Пробуем KAN (Kolmogorov-Arnold Networks) для классификации данных ЭЭГ. kan.. kan. kolmogorov-arnold networks.

После ознакомления с несколькими статьями (1, 2, 3, 4) о новой архитектуре KAN (Kolmogorov‑Arnold Networks), у меня возникло желание опробовать её для классификации данных ЭЭГ. Архитектура KAN показалась перспективной благодаря своей способности моделировать сложные нелинейные зависимости, что может быть особенно полезно для анализа сигналов ЭЭГ.

Из‑за ограниченного времени было решено взять за основу готовое исследование и проверить эффективность KAN на тех же данных, используя преимущественно стандартные настройки.

Для классификации был выбран датасет, доступный по ссылке. Набор данных включает записи ЭЭГ от 14 пациентов с параноидной шизофренией и 14 здоровых людей из контрольной группы. Данные были записаны с частотой дискретизации 250 Гц с использованием стандартной схемы размещения электродов 10–20 и 19 каналов: Fp1, Fp2, F7, F3, Fz, F4, F8, T3, C3, Cz, C4, T4, T5, P3, Pz, P4, T6, O1, O2. Предобработка и фильтрация данных выполнены в соответствии с примером на гитхаб.

Тестирование pykan

Первым этапом было тестирование библиотеки pykan с различными конфигурациями, представленными в таблице:

Модель

width

grid

k

model_pykan_1

[19, 1]

3

3

model_pykan_2

[19, 9, 1]

3

3

model_pykan_3

[19, 9, 4, 1]

3

3

Основным изменяемым параметром был width. В дальнейшем планируется экспериментировать с другими параметрами.

Также можно визуализировать архитектуры моделей. Пример кода для построения графиков:

model_pykan_1 = KAN(width=[19, 1], grid=3, k=3)

x = torch.normal(0,1,size=(10, 19))

model_pykan_1(x)

figure(figsize=(38, 20), dpi=80)

model_pykan_1.plot(beta=100, )

Графики архитектур моделей

Графики архитектур моделей

Для разделения данных на обучающую и валидационную выборки использовался метод gkf.split(), который гарантирует, что данные одного пациента не попадут одновременно в обе выборки. Функция обучения:

def kan_train(df_x, df_y, df_group, model):

accuracy_train = []

accuracy_test=[]

for train_index, val_index in gkf.split(df_x, df_y, groups=df_group):

train_features, train_labels = df_x[train_index], df_y[train_index]

val_features, val_labels = df_x[val_index], df_y[val_index]

dataset = dataset_user(train_features[:, -1, :], val_features[:, -1, :], train_labels, val_labels)

def train_acc():

return torch.mean((torch.round(model(dataset['train_input'])[:,0]) == dataset['train_label'][:,0]).float())

def test_acc():

return torch.mean((torch.round(model(dataset['test_input'])[:,0]) == dataset['test_label'][:,0]).float())

results = model.fit(dataset, opt="LBFGS", steps=50, metrics=(train_acc, test_acc));

print(results['train_acc'][-1], results['test_acc'][-1])

accuracy_train.append(results['train_acc'][-1])

accuracy_test.append(results['test_acc'][-1])

return accuracy_train, accuracy_test

После обучения были получены следующие результаты:

Модель

Test Acc

Val Acc

model_pykan_1

0.67

0.55

model_pykan_2

0.88

0.58

model_pykan_3

0.81

0.53

Как видно из таблицы, с увеличением количества слоёв модель начинает переобучаться.

Тестирование DeepKAN

Первым этапом было тестирование библиотеки deepkan с различными конфигурациями, представленными в таблице:

Модель

input_dim

hidden_layers

num_knots

spline_order

model_deepkan_1

19

[1]

3

3

model_deepkan_2

19

[9, 1]

3

3

model_deepkan_3

19

[9, 4, 1]

3

3

Основным изменяемым параметром был hidden_layers (Список, определяющий размерности скрытых слоев). В дальнейшем планируется экспериментировать с другими параметрами.

После обучения были получены следующие результаты:

Модель

Test Acc

Val Acc

model_deepkan_1

0.5490

0.5476

model_deepkan_2

0.5480

0.5466

model_deepkan_3

0.5425

0.4730

Как видно из таблицы, данная реализация KAN ведет себя стабильней.

Тестирование модели с GitHub

Далее была протестирована модель, найденная на GitHub. Подробнее с ней можно ознакомиться по ссылке. Сама модель выглядит так:

модель представленная на гитхаб

модель представленная на гитхаб

Результат: avg cross‑val acc: 0.6662.

Замена линейного слоя на SplineLinearLayer

После этого я заменил линейный слой на SplineLinearLayer из библиотеки KAN. Описание библиотеки доступно тут.

Получившиеся модель:

Модель с замененным линейным слоем на SplineLinearLayer

Модель с замененным линейным слоем на SplineLinearLayer

Результат: avg cross‑val acc: 0.6588.

Выводы

Учитывая, что модель с GitHub построена на довольно старой архитектуре, KAN на базовых настройках не показал значительного улучшения результата. Однако стоит отметить, что совместное использование KAN с другими методами может привести к более высоким результатам.

Тетрадка с представленными тестами

Автор: MxaTs

Источник

Рейтинг@Mail.ru
Rambler's Top100