Mixture of Experts: когда нейросеть учится делегировать. deeplearning.. deeplearning. llm.. deeplearning. llm. mixture of experts.. deeplearning. llm. mixture of experts. moe.. deeplearning. llm. mixture of experts. moe. vmoes.. deeplearning. llm. mixture of experts. moe. vmoes. Блог компании Data Feeling School.. deeplearning. llm. mixture of experts. moe. vmoes. Блог компании Data Feeling School. Машинное обучение.
Mixture of Experts: когда нейросеть учится делегировать - 1

Привет, чемпионы!

Представьте, что у вас есть большой и сложный проект, и вы наняли двух управленцев: Кабан-Кабаныча и Руководителева. Вы даете им одинаковую задачу: набрать штат сотрудников и выполнить ваш проект. Вся прибыль вместе с начальным бюджетом останется у них.

Кабан-Кабаныч решил, что нет смысла платить отдельным специалистам по DevOps, backend, ML и другим направлениям, и нанял всего одного сотрудника за 80 монеток. Этот бедняга работал в стиле «один за всех» и, естественно, быстро выгорел и «умер». Кабан-Кабаныч, не долго думая, нанял еще одного такого же сотрудника. В итоге вы вернулись и увидели печальную картину: задачу никто не решил, остался лишь Кабан-Кабаныч и кладбище несчастных сотрудников.

Mixture of Experts: когда нейросеть учится делегировать - 2

А вот Руководителев поступил иначе: он распределил бюджет на несколько похожих сотрудников, но сначала не понимал, кто из них в чём лучше. Тогда он стал давать им небольшие задачи и внимательно наблюдать за результатами. Через некоторое время он понял, что сотрудник №1 на 70% лучше справляется с задачами по ML, сотрудник №2 на 80% эффективнее в backend-разработке и так далее. Так Руководителев постепенно сформировал команду экспертов, сам став управляющим (или “gating”-узлом), который распределяет задачи на основе знаний о возможностях каждого сотрудника. Сотрудники углубляли экспертизу в своих направлениях, а Руководителев становился всё эффективнее в распределении задач.

Внезапно мы пришли к интересному решению:

  • Руководителев — это gating network, который распределяет задачи, исходя из предыдущих успехов сотрудников.

  • Сотрудники — это local experts, каждый из которых специализируется на своей части задач.

Mixture of Experts: когда нейросеть учится делегировать - 3

Таким образом, мы экономим ресурсы, получаем сильных специалистов и достигаем отличных результатов за короткое время.

Именно так в 1991 году и появилось решение Adaptive Mixtures of local Experts

Этот подход доказал эффективность, сокращая время обучения моделей почти вдвое.

Как работает MoE?

Представьте модель, у которой есть входные и выходные данные, а между ними набор экспертов. Этих экспертов организует управляющая сеть (gating network), определяющая, какие эксперты могут лучше справиться с конкретной задачей. Gating-сеть, которая присваивает веса результату каждого эксперта, объединяя их в итоговый ответ.

Звучит красиво, но не всё так просто… Во время обучения возникают интересные и даже «ломающие мозг» ситуации, особенно когда осознаёшь, что созданная тобой модель может «вынести» тебя самого.

Conditional Computation одна из фишек MoE: возможность отключать или частично использовать экспертов. Это позволяет комбинировать разные архитектуры, каждая из которых выявляет уникальные паттерны в данных. Модель становится гибкой: сама решает, каких экспертов задействовать активно, кого игнорировать, а кого подключить чуть-чуть.

Ключевая особенность — разреженность. С помощью MoE можно масштабировать модель без пропорционального увеличения вычислительной нагрузки. Это очень важно, ведь позволяет обучать огромное количество экспертов, используя при этом только нужных. В этом нам помогает важный гиперпараметр — top_k, определяющий, сколько лучших экспертов будет выбрано для каждого входа.

Mixture of Experts: когда нейросеть учится делегировать - 4

