1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135
| import os import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torchvision import transforms from tqdm import tqdm from dataset import EEGDataset,EEGDataset_Batch_normal from net import IntegratedNet from sklearn.metrics import classification_report from matplotlib import pyplot as plt
# 定义归一化操作 def normalize(data): mean = np.mean(data) std = np.std(data) return (data - mean) / std
transform = transforms.Compose([ transforms.Lambda(normalize), # 使用Lambda函数应用自定义归一化操作 transforms.ToTensor() ])
def train_identityformer_model(model, model_name, num_epochs=100, num_classes=3, batch_size=16, learning_rate=0.0001, w_wight=1025, chennal=33): if torch.cuda.is_available(): device = torch.device("cuda") else: device = torch.device("cpu") m = nn.Softmax(dim=1)
train_dataset = EEGDataset(csv_file='train_data.csv', transform=transform) test_dataset = EEGDataset(csv_file='test_data.csv', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
model.to(device) loss_fn = nn.CrossEntropyLoss() optimizer = optim.RMSprop(model.parameters(), lr=learning_rate)
save_dir = 'loss' os.makedirs(save_dir, exist_ok=True)
train_loss_arr = [] train_acc_arr = [] val_loss_arr = [] val_acc_arr = []
best_val_acc = 0.0 best_epoch = 0
for epoch in range(num_epochs): train_loss_total = 0 train_acc_total = 0 val_loss_total = 0 val_acc_total = 0
model.train() progress_bar = tqdm(enumerate(train_loader), total=len(train_loader)) for i, (train_x, train_y) in progress_bar: train_x = train_x.to(device) train_y = train_y.to(device) train_x = train_x.unsqueeze(1) train_x = train_x.view(batch_size, 1, chennal, w_wight)
train_y_pred = model(train_x) train_loss = loss_fn(train_y_pred, train_y)
train_acc = (m(train_y_pred).max(dim=1)[1] == train_y).sum() / train_y.shape[0] train_loss_total += train_loss.data.item() train_acc_total += train_acc.data.item()
train_loss.backward() optimizer.step() optimizer.zero_grad()
progress_bar.set_description(f"Epoch {epoch+1}/{num_epochs}, Batch {i+1}/{len(train_loader)}, Train Loss: {train_loss.data.item():.4f}, Train Acc: {train_acc.data.item():.4f}")
train_loss_arr.append(train_loss_total / len(train_loader)) train_acc_arr.append(train_acc_total / len(train_loader))
model.eval() for j, (val_x, val_y) in enumerate(test_loader): val_x = val_x.to(device) val_y = val_y.to(device) val_x = val_x.unsqueeze(1) val_x = val_x.view(batch_size, 1, chennal, w_wight)
val_y_pred = model(val_x) val_loss = loss_fn(val_y_pred, val_y) val_acc = (m(val_y_pred).max(dim=1)[1] == val_y).sum() / val_y.shape[0] val_loss_total += val_loss.data.item() val_acc_total += val_acc.data.item()
val_loss_arr.append(val_loss_total / len(test_loader)) val_acc_arr.append(val_acc_total / len(test_loader))
if val_acc_arr[-1] > best_val_acc: best_val_acc = val_acc_arr[-1] best_epoch = epoch torch.save(model.state_dict(), f"{model_name}_best.pth")
print("epoch:{} val_loss:{} val_acc:{}".format(epoch, val_loss_arr[-1], val_acc_arr[-1]))
np.save(os.path.join(save_dir, 'train_loss_arr.npy'), np.array(train_loss_arr)) np.save(os.path.join(save_dir, 'train_acc_arr.npy'), np.array(train_acc_arr)) np.save(os.path.join(save_dir, 'val_loss_arr.npy'), np.array(val_loss_arr)) np.save(os.path.join(save_dir, 'val_acc_arr.npy'), np.array(val_acc_arr))
plt.subplot(1, 2, 1) plt.title("loss") plt.plot(train_loss_arr, "r", label="train") plt.plot(val_loss_arr, "b", label="val") plt.legend()
plt.subplot(1, 2, 2) plt.title("acc") plt.plot(train_acc_arr, "r", label="train") plt.plot(val_acc_arr, "b", label="val") plt.legend()
plt.savefig("loss_acc.png") plt.show()
print(f"Best model at epoch {best_epoch+1}, val_acc={best_val_acc:.4f}") print('Training completed!')
# 创建模型并训练 model = IntegratedNet(input_size=1,in_feature=157,num_classes=2) # 确保模型的输出层适用于三分类问题 train_identityformer_model(model, model_name='MLPFormer_betch_16_fft_opendata',chennal=19,w_wight=2500)
|