weibo_senti_cls / train.py
ZaynSu99's picture
Update weibo_senti_cls
995278d
from cnn import CNN,Model
from utils import loader_train
import torch
#超参数设置
epochs = 5
learning_rate = 5e-4
device = torch.device('cuda')
#model = CNN().to(device)
model = Model().to(device)
#AdamW,学习率5e-4
optimizer = torch.optim.AdamW(model.parameters(),lr=learning_rate)
#学习率衰减90%
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.1)
#交叉熵
criterion = torch.nn.CrossEntropyLoss()
model.train()
for epoch in range(0,epochs):
if epoch == 3:
scheduler.step()
print('轮数',epoch+1)
for i ,(input_ids, attention_mask, token_type_ids, labels) in enumerate(loader_train):
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
token_type_ids = token_type_ids.to(device)
labels = labels.to(device)
out = model(input_ids=input_ids,attention_mask=attention_mask,token_type_ids=token_type_ids).to(device)
loss = criterion(out,labels)
loss.backward()
optimizer.step()
optimizer.zero_grad()
if i % 1 == 0:
out = out.argmax(dim=1)
print(out,labels,(out==labels).sum().item(),len(labels))
accuracy = (out==labels).sum().item()/len(labels)
print(i,loss.item(),accuracy)
#torch.save(model.state_dict(),'net_params.pth')
torch.save(model.state_dict(),'cls_params.pth')