Кастомные loss-функции в TensorFlow-Keras и PyTorch. keras.. keras. loss-функции.. keras. loss-функции. ml.. keras. loss-функции. ml. PyTorch.. keras. loss-функции. ml. PyTorch. tensorflow.

Стандартные loss‑функции, такие как MSE или CrossEntropy, хороши, но часто им не хватает гибкости для сложных задач. Допустим, есть тот же проект с огромным дисбалансом классов, или хочется внедрить специфическую регуляризацию прямо в функцию потерь. Стандартный функционал тут бессилен — тут на помощь приходят кастомные loss’ы.

Custom Loss Functions в TensorFlow/Keras

TensorFlow/Keras радуют удобным API, но за простоту приходится платить вниманием к деталям.

Focal Loss

Focal Loss помогает сместить фокус обучения на сложные примеры, снижая влияние легко классифицируемых данных:

import tensorflow as tf
from tensorflow.keras import backend as K

def focal_loss(gamma=2., alpha=0.25):
    """
    Реализация Focal Loss для задач с дисбалансом классов.
    :param gamma: фокусирующий параметр для усиления влияния сложных примеров.
    :param alpha: коэффициент балансировки классов.
    :return: функция потерь, принимающая (y_true, y_pred).
    """
    def focal_loss_fixed(y_true, y_pred):
        # Защита от log(0) – обрезаем значения предсказаний.
        y_pred = K.clip(y_pred, K.epsilon(), 1. - K.epsilon())
        # Вычисляем кросс-энтропию для каждого примера.
        cross_entropy = -y_true * tf.math.log(y_pred)
        # Применяем вес для "тяжёлых" примеров.
        weight = alpha * tf.pow(1 - y_pred, gamma)
        loss = weight * cross_entropy
        # Усредняем по батчу и классам.
        return tf.reduce_mean(tf.reduce_sum(loss, axis=-1))
    return focal_loss_fixed

# Пример использования Focal Loss:
if __name__ == "__main__":
    # Тестовые данные для отладки (да, я тоже люблю маленькие эксперименты)
    y_true = tf.constant([[1, 0], [0, 1]], dtype=tf.float32)
    y_pred = tf.constant([[0.9, 0.1], [0.2, 0.8]], dtype=tf.float32)
    
    loss_fn = focal_loss(gamma=2.0, alpha=0.25)
    loss_value = loss_fn(y_true, y_pred)
    print("Focal Loss:", loss_value.numpy())

Интеграция кастомного loss в модель Keras

Создадим простую CNN‑модель для распознавания изображений и подключим Focal Loss:

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense

def create_model(input_shape=(28, 28, 1), num_classes=10):
    model = Sequential([
        Conv2D(32, kernel_size=(3, 3), activation='relu', input_shape=input_shape),
        MaxPooling2D(pool_size=(2, 2)),
        Flatten(),
        Dense(128, activation='relu'),
        Dense(num_classes, activation='softmax')
    ])
    return model

# Компилируем модель с кастомной функцией потерь
model = create_model()
model.compile(optimizer='adam', loss=focal_loss(gamma=2.0, alpha=0.25), metrics=['accuracy'])

# Создадим тестовые данные (набор из случайных изображений и меток)
import numpy as np
X_train = np.random.rand(100, 28, 28, 1)
y_train = tf.keras.utils.to_categorical(np.random.randint(0, 10, 100), num_classes=10)

print("Запускаем обучение модели с кастомным Focal Loss...")
model.fit(X_train, y_train, epochs=3, batch_size=16)

Модель обучается и градиенты сходятся.

Нюансы вычисления градиентов

Нельзя забывать — любые операции, выполняемые с numpy, ломают автоматическое вычисление градиентов. Пример плохой практики:

import numpy as np
import tensorflow as tf

def loss_with_numpy(y_true, y_pred):
    # Плохая практика: переводим тензоры в numpy и разрываем градиентный поток.
    y_true_np = y_true.numpy()  # Ой-ой, ошибка внутри GradientTape!
    y_pred_np = y_pred.numpy()
    loss_np = np.mean((y_true_np - y_pred_np) ** 2)
    return tf.constant(loss_np, dtype=tf.float32)

if __name__ == "__main__":
    x = tf.constant([[1.0], [2.0]])
    y_true = tf.constant([[1.5], [2.5]])
    
    with tf.GradientTape() as tape:
        tape.watch(x)
        y_pred = x * 2
        try:
            loss = loss_with_numpy(y_true, y_pred)
            grad = tape.gradient(loss, x)
            print("Gradient:", grad)
        except Exception as e:
            print("Ошибка при вычислении градиента:", e)