Но основные сложности начинаются с настройки гиперпараметров и архитектурных решений. Самая большая проблема MoE — это «прилипание гейта», когда маршрутизатор начинает постоянно выбирать одних и тех же экспертов. Эти избранные эксперты получают больше данных и быстрее обучаются, в то время как остальные «скучают и пьют кофе».

Возникает закономерный вопрос: зачем тогда вообще нужны остальные эксперты?

Как с этим бороться? В своём коде я добавил трекер распределения данных по экспертам, чтобы контролировать, не «залип» ли гейт. Также я внедрил несколько хитрых решений, подсмотренных на профессиональных форумах.

Давайте кратко резюмируем:

Технология MoE выгодна за счёт разреженности и гибкости использования экспертов. Однако это «сделка с дьяволом», поскольку возникают сложности:

  • Сложная балансировка работы экспертов.

  • Функция потерь должна учитывать как производительность экспертов, так и маршрутизатора.

  • Количество гиперпараметров (количество экспертов, архитектура gating-сети) усложняет настройку модели.

Где сейчас используют MoE?

Почти все современные LLM используют MoE. Например, недавно вышедшая модель Llama4 Scout с 16x17B параметрами — это 16 экспертов по 17 миллиардов параметров каждый. То есть на инференсе вы используете не все 272 млрд параметров, а только top_k выбранных. Впечатляющее снижение вычислительных затрат, правда?

Также технология активно применяется в компьютерном зрении, и сейчас мы её протестируем на простом примере V-MoEs.

Тест драйв технологии

Итак, для обучения возьмем простенький датасет CIFAR100 и обучим на нем нашу кастомную V-MoEs для классификации изображений.

Сама по себе архитектура будет состоять из следующего:

Классический VIT, но ее часть классификатора мы обернем в decoder блок, где у нас будет применена MOE

Mixture of Experts: когда нейросеть учится делегировать - 5

Начнем с маршрутизатора, в нашем случае он был реализован следующим образом

import torch
import torch.nn as nn
import torch.nn.functional as F

class GatingNetwork(nn.Module):
    def __init__(self,
                 input_dim = 151296,
                 num_experts=4,
                 top_k=2,
                 use_noise=True,
                 noise_std=1e-2,
                 temperature=1.0):
        super().__init__()

        self.num_experts = num_experts
        self.top_k = top_k
        self.use_noise = use_noise
        self.noise_std = noise_std
        self.temperature = temperature

        self.gate = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, num_experts)
        )

    def forward(self, x):
        logits = self.gate(x)  # (B, num_experts)

        if self.use_noise and self.training:
            scale = logits.std(dim=1, keepdim=True).clamp(min=1e-3)  
            noise = torch.randn_like(logits) * self.noise_std * scale
            logits = logits + noise

        topk_vals, topk_indices = torch.topk(logits, self.top_k, dim=1)

        gates = F.softmax(topk_vals / self.temperature, dim=1)  # (B, top_k)

        return topk_indices, gates

Он берёт входной вектор x, оценивает, какие эксперты из num_experts лучше подойдут для каждого примера в батче, и возвращает top_k лучших экспертов с их весами.

То есть это — классическая Gating Network, которая решает, каким экспертам дать поработать с входом.

Обратите внимание, что тут есть noisy gating — это один из способов избежать “залипания гейта” на одном и том же эксперте. Во время тренировки шум масштабируется и в зависимости от поставленной нами пропорции влияет на решение о том какого эксперта повыбирать. Иными словами мы влияем на “результатова”, чтобы он давал шансы большему числу экспертов, а не выбирал любимчиков.

Создадим экспертов

import torch.nn as nn


class FFNExpert(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim, dropout_prob=0.5):
        super(FFNExpert, self).__init__()

        layers = []
        self.linears = nn.ModuleList()

        prev_dim = input_dim
        for hidden_dim in hidden_dims:
            linear = nn.Linear(prev_dim, hidden_dim)
            self.linears.append(linear)
            layers.append(linear)
            layers.append(nn.LayerNorm(hidden_dim))
            layers.append(nn.ReLU())
            layers.append(nn.Dropout(dropout_prob))
            prev_dim = hidden_dim

        final_linear = nn.Linear(prev_dim, output_dim)
        self.linears.append(final_linear)
        layers.append(final_linear)

        self.network = nn.Sequential(*layers)
        self._initialize_weights()

    def _initialize_weights(self):
        for linear in self.linears:
            nn.init.xavier_uniform_(linear.weight)
            if linear.bias is not None:
                nn.init.zeros_(linear.bias)

    def forward(self, x):
        return self.network(x)


