kpriyanshu256's picture
Added app files
b1aad3c
raw
history blame
4.64 kB
import gradio as gr
import json
import numpy as np
import torch
import transformers
import tokenizers
from model import BertAD
DICTIONARY = json.load(open('model/dict.json'))
TOKENIZER = tokenizers.BertWordPieceTokenizer(f"model/vocab.txt", lowercase=True)
MAX_LEN = 256
MODEL = BertAD()
vec = MODEL.state_dict()['bert.embeddings.position_ids']
chkp = torch.load(os.path.join('model', 'model_0.bin'), map_location='cpu')
chkp['bert.embeddings.position_ids'] =vec
MODEL.load_state_dict(chkp)
def sample_text(text, acronym, max_len):
text = text.split()
idx = text.index(acronym)
left_idx = max(0, idx - max_len//2)
right_idx = min(len(text), idx + max_len//2)
sampled_text = text[left_idx:right_idx]
return ' '.join(sampled_text)
def process_data(text, acronym, expansion, tokenizer, max_len):
text = str(text)
expansion = str(expansion)
acronym = str(acronym)
n_tokens = len(text.split())
if n_tokens>120:
text = sample_text(text, acronym, 120)
answers = acronym + ' ' + ' '.join(DICTIONARY[acronym])
start = answers.find(expansion)
end = start + len(expansion)
char_mask = [0]*len(answers)
for i in range(start, end):
char_mask[i] = 1
tok_answer = tokenizer.encode(answers)
answer_ids = tok_answer.ids
answer_offsets = tok_answer.offsets
answer_ids = answer_ids[1:-1]
answer_offsets = answer_offsets[1:-1]
target_idx = []
for i, (off1, off2) in enumerate(answer_offsets):
if sum(char_mask[off1:off2])>0:
target_idx.append(i)
start = target_idx[0]
end = target_idx[-1]
text_ids = tokenizer.encode(text).ids[1:-1]
token_ids = [101] + answer_ids + [102] + text_ids + [102]
offsets = [(0,0)] + answer_offsets + [(0,0)]*(len(text_ids) + 2)
mask = [1] * len(token_ids)
token_type = [0]*(len(answer_ids) + 1) + [1]*(2+len(text_ids))
text = answers + text
start = start + 1
end = end + 1
padding = max_len - len(token_ids)
if padding>=0:
token_ids = token_ids + ([0] * padding)
token_type = token_type + [1] * padding
mask = mask + ([0] * padding)
offsets = offsets + ([(0, 0)] * padding)
else:
token_ids = token_ids[0:max_len]
token_type = token_type[0:max_len]
mask = mask[0:max_len]
offsets = offsets[0:max_len]
assert len(token_ids)==max_len
assert len(mask)==max_len
assert len(offsets)==max_len
assert len(token_type)==max_len
return {
'ids': token_ids,
'mask': mask,
'token_type': token_type,
'offset': offsets,
'start': start,
'end': end,
'text': text,
'expansion': expansion,
'acronym': acronym,
}
def jaccard(str1, str2):
a = set(str1.lower().split())
b = set(str2.lower().split())
c = a.intersection(b)
return float(len(c)) / (len(a) + len(b) - len(c))
def evaluate_jaccard(text, selected_text, acronym, offsets, idx_start, idx_end):
filtered_output = ""
for ix in range(idx_start, idx_end + 1):
filtered_output += text[offsets[ix][0]: offsets[ix][1]]
if (ix+1) < len(offsets) and offsets[ix][1] < offsets[ix+1][0]:
filtered_output += " "
candidates = DICTIONARY[acronym]
candidate_jaccards = [jaccard(w.strip(), filtered_output.strip()) for w in candidates]
idx = np.argmax(candidate_jaccards)
return candidate_jaccards[idx], candidates[idx]
def disambiguate(text, acronym):
inputs = process_data(text, acronym, acronym, TOKENIZER, MAX_LEN)
ids = torch.tensor(input['ids']).view(1, -1)
mask = torch.tensor(inputs['mask']).view(1, -1)
token_type = torch.tensor(inputs['token_type']).view(1, -1)
offsets = inputs['offset']
expansion = inputs['expnsion']
acronym = inputs['acronym']
start_logits, end_logits = MODEL(ids, mask, token_type)
start_prob = torch.softmax(start_logits, axis=-1).detach().numpy()
end_prob = torch.softmax(end_logits, axis=-1).detach().numpy()
start_idx = np.argmax(start_prob[0,:])
end_idx = np.argmax(end_prob[0,:])
js, exp = evaluate_jaccard(text, expansion[0], acronym[0], offsets[0], start_idx, end_idx)
return exp
text = gr.inputs.Textbox(lines=5, label="Context", placeholder="Type a sentence or paragraph here."),
acronym = gr.inputs.Textbox(lines=2, label="Question", placeholder="Type acronym")
expansion = gr.outputs.Textbox(label="Answer")
iface = gr.Interface(fn=disambiguate, inputs=[text, acronym], outputs=expansion)
iface.launch()