mzboito's picture
files upload
ffa317c
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer
)
from datasets import load_dataset
import torch
def load_nlu_model():
config = AutoConfig.from_pretrained("Beomseok-LEE/NLU-Speech-MASSIVE-finetune")
tokenizer = AutoTokenizer.from_pretrained("Beomseok-LEE/NLU-Speech-MASSIVE-finetune")
model = AutoModelForSeq2SeqLM.from_pretrained("Beomseok-LEE/NLU-Speech-MASSIVE-finetune", config=config)
return model, tokenizer
def run_nlu_inference(model, tokenizer, example):
print(example)
formatted_example = "Annotate: " + example
input_values = tokenizer(formatted_example, max_length=128, padding=False, truncation=True, return_tensors="pt").input_ids
with torch.no_grad():
pred_ids = model.generate(input_values)
prediction = tokenizer.decode(pred_ids[0], skip_special_tokens=True)
print(prediction)
splitted_pred = prediction.strip().split()
slots_prediction = ''
intent_prediction = ''
if len(splitted_pred) >= 2:
slots_prediction = splitted_pred[:-1]
intent_prediction = splitted_pred[-1]
if len(splitted_pred) == 1:
slots_prediction = splitted_pred
words = example.split(' ')
title_1 = '[ASR output]\n'
title_2 = '\n\n[NLU - slot filling]\n'
title_3 = '\n\n[NLU - intent classifcation]\n'
prefix_str_1 = title_1 + example + title_2
prefix_str_2 = title_3
structured_output = {
'text' : prefix_str_1 + example + prefix_str_2 + intent_prediction,
'entities': []}
structured_output['entities'].append({
'entity': 'ASR output',
'word': example,
'start': len(title_1),
'end': len(title_1) + len(example)
})
idx = len(prefix_str_1)
for slot, word in zip(slots_prediction, words):
_entity = slot
_word = word
_start = idx
_end = idx + len(word)
idx = _end + 1
structured_output['entities'].append({
'entity': _entity,
'word': _word,
'start': _start,
'end': _end
})
idx = len(prefix_str_1 + example + prefix_str_2)
if intent_prediction:
structured_output['entities'].append({
'entity': 'Classified Intent',
'word': intent_prediction,
'start': idx,
'end': idx + len(intent_prediction)
})
return structured_output