wgcban's picture
Upload 98 files
803ef9e
import torch.nn as nn
from model import get_model, get_head
from eval.sgd import eval_sgd
from eval.knn import eval_knn
from eval.get_data import get_data
import torch
import tqdm
class BaseMethod(nn.Module):
"""
Base class for self-supervised loss implementation.
It includes encoder and head for training, evaluation function.
"""
def __init__(self, cfg):
super().__init__()
self.model, self.out_size = get_model(cfg.arch, cfg.dataset)
self.head = get_head(self.out_size, cfg)
self.knn = cfg.knn
self.num_pairs = cfg.num_samples * (cfg.num_samples - 1) // 2
self.eval_head = cfg.eval_head
self.emb_size = cfg.emb
def forward(self, samples):
raise NotImplementedError
def get_acc(self, ds_clf, ds_test):
self.eval()
if self.eval_head:
model = lambda x: self.head(self.model(x))
out_size = self.emb_size
else:
model, out_size = self.model, self.out_size
# torch.cuda.empty_cache()
x_train, y_train = get_data(model, ds_clf, out_size, "cuda")
x_test, y_test = get_data(model, ds_test, out_size, "cuda")
acc_knn = eval_knn(x_train, y_train, x_test, y_test, self.knn)
acc_linear = eval_sgd(x_train, y_train, x_test, y_test)
del x_train, y_train, x_test, y_test
self.train()
return acc_knn, acc_linear
def get_acc_knn(self, ds_clf, ds_test):
self.eval()
if self.eval_head:
model = lambda x: self.head(self.model(x))
out_size = self.emb_size
else:
model, out_size = self.model, self.out_size
# torch.cuda.empty_cache()
x_train, y_train = get_data(model, ds_clf, out_size, "cuda")
x_test, y_test = get_data(model, ds_test, out_size, "cuda")
acc_knn = eval_knn(x_train, y_train, x_test, y_test, self.knn)
del x_train, y_train, x_test, y_test
self.train()
return acc_knn
def step(self, progress):
pass