Spaces:
Runtime error
Runtime error
from pathlib import Path | |
from transformers import AutoTokenizer, pipeline | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from text_utils import post_process_answer | |
from graph_utils import find_best_cluster | |
from optimum.intel import OVModelForQuestionAnswering | |
import os | |
import json | |
from text_utils import * | |
# os.environ['HTTP_PROXY'] = 'http://proxy.hcm.fpt.vn:80' | |
class QAEnsembleModel_modify(nn.Module): | |
# def __init__(self, model_name, model_checkpoints, entity_dict, | |
# thr=0.1, device="cuda:0"): | |
def __init__(self, model_name, entity_dict, | |
thr=0.1, device="cpu"): | |
super(QAEnsembleModel_modify, self).__init__() | |
self.nlps = [] | |
# model_checkpoint = "./data/qa_model_robust.bin" | |
AUTH_TOKEN = "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" | |
# model_checkpoint = "letrunglinh/qa_pnc" | |
model_convert = OVModelForQuestionAnswering.from_pretrained(model_name, export= True, use_auth_token= AUTH_TOKEN) | |
# model_convert.half() | |
# model_convert.compile() | |
nlp = pipeline('question-answering', model=model_convert, | |
tokenizer=model_name) | |
self.nlps.append(nlp) | |
self.entity_dict = entity_dict | |
self.thr = thr | |
def forward(self, question, texts, ranking_scores=None): | |
if ranking_scores is None: | |
ranking_scores = np.ones((len(texts),)) | |
curr_answers = [] | |
curr_scores = [] | |
best_score = 0 | |
for i, nlp in enumerate(self.nlps): | |
for text, score in zip(texts, ranking_scores): | |
QA_input = { | |
'question': question, | |
'context': text | |
} | |
res = nlp(QA_input) | |
print(res) | |
if res["score"] > self.thr: | |
curr_answers.append(res["answer"]) | |
curr_scores.append(res["score"]) | |
res["score"] = res["score"] * score | |
if i == 0: | |
if res["score"] > best_score: | |
answer = res["answer"] | |
best_score = res["score"] | |
if len(curr_answers) == 0: | |
return None | |
curr_answers = [post_process_answer(x, self.entity_dict) for x in curr_answers] | |
answer = post_process_answer(answer, self.entity_dict) | |
new_best_answer = post_process_answer(find_best_cluster(curr_answers, answer), self.entity_dict) | |
return new_best_answer |