После ознакомления с несколькими статьями (1, 2, 3, 4) о новой архитектуре KAN (Kolmogorov‑Arnold Networks), у меня возникло желание опробовать её для классификации данных ЭЭГ. Архитектура KAN показалась перспективной благодаря своей способности моделировать сложные нелинейные зависимости, что может быть особенно полезно для анализа сигналов ЭЭГ.
Из‑за ограниченного времени было решено взять за основу готовое исследование и проверить эффективность KAN на тех же данных, используя преимущественно стандартные настройки.
Для классификации был выбран датасет, доступный по ссылке. Набор данных включает записи ЭЭГ от 14 пациентов с параноидной шизофренией и 14 здоровых людей из контрольной группы. Данные были записаны с частотой дискретизации 250 Гц с использованием стандартной схемы размещения электродов 10–20 и 19 каналов: Fp1, Fp2, F7, F3, Fz, F4, F8, T3, C3, Cz, C4, T4, T5, P3, Pz, P4, T6, O1, O2. Предобработка и фильтрация данных выполнены в соответствии с примером на гитхаб.
Тестирование pykan
Первым этапом было тестирование библиотеки pykan с различными конфигурациями, представленными в таблице:
Модель |
width |
grid |
k |
model_pykan_1 |
[19, 1] |
3 |
3 |
model_pykan_2 |
[19, 9, 1] |
3 |
3 |
model_pykan_3 |
[19, 9, 4, 1] |
3 |
3 |
Основным изменяемым параметром был width. В дальнейшем планируется экспериментировать с другими параметрами.
Также можно визуализировать архитектуры моделей. Пример кода для построения графиков:
model_pykan_1 = KAN(width=[19, 1], grid=3, k=3)
x = torch.normal(0,1,size=(10, 19))
model_pykan_1(x)
figure(figsize=(38, 20), dpi=80)
model_pykan_1.plot(beta=100, )

Для разделения данных на обучающую и валидационную выборки использовался метод gkf.split(), который гарантирует, что данные одного пациента не попадут одновременно в обе выборки. Функция обучения:
def kan_train(df_x, df_y, df_group, model):
accuracy_train = []
accuracy_test=[]
for train_index, val_index in gkf.split(df_x, df_y, groups=df_group):
train_features, train_labels = df_x[train_index], df_y[train_index]
val_features, val_labels = df_x[val_index], df_y[val_index]
dataset = dataset_user(train_features[:, -1, :], val_features[:, -1, :], train_labels, val_labels)
def train_acc():
return torch.mean((torch.round(model(dataset['train_input'])[:,0]) == dataset['train_label'][:,0]).float())
def test_acc():
return torch.mean((torch.round(model(dataset['test_input'])[:,0]) == dataset['test_label'][:,0]).float())
results =
model.fit
(dataset, opt="LBFGS", steps=50, metrics=(train_acc, test_acc));
print(results['train_acc'][-1], results['test_acc'][-1])
accuracy_train.append(results['train_acc'][-1])
accuracy_test.append(results['test_acc'][-1])
return accuracy_train, accuracy_test
После обучения были получены следующие результаты:
Модель |
Test Acc |
Val Acc |
model_pykan_1 |
0.67 |
0.55 |
model_pykan_2 |
0.88 |
0.58 |
model_pykan_3 |
0.81 |
0.53 |
Как видно из таблицы, с увеличением количества слоёв модель начинает переобучаться.
Тестирование DeepKAN
Первым этапом было тестирование библиотеки deepkan с различными конфигурациями, представленными в таблице:
Модель |
input_dim |
hidden_layers |
num_knots |
spline_order |
model_deepkan_1 |
19 |
[1] |
3 |
3 |
model_deepkan_2 |
19 |
[9, 1] |
3 |
3 |
model_deepkan_3 |
19 |
[9, 4, 1] |
3 |
3 |
Основным изменяемым параметром был hidden_layers (Список, определяющий размерности скрытых слоев). В дальнейшем планируется экспериментировать с другими параметрами.
После обучения были получены следующие результаты:
Модель |
Test Acc |
Val Acc |
model_deepkan_1 |
0.5490 |
0.5476 |
model_deepkan_2 |
0.5480 |
0.5466 |
model_deepkan_3 |
0.5425 |
0.4730 |
Как видно из таблицы, данная реализация KAN ведет себя стабильней.
Тестирование модели с GitHub
Далее была протестирована модель, найденная на GitHub. Подробнее с ней можно ознакомиться по ссылке. Сама модель выглядит так:

Результат: avg cross‑val acc: 0.6662.
Замена линейного слоя на SplineLinearLayer
После этого я заменил линейный слой на SplineLinearLayer из библиотеки KAN. Описание библиотеки доступно тут.
Получившиеся модель:

Результат: avg cross‑val acc: 0.6588.
Выводы
Учитывая, что модель с GitHub построена на довольно старой архитектуре, KAN на базовых настройках не показал значительного улучшения результата. Однако стоит отметить, что совместное использование KAN с другими методами может привести к более высоким результатам.
Автор: MxaTs