class FFNExpertSmall(FFNExpert):
    def __init__(self, input_dim, output_dim):
        super(FFNExpertSmall, self).__init__(input_dim, hidden_dims=[256, 128], output_dim=output_dim, dropout_prob=0.3)


class FFNExpertMedium(FFNExpert):
    def __init__(self, input_dim, output_dim):
        super(FFNExpertMedium, self).__init__(input_dim, hidden_dims=[512, 256, 128], output_dim=output_dim,
                                              dropout_prob=0.4)


class FFNExpertLarge(FFNExpert):
    def __init__(self, input_dim, output_dim):
        super(FFNExpertLarge, self).__init__(input_dim, hidden_dims=[1024, 512, 256, 128], output_dim=output_dim,
                                             dropout_prob=0.5)


class FFNExpertVeryLarge(FFNExpert):
    def __init__(self, input_dim, output_dim):
        super(FFNExpertVeryLarge, self).__init__(input_dim, hidden_dims=[2048, 1024, 512, 256, 128],
                                                 output_dim=output_dim, dropout_prob=0.6)

Тут в целом все просто, мы набросали 4 эксперта с разными параметрами и посмотрим на то как они будут обучаться.

Начнем собирать модель

import torch.nn as nn
import timm

class ViT_backbone(nn.Module):
    def __init__(self):
        super().__init__()
        self.backbone = timm.create_model('vit_base_patch16_224',
                                          pretrained=True)

        for param in self.backbone.parameters():
            param.requires_grad = False

        self.embed_dim = self.backbone.head.in_features

        self.backbone.reset_classifier(0)

        self.ln = nn.LayerNorm(self.embed_dim)
        self.ln2 = nn.LayerNorm(self.embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim=self.embed_dim,
                                          num_heads=8,
                                          batch_first=True)

    def forward(self, x):
        skip = self.backbone.forward_features(x)  # [B, N, D]
        x_ln = self.ln(skip)
        attn_out, _ = self.attn(x_ln, x_ln, x_ln)
        x_attn = attn_out + skip
        x_final = self.ln2(x_attn).flatten(1)  # [B, N*D]
        return x_final

Тут все просто возьмем классическую модель VIT и добавим к ней слои нормализации после Multihead Attention и skip connection.

После сделаем наше объединение и наконец-то MOE

import torch
import torch.nn as nn
import torch.nn.functional as F
from model.gating_network import GatingNetwork
from model.Vit_model import ViT_backbone

