Spaces:
Running
Running
from typing import List | |
import torch | |
from methods.base.model import BaseModel | |
import tqdm | |
from torch import nn | |
import torch.nn.functional as F | |
from abc import abstractmethod | |
from methods.elasticdnn.model.base import ElasticDNNUtil | |
from methods.elasticdnn.pipeline.offline.fm_to_md.base import FM_to_MD_Util | |
from methods.elasticdnn.pipeline.offline.fm_lora.base import FMLoRA_Util | |
from utils.dl.common.model import LayerActivation | |
class ElasticDNN_OfflineFMModel(BaseModel): | |
def get_required_model_components(self) -> List[str]: | |
return ['main'] | |
def generate_md_by_reducing_width(self, reducing_width_ratio, samples: torch.Tensor): | |
pass | |
def forward_to_get_task_loss(self, x, y, *args, **kwargs): | |
raise NotImplementedError | |
def get_feature_hook(self) -> LayerActivation: | |
pass | |
def get_elastic_dnn_util(self) -> ElasticDNNUtil: | |
pass | |
def get_lora_util(self) -> FMLoRA_Util: | |
pass | |
def get_task_head_params(self): | |
pass | |
class ElasticDNN_OfflineClsFMModel(ElasticDNN_OfflineFMModel): | |
def get_accuracy(self, test_loader, *args, **kwargs): | |
acc = 0 | |
sample_num = 0 | |
self.to_eval_mode() | |
with torch.no_grad(): | |
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
for batch_index, (x, y) in pbar: | |
x, y = x.to(self.device), y.to(self.device) | |
output = self.infer(x) | |
pred = F.softmax(output.logits, dim=1).argmax(dim=1) | |
#correct = torch.eq(torch.argmax(output.logits,dim = 1), y).sum().item() | |
correct = torch.eq(pred, y).sum().item() | |
acc += correct | |
sample_num += len(y) | |
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, ' | |
f'cur_batch_acc: {(correct / len(y)):.4f}') | |
acc /= sample_num | |
return acc | |
def infer(self, x, *args, **kwargs): | |
return self.models_dict['main'](x) | |
import numpy as np | |
class StreamSegMetrics: | |
""" | |
Stream Metrics for Semantic Segmentation Task | |
""" | |
def __init__(self, n_classes): | |
self.n_classes = n_classes | |
self.confusion_matrix = np.zeros((n_classes, n_classes)) | |
def update(self, label_trues, label_preds): | |
for lt, lp in zip(label_trues, label_preds): | |
self.confusion_matrix += self._fast_hist( lt.flatten(), lp.flatten() ) | |
def to_str(results): | |
string = "\n" | |
for k, v in results.items(): | |
if k!="Class IoU": | |
string += "%s: %f\n"%(k, v) | |
return string | |
def _fast_hist(self, label_true, label_pred): | |
mask = (label_true >= 0) & (label_true < self.n_classes) | |
hist = np.bincount( | |
self.n_classes * label_true[mask].astype(int) + label_pred[mask], | |
minlength=self.n_classes ** 2, | |
).reshape(self.n_classes, self.n_classes) | |
return hist | |
def get_results(self): | |
"""Returns accuracy score evaluation result. | |
- overall accuracy | |
- mean accuracy | |
- mean IU | |
- fwavacc | |
""" | |
hist = self.confusion_matrix | |
acc = np.diag(hist).sum() / hist.sum() | |
acc_cls = np.diag(hist) / hist.sum(axis=1) | |
acc_cls = np.nanmean(acc_cls) | |
iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist)) | |
mean_iu = np.nanmean(iu) | |
freq = hist.sum(axis=1) / hist.sum() | |
fwavacc = (freq[freq > 0] * iu[freq > 0]).sum() | |
cls_iu = dict(zip(range(self.n_classes), iu)) | |
return { | |
"Overall Acc": acc, | |
"Mean Acc": acc_cls, | |
"FreqW Acc": fwavacc, | |
"Mean IoU": mean_iu, | |
"Class IoU": cls_iu, | |
} | |
def reset(self): | |
self.confusion_matrix = np.zeros((self.n_classes, self.n_classes)) | |
class ElasticDNN_OfflineSegFMModel(ElasticDNN_OfflineFMModel): | |
def __init__(self, name: str, models_dict_path: str, device: str, num_classes): | |
super().__init__(name, models_dict_path, device) | |
self.num_classes = num_classes | |
def get_accuracy(self, test_loader, *args, **kwargs): | |
device = self.device | |
self.to_eval_mode() | |
metrics = StreamSegMetrics(self.num_classes) | |
metrics.reset() | |
import tqdm | |
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), leave=False, dynamic_ncols=True) | |
with torch.no_grad(): | |
for batch_index, (x, y) in pbar: | |
x, y = x.to(device, dtype=x.dtype, non_blocking=True, copy=False), \ | |
y.to(device, dtype=y.dtype, non_blocking=True, copy=False) | |
output = self.infer(x) | |
pred = output.detach().max(dim=1)[1].cpu().numpy() | |
metrics.update((y + 0).cpu().numpy(), pred) | |
res = metrics.get_results() | |
pbar.set_description(f'cur batch mIoU: {res["Mean IoU"]:.4f}') | |
res = metrics.get_results() | |
return res['Mean IoU'] | |
def infer(self, x, *args, **kwargs): | |
return self.models_dict['main'](x) | |
class ElasticDNN_OfflineDetFMModel(ElasticDNN_OfflineFMModel): | |
def __init__(self, name: str, models_dict_path: str, device: str, num_classes): | |
super().__init__(name, models_dict_path, device) | |
self.num_classes = num_classes | |
def get_accuracy(self, test_loader, *args, **kwargs): | |
# print('DeeplabV3: start test acc') | |
_d = test_loader.dataset | |
from data import build_dataloader | |
if _d.__class__.__name__ == 'MergedDataset': | |
# print('\neval on merged datasets') | |
datasets = _d.datasets | |
test_loaders = [build_dataloader(d, test_loader.batch_size, test_loader.num_workers, False, None) for d in datasets] | |
accs = [self.get_accuracy(loader) for loader in test_loaders] | |
# print(accs) | |
return sum(accs) / len(accs) | |
# print('dataset len', len(test_loader.dataset)) | |
model = self.models_dict['main'] | |
device = self.device | |
model.eval() | |
# print('# classes', model.num_classes) | |
model = model.to(device) | |
from dnns.yolov3.coco_evaluator import COCOEvaluator | |
from utils.common.others import HiddenPrints | |
with torch.no_grad(): | |
with HiddenPrints(): | |
evaluator = COCOEvaluator( | |
dataloader=test_loader, | |
img_size=(224, 224), | |
confthre=0.01, | |
nmsthre=0.65, | |
num_classes=self.num_classes, | |
testdev=False | |
) | |
res = evaluator.evaluate(model, False, False) | |
map50 = res[1] | |
# print('eval info', res[-1]) | |
return map50 | |
def infer(self, x, *args, **kwargs): | |
if len(args) > 0: | |
print(args, len(args)) | |
return self.models_dict['main'](x, *args) # forward(x, label) | |
return self.models_dict['main'](x) | |
class ElasticDNN_OfflineSenClsFMModel(ElasticDNN_OfflineFMModel): | |
def get_accuracy(self, test_loader, *args, **kwargs): | |
acc = 0 | |
sample_num = 0 | |
self.to_eval_mode() | |
with torch.no_grad(): | |
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
for batch_index, (x, y) in pbar: | |
for k, v in x.items(): | |
if isinstance(v, torch.Tensor): | |
x[k] = v.to(self.device) | |
y = y.to(self.device) | |
output = self.infer(x) | |
pred = F.softmax(output, dim=1).argmax(dim=1) | |
correct = torch.eq(pred, y).sum().item() | |
acc += correct | |
sample_num += len(y) | |
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, ' | |
f'cur_batch_acc: {(correct / len(y)):.4f}') | |
acc /= sample_num | |
return acc | |
def infer(self, x, *args, **kwargs): | |
return self.models_dict['main'].forward(**x) | |
from accelerate.utils.operations import pad_across_processes | |
class ElasticDNN_OfflineTrFMModel(ElasticDNN_OfflineFMModel): | |
def get_accuracy(self, test_loader, *args, **kwargs): | |
# TODO: BLEU | |
from sacrebleu import corpus_bleu | |
acc = 0 | |
num_batches = 0 | |
self.to_eval_mode() | |
from data.datasets.sentiment_classification.global_bert_tokenizer import get_tokenizer | |
tokenizer = get_tokenizer() | |
def _decode(o): | |
# https://github.com/huggingface/transformers/blob/main/examples/research_projects/seq2seq-distillation/finetune.py#L133 | |
o = tokenizer.batch_decode(o, skip_special_tokens=True) | |
return [oi.strip() for oi in o] | |
with torch.no_grad(): | |
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
for batch_index, (x, y) in pbar: | |
for k, v in x.items(): | |
if isinstance(v, torch.Tensor): | |
x[k] = v.to(self.device) | |
label = y.to(self.device) | |
# generated_tokens = self.infer(x, generate=True) | |
generated_tokens = self.infer(x).logits.argmax(-1) | |
# pad tokens | |
generated_tokens = pad_across_processes( | |
generated_tokens, dim=1, pad_index=tokenizer.pad_token_id | |
) | |
# pad label | |
label = pad_across_processes( | |
label, dim=1, pad_index=tokenizer.pad_token_id | |
) | |
label = label.cpu().numpy() | |
label = np.where(label != -100, label, tokenizer.pad_token_id) | |
decoded_output = _decode(generated_tokens) | |
decoded_y = _decode(y) | |
decoded_y = [decoded_y] | |
if batch_index == 0: | |
print(decoded_y, decoded_output) | |
bleu = corpus_bleu(decoded_output, decoded_y).score | |
pbar.set_description(f'cur_batch_bleu: {bleu:.4f}') | |
acc += bleu | |
num_batches += 1 | |
acc /= num_batches | |
return acc | |
def infer(self, x, *args, **kwargs): | |
if 'token_type_ids' in x.keys(): | |
del x['token_type_ids'] | |
if 'generate' in kwargs: | |
return self.models_dict['main'].generate( | |
x['input_ids'], | |
attention_mask=x["attention_mask"], | |
max_length=512 | |
) | |
return self.models_dict['main'](**x) | |
from nltk.metrics import accuracy as nltk_acc | |
class ElasticDNN_OfflineTokenClsFMModel(ElasticDNN_OfflineFMModel): | |
def get_accuracy(self, test_loader, *args, **kwargs): | |
acc = 0 | |
sample_num = 0 | |
self.to_eval_mode() | |
with torch.no_grad(): | |
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
for batch_index, (x, y) in pbar: | |
for k, v in x.items(): | |
if isinstance(v, torch.Tensor): | |
x[k] = v.to(self.device) | |
# print(x) | |
y = y.to(self.device) | |
output = self.infer(x) | |
# torch.Size([16, 512, 43]) torch.Size([16, 512]) | |
for oi, yi, xi in zip(output, y, x['input_ids']): | |
# oi: 512, 43; yi: 512 | |
seq_len = xi.nonzero().size(0) | |
# print(output.size(), y.size()) | |
pred = F.softmax(oi, dim=-1).argmax(dim=-1) | |
correct = torch.eq(pred[1: seq_len], yi[1: seq_len]).sum().item() | |
# print(output.size(), y.size()) | |
acc += correct | |
sample_num += seq_len | |
pbar.set_description(f'seq_len: {seq_len}, cur_seq_acc: {(correct / seq_len):.4f}') | |
acc /= sample_num | |
return acc | |
def infer(self, x, *args, **kwargs): | |
return self.models_dict['main'](**x) | |
class ElasticDNN_OfflineMMClsFMModel(ElasticDNN_OfflineFMModel): | |
# def __init__(self, name: str, models_dict_path: str, device: str, class_to_label_idx_map): | |
# super().__init__(name, models_dict_path, device) | |
# self.class_to_label_idx_map = class_to_label_idx_map | |
def get_accuracy(self, test_loader, *args, **kwargs): | |
acc = 0 | |
sample_num = 0 | |
self.to_eval_mode() | |
batch_size = 1 | |
with torch.no_grad(): | |
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
for batch_index, (x, y) in pbar: | |
if batch_index * batch_size > 2000: | |
break | |
for k, v in x.items(): | |
if isinstance(v, torch.Tensor): | |
x[k] = v.to(self.device) | |
y = y.to(self.device) | |
# print(x) | |
raw_texts = x['texts'][:] | |
x['texts'] = list(set(x['texts'])) | |
# print(x['texts']) | |
batch_size = len(y) | |
x['for_training'] = False | |
output = self.infer(x) | |
output = output.logits_per_image | |
# print(output.size()) | |
# exit() | |
# y = torch.arange(len(y), device=self.device) | |
y = torch.LongTensor([x['texts'].index(rt) for rt in raw_texts]).to(self.device) | |
# print(y) | |
# exit() | |
# print(output.size(), y.size()) | |
pred = F.softmax(output, dim=1).argmax(dim=1) | |
correct = torch.eq(pred, y).sum().item() | |
acc += correct | |
sample_num += len(y) | |
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, ' | |
f'cur_batch_acc: {(correct / len(y)):.4f}') | |
acc /= sample_num | |
return acc | |
def infer(self, x, *args, **kwargs): | |
x['for_training'] = self.models_dict['main'].training | |
return self.models_dict['main'](**x) | |
class VQAScore: | |
def __init__(self): | |
# self.add_state("score", default=torch.tensor(0.0), dist_reduce_fx="sum") | |
# self.add_state("total", default=torch.tensor(0.0), dist_reduce_fx="sum") | |
self.score = torch.tensor(0.0) | |
self.total = torch.tensor(0.0) | |
def update(self, logits, target): | |
logits, target = ( | |
logits.detach().float().to(self.score.device), | |
target.detach().float().to(self.score.device), | |
) | |
logits = torch.max(logits, 1)[1] | |
one_hots = torch.zeros(*target.size()).to(target) | |
one_hots.scatter_(1, logits.view(-1, 1), 1) | |
scores = one_hots * target | |
self.score += scores.sum() | |
self.total += len(logits) | |
def compute(self): | |
return self.score / self.total | |
class ElasticDNN_OfflineVQAFMModel(ElasticDNN_OfflineFMModel): | |
# def __init__(self, name: str, models_dict_path: str, device: str, class_to_label_idx_map): | |
# super().__init__(name, models_dict_path, device) | |
# self.class_to_label_idx_map = class_to_label_idx_map | |
def get_accuracy(self, test_loader, *args, **kwargs): | |
acc = 0 | |
sample_num = 0 | |
vqa_score = VQAScore() | |
self.to_eval_mode() | |
with torch.no_grad(): | |
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
for batch_index, (x, y) in pbar: | |
for k, v in x.items(): | |
if isinstance(v, torch.Tensor): | |
x[k] = v.to(self.device) | |
y = y.to(self.device) | |
output = self.infer(x).logits | |
vqa_score.update(output, y) | |
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_acc: {vqa_score.compute():.4f}') | |
return vqa_score.compute() | |
def infer(self, x, *args, **kwargs): | |
return self.models_dict['main'](**x) | |
class ElasticDNN_OfflineMDModel(BaseModel): | |
def get_required_model_components(self) -> List[str]: | |
return ['main'] | |
def forward_to_get_task_loss(self, x, y, *args, **kwargs): | |
raise NotImplementedError | |
def get_feature_hook(self) -> LayerActivation: | |
pass | |
def get_distill_loss(self, student_output, teacher_output): | |
pass | |
def get_matched_param_of_fm(self, self_param_name, fm: nn.Module): | |
pass | |
class ElasticDNN_OfflineClsMDModel(ElasticDNN_OfflineMDModel): | |
def get_accuracy(self, test_loader, *args, **kwargs): | |
acc = 0 | |
sample_num = 0 | |
self.to_eval_mode() | |
with torch.no_grad(): | |
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
for batch_index, (x, y) in pbar: | |
x, y = x.to(self.device), y.to(self.device) | |
output = self.infer(x) | |
pred = F.softmax(output, dim=1).argmax(dim=1) | |
correct = torch.eq(pred, y).sum().item() | |
acc += correct | |
sample_num += len(y) | |
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, ' | |
f'cur_batch_acc: {(correct / len(y)):.4f}') | |
acc /= sample_num | |
return acc | |
def infer(self, x, *args, **kwargs): | |
return self.models_dict['main'](x) | |
class ElasticDNN_OfflineSegMDModel(ElasticDNN_OfflineMDModel): | |
def __init__(self, name: str, models_dict_path: str, device: str, num_classes): | |
super().__init__(name, models_dict_path, device) | |
self.num_classes = num_classes | |
def get_accuracy(self, test_loader, *args, **kwargs): | |
device = self.device | |
self.to_eval_mode() | |
metrics = StreamSegMetrics(self.num_classes) | |
metrics.reset() | |
import tqdm | |
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), leave=False, dynamic_ncols=True) | |
with torch.no_grad(): | |
for batch_index, (x, y) in pbar: | |
x, y = x.to(device, dtype=x.dtype, non_blocking=True, copy=False), \ | |
y.to(device, dtype=y.dtype, non_blocking=True, copy=False) | |
output = self.infer(x) | |
pred = output.detach().max(dim=1)[1].cpu().numpy() | |
metrics.update((y + 0).cpu().numpy(), pred) | |
res = metrics.get_results() | |
pbar.set_description(f'cur batch mIoU: {res["Mean IoU"]:.4f}') | |
res = metrics.get_results() | |
return res['Mean IoU'] | |
def infer(self, x, *args, **kwargs): | |
return self.models_dict['main'](x) | |
class ElasticDNN_OfflineDetMDModel(ElasticDNN_OfflineMDModel): | |
def __init__(self, name: str, models_dict_path: str, device: str, num_classes): | |
super().__init__(name, models_dict_path, device) | |
self.num_classes = num_classes | |
def get_accuracy(self, test_loader, *args, **kwargs): | |
# print('DeeplabV3: start test acc') | |
_d = test_loader.dataset | |
from data import build_dataloader | |
if _d.__class__.__name__ == 'MergedDataset': | |
# print('\neval on merged datasets') | |
datasets = _d.datasets | |
test_loaders = [build_dataloader(d, test_loader.batch_size, test_loader.num_workers, False, None) for d in datasets] | |
accs = [self.get_accuracy(loader) for loader in test_loaders] | |
# print(accs) | |
return sum(accs) / len(accs) | |
# print('dataset len', len(test_loader.dataset)) | |
model = self.models_dict['main'] | |
device = self.device | |
model.eval() | |
# print('# classes', model.num_classes) | |
model = model.to(device) | |
from dnns.yolov3.coco_evaluator import COCOEvaluator | |
from utils.common.others import HiddenPrints | |
with torch.no_grad(): | |
with HiddenPrints(): | |
evaluator = COCOEvaluator( | |
dataloader=test_loader, | |
img_size=(224, 224), | |
confthre=0.01, | |
nmsthre=0.65, | |
num_classes=self.num_classes, | |
testdev=False | |
) | |
res = evaluator.evaluate(model, False, False) | |
map50 = res[1] | |
# print('eval info', res[-1]) | |
return map50 | |
def infer(self, x, *args, **kwargs): | |
if len(args) > 0: | |
return self.models_dict['main'](x, *args) # forward(x, label) | |
return self.models_dict['main'](x) | |
class ElasticDNN_OfflineSenClsMDModel(ElasticDNN_OfflineMDModel): | |
def get_accuracy(self, test_loader, *args, **kwargs): | |
acc = 0 | |
sample_num = 0 | |
self.to_eval_mode() | |
with torch.no_grad(): | |
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
for batch_index, (x, y) in pbar: | |
for k, v in x.items(): | |
if isinstance(v, torch.Tensor): | |
x[k] = v.to(self.device) | |
y = y.to(self.device) | |
output = self.infer(x) | |
pred = F.softmax(output, dim=1).argmax(dim=1) | |
# print(pred, y) | |
correct = torch.eq(pred, y).sum().item() | |
acc += correct | |
sample_num += len(y) | |
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, ' | |
f'cur_batch_acc: {(correct / len(y)):.4f}') | |
acc /= sample_num | |
return acc | |
def infer(self, x, *args, **kwargs): | |
return self.models_dict['main'](**x) | |
class ElasticDNN_OfflineTrMDModel(ElasticDNN_OfflineMDModel): | |
def get_accuracy(self, test_loader, *args, **kwargs): | |
# TODO: BLEU | |
from sacrebleu import corpus_bleu | |
acc = 0 | |
num_batches = 0 | |
self.to_eval_mode() | |
from data.datasets.sentiment_classification.global_bert_tokenizer import get_tokenizer | |
tokenizer = get_tokenizer() | |
def _decode(o): | |
# https://github.com/huggingface/transformers/blob/main/examples/research_projects/seq2seq-distillation/finetune.py#L133 | |
o = tokenizer.batch_decode(o, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
return [oi.strip().replace(' ', '') for oi in o] | |
with torch.no_grad(): | |
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
for batch_index, (x, y) in pbar: | |
for k, v in x.items(): | |
if isinstance(v, torch.Tensor): | |
x[k] = v.to(self.device) | |
y = y.to(self.device) | |
output = self.infer(x) | |
decoded_output = _decode(output.argmax(-1)) | |
decoded_y = _decode(y) | |
decoded_y = [decoded_y] | |
# print(x, decoded_y, decoded_output, output.argmax(-1)) | |
bleu = corpus_bleu(decoded_output, decoded_y).score | |
pbar.set_description(f'cur_batch_bleu: {bleu:.4f}') | |
acc += bleu | |
num_batches += 1 | |
acc /= num_batches | |
return acc | |
def infer(self, x, *args, **kwargs): | |
return self.models_dict['main'](**x) | |
class ElasticDNN_OfflineTokenClsMDModel(ElasticDNN_OfflineMDModel): | |
def get_accuracy(self, test_loader, *args, **kwargs): | |
acc = 0 | |
sample_num = 0 | |
self.to_eval_mode() | |
with torch.no_grad(): | |
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
for batch_index, (x, y) in pbar: | |
for k, v in x.items(): | |
if isinstance(v, torch.Tensor): | |
x[k] = v.to(self.device) | |
# print(x) | |
y = y.to(self.device) | |
output = self.infer(x) | |
# torch.Size([16, 512, 43]) torch.Size([16, 512]) | |
for oi, yi, xi in zip(output, y, x['input_ids']): | |
# oi: 512, 43; yi: 512 | |
seq_len = xi.nonzero().size(0) | |
# print(output.size(), y.size()) | |
pred = F.softmax(oi, dim=-1).argmax(dim=-1) | |
correct = torch.eq(pred[1: seq_len], yi[1: seq_len]).sum().item() | |
# print(output.size(), y.size()) | |
acc += correct | |
sample_num += seq_len | |
pbar.set_description(f'seq_len: {seq_len}, cur_seq_acc: {(correct / seq_len):.4f}') | |
acc /= sample_num | |
return acc | |
def infer(self, x, *args, **kwargs): | |
return self.models_dict['main'](**x) | |
class ElasticDNN_OfflineMMClsMDModel(ElasticDNN_OfflineMDModel): | |
def get_accuracy(self, test_loader, *args, **kwargs): | |
acc = 0 | |
sample_num = 0 | |
self.to_eval_mode() | |
batch_size = 1 | |
with torch.no_grad(): | |
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
for batch_index, (x, y) in pbar: | |
if batch_index * batch_size > 2000: | |
break | |
for k, v in x.items(): | |
if isinstance(v, torch.Tensor): | |
x[k] = v.to(self.device) | |
y = y.to(self.device) | |
# print(x) | |
raw_texts = x['texts'][:] | |
x['texts'] = list(set(x['texts'])) | |
# print(x['texts']) | |
batch_size = len(y) | |
x['for_training'] = False | |
output = self.infer(x) | |
output = output.logits_per_image | |
# print(output.size()) | |
# exit() | |
# y = torch.arange(len(y), device=self.device) | |
y = torch.LongTensor([x['texts'].index(rt) for rt in raw_texts]).to(self.device) | |
# print(y) | |
# exit() | |
# print(output.size(), y.size()) | |
pred = F.softmax(output, dim=1).argmax(dim=1) | |
correct = torch.eq(pred, y).sum().item() | |
acc += correct | |
sample_num += len(y) | |
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_correct: {correct}, ' | |
f'cur_batch_acc: {(correct / len(y)):.4f}') | |
acc /= sample_num | |
return acc | |
def infer(self, x, *args, **kwargs): | |
return self.models_dict['main'](**x) | |
class ElasticDNN_OfflineVQAMDModel(ElasticDNN_OfflineMDModel): | |
# def __init__(self, name: str, models_dict_path: str, device: str, class_to_label_idx_map): | |
# super().__init__(name, models_dict_path, device) | |
# self.class_to_label_idx_map = class_to_label_idx_map | |
def get_accuracy(self, test_loader, *args, **kwargs): | |
acc = 0 | |
sample_num = 0 | |
vqa_score = VQAScore() | |
self.to_eval_mode() | |
with torch.no_grad(): | |
pbar = tqdm.tqdm(enumerate(test_loader), total=len(test_loader), dynamic_ncols=True, leave=False) | |
for batch_index, (x, y) in pbar: | |
for k, v in x.items(): | |
if isinstance(v, torch.Tensor): | |
x[k] = v.to(self.device) | |
y = y.to(self.device) | |
output = self.infer(x).logits | |
vqa_score.update(output, y) | |
pbar.set_description(f'cur_batch_total: {len(y)}, cur_batch_acc: {vqa_score.compute():.4f}') | |
return vqa_score.compute() | |
def infer(self, x, *args, **kwargs): | |
return self.models_dict['main'](**x) |