asl_model_uploader / utils.py
louiecerv's picture
save changes
1cc1116
raw
history blame
1.52 kB
import torch
import torch.nn as nn
class MyConvBlock(nn.Module):
def __init__(self, in_ch, out_ch, dropout_p):
kernel_size = 3
super().__init__()
self.model = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size, stride=1, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.Dropout(dropout_p),
nn.MaxPool2d(2, stride=2)
)
def forward(self, x):
return self.model(x)
def get_batch_accuracy(output, y, N):
pred = output.argmax(dim=1, keepdim=True)
correct = pred.eq(y.view_as(pred)).sum().item()
return correct / N
def train(model, train_loader, train_N, random_trans, optimizer, loss_function):
loss = 0
accuracy = 0
model.train()
for x, y in train_loader:
output = model(random_trans(x))
optimizer.zero_grad()
batch_loss = loss_function(output, y)
batch_loss.backward()
optimizer.step()
loss += batch_loss.item()
accuracy += get_batch_accuracy(output, y, train_N)
print('Train - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))
def validate(model, valid_loader, valid_N, loss_function):
loss = 0
accuracy = 0
model.eval()
with torch.no_grad():
for x, y in valid_loader:
output = model(x)
loss += loss_function(output, y).item()
accuracy += get_batch_accuracy(output, y, valid_N)
print('Valid - Loss: {:.4f} Accuracy: {:.4f}'.format(loss, accuracy))