class MoECNN(nn.Module):
    def __init__(self,
                 experts,
                 input_for_gating = 151296,
                 top_k=2,
                 output_dim=100,
                 use_aux_loss=True,
                 aux_loss_weight=0.01,
                 warmup_iters=500,
                 noise_std = 0.5):
        super().__init__()
        self.num_experts = len(experts)
        self.top_k = top_k
        self.output_dim = output_dim
        self.use_aux_loss = use_aux_loss
        self.aux_loss_weight = aux_loss_weight
        self.warmup_iters = warmup_iters
        self.iter = 0

        self.backbone = ViT_backbone()
        self.experts = nn.ModuleList(experts)
        self.gating = GatingNetwork( input_dim = input_for_gating,
                                     num_experts=self.num_experts,
                                     top_k = top_k,
                                     noise_std=noise_std)

        self.register_buffer("expert_usage",
                             torch.zeros(self.num_experts))

    def forward(self, x):
        batch_size = x.size(0)
        device = x.device
        x = self.backbone(x)

        if self.training and self.iter < self.warmup_iters:
            random_indices = torch.randint(0,
                                           self.num_experts,
                                           (batch_size, self.top_k),
                                           device=device)
            gates = torch.full((batch_size, self.top_k),
                               1.0 / self.top_k,
                               device=device)
            topk_indices = random_indices
            self.iter += 1
        else:
            topk_indices, gates = self.gating(x)

        output = torch.zeros(batch_size, self.output_dim, device=device)
        self.expert_usage.zero_()

        for i in range(self.top_k):
            idx = topk_indices[:, i]
            for expert_idx in torch.unique(idx):
                expert_mask = (idx == expert_idx)
                if expert_mask.sum() == 0:
                    continue
                x_sel = x[expert_mask]
                y_sel = self.experts[expert_idx](x_sel)
                gate_weight = gates[expert_mask, i].unsqueeze(1)
                output[expert_mask] += gate_weight * y_sel

                self.expert_usage[expert_idx] += expert_mask.sum()

        aux_loss = None
        if self.use_aux_loss and self.training:
            usage = self.expert_usage / batch_size
            aux_loss = ((usage - usage.mean()) ** 2).mean() * self.aux_loss_weight

        return output, aux_loss

Первое на что обратим внимание это это warmup_iters. Тут у нас это число итераций где мы как-бы отключаем gating-сеть , чтобы избежать коллапса распределения (один эксперт выбирается чаще остальных до того, как сеть обучится разумно маршрутизировать входы). Это дает нам “разогреть” экспертов передавая им равномерно данные и далее мы уже начинаем более тонко избирать экспертов за счет gating network.

Второй момент это добавление use_aux_loss. Данный параметр позволяет нам учитывать в общем лоссе неравномерное распределение по экспертам в общий loss.

Как итог модель выбирает tok_k экспертов и на основе их предсказаний делает взвешенную сумму, после чего выдает результат и loss по распределению.

Что в итоге?

При простом “на коленке” мы смогли получить f1 на тесте 89.%. Более явно поиграв с гипперпараметрами, типами экспертов и некоторыми изощренностями думаю, что можно получить результат лучше. Самое главное, что

Mixture of Experts: когда нейросеть учится делегировать - 6

Давайте теперь проведем модель через один батч и посмотрим, что там на одном батче, что произошло по графику и посмотрим на первые 10 сэмплов батча.

Mixture of Experts: когда нейросеть учится делегировать - 7

Как можем увидеть, у нас 2 эксперт оказался в данной итерации не востребован, а использовали мы с 0,1,3 эксперта в разной пропорции.

Можно сказать, что вот: “второй эксперт переобучился или обучился плохо”. Однако давайте глянем глубже! Мы ведь отслеживаем все через clearml :)

Mixture of Experts: когда нейросеть учится делегировать - 8

На тестовом датасете в среднем можно заметить, что все эксперты примерно вышли на какую-то свою зону ответственности. Хотя конечно от первого эксперта хотелось ожидать побольше!

Теперь давайте посмотрим на визуализацию результатов:

Mixture of Experts: когда нейросеть учится делегировать - 9

Несмотря на шакальность(мы работаем с CIFAR100 напоминаю) мы получили весьма неплохие результаты. И теперь вишенка на торте – это отслеживание по экспертам. Их собственно говоря мы итак логируем и сейчас можем провести на маленьком сэмпле аналитику. Если у нас есть очень большой эксперт и он не пригодился в использовании в вычислениях, то мы можем сэкономить очень много памяти.

Mixture of Experts: когда нейросеть учится делегировать - 10

Подводя итоги основным концептом было показать какие проблемы бывают и сложности при работе с технологией, а также ее возможности и потенциал, который уже сейчас очень успешно реализуется!

🔥 Ставьте лайк и напишите какие темы было бы интересно разобрать дальше! Самое главное — пробуйте и экспериментируйте!

✔️ Присоединяйтесь к нашему Telegram-сообществу @datafeeling, чтобы первыми применять на практике передовые технологии!

Автор: Aleron75

Источник

Рейтинг@Mail.ru
Rambler's Top100