chatbot_full / qa_model.py
letrunglinh's picture
Upload 15 files
fa01b79
raw
history blame
2.62 kB
from pathlib import Path
from transformers import AutoTokenizer, pipeline
from optimum.onnxruntime import ORTModelForQuestionAnswering
import numpy as np
import torch
import torch.nn as nn
from transformers import AutoModelForQuestionAnswering, pipeline
from text_utils import post_process_answer
from graph_utils import find_best_cluster
from optimum.intel import OVModelForQuestionAnswering
import os
import json
from text_utils import *
# os.environ['HTTP_PROXY'] = 'http://proxy.hcm.fpt.vn:80'
class QAEnsembleModel_modify(nn.Module):
# def __init__(self, model_name, model_checkpoints, entity_dict,
# thr=0.1, device="cuda:0"):
def __init__(self, model_name, entity_dict,
thr=0.1, device="cpu"):
super(QAEnsembleModel_modify, self).__init__()
self.nlps = []
# model_checkpoint = "./data/qa_model_robust.bin"
AUTH_TOKEN = "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"
# model_checkpoint = "letrunglinh/qa_pnc"
model_convert = OVModelForQuestionAnswering.from_pretrained(model_name, export= True, use_auth_token= AUTH_TOKEN)
# model_convert.half()
# model_convert.compile()
nlp = pipeline('question-answering', model=model_convert,
tokenizer=model_name)
self.nlps.append(nlp)
self.entity_dict = entity_dict
self.thr = thr
def forward(self, question, texts, ranking_scores=None):
if ranking_scores is None:
ranking_scores = np.ones((len(texts),))
curr_answers = []
curr_scores = []
best_score = 0
for i, nlp in enumerate(self.nlps):
for text, score in zip(texts, ranking_scores):
QA_input = {
'question': question,
'context': text
}
res = nlp(QA_input)
print(res)
if res["score"] > self.thr:
curr_answers.append(res["answer"])
curr_scores.append(res["score"])
res["score"] = res["score"] * score
if i == 0:
if res["score"] > best_score:
answer = res["answer"]
best_score = res["score"]
if len(curr_answers) == 0:
return None
curr_answers = [post_process_answer(x, self.entity_dict) for x in curr_answers]
answer = post_process_answer(answer, self.entity_dict)
new_best_answer = post_process_answer(find_best_cluster(curr_answers, answer), self.entity_dict)
return new_best_answer