|
from transformers import ( |
|
AutoConfig, |
|
AutoModelForSeq2SeqLM, |
|
AutoTokenizer |
|
) |
|
|
|
from datasets import load_dataset |
|
import torch |
|
|
|
def load_nlu_model(): |
|
config = AutoConfig.from_pretrained("Beomseok-LEE/NLU-Speech-MASSIVE-finetune") |
|
tokenizer = AutoTokenizer.from_pretrained("Beomseok-LEE/NLU-Speech-MASSIVE-finetune") |
|
model = AutoModelForSeq2SeqLM.from_pretrained("Beomseok-LEE/NLU-Speech-MASSIVE-finetune", config=config) |
|
|
|
return model, tokenizer |
|
|
|
def run_nlu_inference(model, tokenizer, example): |
|
print(example) |
|
formatted_example = "Annotate: " + example |
|
input_values = tokenizer(formatted_example, max_length=128, padding=False, truncation=True, return_tensors="pt").input_ids |
|
|
|
with torch.no_grad(): |
|
pred_ids = model.generate(input_values) |
|
|
|
prediction = tokenizer.decode(pred_ids[0], skip_special_tokens=True) |
|
print(prediction) |
|
|
|
splitted_pred = prediction.strip().split() |
|
|
|
slots_prediction = '' |
|
intent_prediction = '' |
|
|
|
if len(splitted_pred) >= 2: |
|
slots_prediction = splitted_pred[:-1] |
|
intent_prediction = splitted_pred[-1] |
|
if len(splitted_pred) == 1: |
|
slots_prediction = splitted_pred |
|
|
|
words = example.split(' ') |
|
|
|
title_1 = '[ASR output]\n' |
|
title_2 = '\n\n[NLU - slot filling]\n' |
|
title_3 = '\n\n[NLU - intent classifcation]\n' |
|
|
|
prefix_str_1 = title_1 + example + title_2 |
|
prefix_str_2 = title_3 |
|
|
|
structured_output = { |
|
'text' : prefix_str_1 + example + prefix_str_2 + intent_prediction, |
|
'entities': []} |
|
|
|
structured_output['entities'].append({ |
|
'entity': 'ASR output', |
|
'word': example, |
|
'start': len(title_1), |
|
'end': len(title_1) + len(example) |
|
}) |
|
|
|
idx = len(prefix_str_1) |
|
|
|
for slot, word in zip(slots_prediction, words): |
|
_entity = slot |
|
_word = word |
|
_start = idx |
|
_end = idx + len(word) |
|
idx = _end + 1 |
|
|
|
structured_output['entities'].append({ |
|
'entity': _entity, |
|
'word': _word, |
|
'start': _start, |
|
'end': _end |
|
}) |
|
|
|
idx = len(prefix_str_1 + example + prefix_str_2) |
|
|
|
if intent_prediction: |
|
structured_output['entities'].append({ |
|
'entity': 'Classified Intent', |
|
'word': intent_prediction, |
|
'start': idx, |
|
'end': idx + len(intent_prediction) |
|
}) |
|
|
|
return structured_output |