Оставайтесь в мире тензоров — TensorFlow умеет всё, что нужно, если вы не решите подмешать туда numpy.

Custom Loss Functions в PyTorch

Реализация кастомной loss через torch.autograd.Function

Начнем с простейшей реализации кастомной loss‑функции, которая считает квадратичную ошибку:

import torch

class CustomLossFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, target):
        """
        Прямой проход: вычисляем MSE.
        """
        ctx.save_for_backward(input, target)
        loss = torch.mean((input - target) ** 2)
        return loss

    @staticmethod
    def backward(ctx, grad_output):
        """
        Обратный проход: аккуратно считаем градиенты.
        """
        input, target = ctx.saved_tensors
        grad_input = grad_output * 2 * (input - target) / input.numel()
        return grad_input, None

# Тестовый пример использования:
if __name__ == "__main__":
    x = torch.tensor([1.0, 2.0, 3.0], requires_grad=True)
    y = torch.tensor([1.5, 2.5, 3.5])
    
    loss = CustomLossFunction.apply(x, y)
    print("Custom Loss (PyTorch):", loss.item())
    
    loss.backward()
    print("Gradient (PyTorch):", x.grad)

Focal Loss в PyTorch

Focal Loss существует не только в TensorFlow. В PyTorch можно сделать не хуже:

import torch
import torch.nn as nn
import torch.nn.functional as F

class FocalLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        # Если inputs – логиты, используем sigmoid для преобразования
        BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
        pt = torch.exp(-BCE_loss)
        focal_loss = self.alpha * (1 - pt) ** self.gamma * BCE_loss
        
        if self.reduction == 'mean':
            return focal_loss.mean()
        elif self.reduction == 'sum':
            return focal_loss.sum()
        else:
            return focal_loss

# Тестируем Focal Loss в PyTorch:
if __name__ == "__main__":
    inputs = torch.tensor([[0.2, -1.0], [1.5, 0.3]], requires_grad=True)
    targets = torch.tensor([[0, 1], [1, 0]], dtype=torch.float32)
    
    criterion = FocalLoss(alpha=0.25, gamma=2.0)
    loss = criterion(inputs, targets)
    print("Focal Loss (PyTorch):", loss.item())
    
    loss.backward()
    print("Gradients (Focal Loss):", inputs.grad)

Работа с эмбеддингами

Для задач, где нужно сравнивать схожесть объектов, подойдут Contrastive и Triplet Loss. Реализуем их в PyTorch.

Contrastive Loss

class ContrastiveLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(ContrastiveLoss, self).__init__()
        self.margin = margin

    def forward(self, output1, output2, label):
        # Евклидова дистанция между эмбеддингами
        euclidean_distance = F.pairwise_distance(output1, output2)
        loss_contrastive = torch.mean((1 - label) * torch.pow(euclidean_distance, 2) +
                                      (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
        return loss_contrastive

# Пример использования Contrastive Loss:
if __name__ == "__main__":
    output1 = torch.tensor([[1.0, 2.0], [3.0, 4.0]], requires_grad=True)
    output2 = torch.tensor([[1.5, 2.5], [2.5, 3.5]], requires_grad=True)
    # label: 0 для похожих пар, 1 для непохожих.
    label = torch.tensor([0, 1], dtype=torch.float32)
    
    criterion = ContrastiveLoss(margin=1.0)
    loss = criterion(output1, output2, label)
    print("Contrastive Loss:", loss.item())
    loss.backward()

Triplet Loss

class TripletLoss(nn.Module):
    def __init__(self, margin=1.0):
        super(TripletLoss, self).__init__()
        self.margin = margin

    def forward(self, anchor, positive, negative):
        pos_distance = F.pairwise_distance(anchor, positive, p=2)
        neg_distance = F.pairwise_distance(anchor, negative, p=2)
        losses = torch.relu(pos_distance - neg_distance + self.margin)
        return losses.mean()

# Пример использования Triplet Loss:
if __name__ == "__main__":
    anchor = torch.tensor([[1.0, 2.0], [2.0, 3.0]], requires_grad=True)
    positive = torch.tensor([[1.1, 2.1], [1.9, 2.9]], requires_grad=True)
    negative = torch.tensor([[3.0, 4.0], [4.0, 5.0]], requires_grad=True)
    
    criterion = TripletLoss(margin=1.0)
    loss = criterion(anchor, positive, negative)
    print("Triplet Loss:", loss.item())
    loss.backward()

Если вам хочется поделиться опытом — пишите в комментариях.

Все актуальные методы и инструменты DS и ML можно освоить на онлайн-курсах OTUS: в каталоге можно посмотреть список всех программ, а в календаре — записаться на открытые уроки.

Автор: badcasedaily1

Источник

Rambler's Top100