Spaces:
Sleeping
Sleeping
#tuto : https://gradio.app/creating_a_chatbot/ | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import torch | |
import re | |
ckpt = 'armandnlp/gpt2-TOD_finetuned_SGD' | |
tokenizer = AutoTokenizer.from_pretrained(ckpt) | |
model = AutoModelForCausalLM.from_pretrained(ckpt) | |
def format_resp(system_resp): | |
# format Belief, Action and Response tags | |
system_resp = system_resp.replace('<|belief|>', '*Belief State: ') | |
system_resp = system_resp.replace('<|action|>', '*Actions: ') | |
system_resp = system_resp.replace('<|response|>', '*System Response: ') | |
return system_resp | |
def predict(input, history=[]): | |
if history != []: | |
# model expects only user and system responses, no belief or action sequences | |
# therefore we clean up the history first. | |
# history is a list of token ids which represents all the previous states in the conversation | |
# ie. tokenied user inputs + tokenized model outputs | |
history_str = tokenizer.decode(history[0]) | |
turns = re.split('<\|system\|>|<\|user\|>', history_str)[1:] | |
for i in range(0, len(turns)-1, 2): | |
turns[i] = '<|user|>' + turns[i] | |
# keep only the response part of each system_out in the history (no belief and action) | |
turns[i+1] = '<|system|>' + turns[i+1].split('<|response|>')[1] | |
history4input = tokenizer.encode(''.join(turns), return_tensors='pt') | |
else: | |
history4input = torch.LongTensor(history) | |
# format input for model by concatenating <|context|> + history4input + new_input + <|endofcontext|> | |
new_user_input_ids = tokenizer.encode(' <|user|> '+input, return_tensors='pt') | |
context = tokenizer.encode('<|context|>', return_tensors='pt') | |
endofcontext = tokenizer.encode(' <|endofcontext|>', return_tensors='pt') | |
model_input = torch.cat([context, history4input, new_user_input_ids, endofcontext], dim=-1) | |
# generate output | |
out = model.generate(model_input, max_length=1024, eos_token_id=50262).tolist()[0] | |
# formatting the history | |
# leave out endof... tokens | |
string_out = tokenizer.decode(out) | |
system_out = string_out.split('<|endofcontext|>')[1].replace('<|endofbelief|>', '').replace('<|endofaction|>', '').replace('<|endofresponse|>', '') | |
resp_tokenized = tokenizer.encode(' <|system|> '+system_out, return_tensors='pt') | |
history = torch.cat([torch.LongTensor(history), new_user_input_ids, resp_tokenized], dim=-1).tolist() | |
# history = history + last user input + <|system|> <|belief|> ... <|action|> ... <|response|>... | |
# format responses to print out | |
# need to output all of the turns, hence why the history must contain belief + action info | |
# even if we have to take it out of the model input | |
turns = tokenizer.decode(history[0]) | |
turns = re.split('<\|system\|>|<\|user\|>', turns)[1:] # list of all the user and system turns until now | |
# list of tuples [(user, system), (user, system)...] | |
# 1 tuple represents 1 exchange at 1 turn | |
# system resp is formatted with function above to make more readable | |
resps = [(turns[i], format_resp(turns[i+1])) for i in range(0, len(turns)-1, 2)] | |
return resps, history | |
examples = [["I want to book a restaurant for 2 people on Saturday."], | |
["What's the weather in Cambridge today ?"], | |
["I need to find a bus to Boston."], | |
["I want to add an event to my calendar."], | |
["I would like to book a plane ticket to New York."], | |
["I want to find a concert around LA."], | |
["Hi, I'd like to find an apartment in London please."], | |
["Can you find me a hotel room near Seattle please ?"], | |
["I want to watch a film online, a comedy would be nice"], | |
["I want to transfer some money please."], | |
["I want to reserve a movie ticket for tomorrow evening"], | |
["Can you play the song Learning to Fly by Tom Petty ?"], | |
["I need to rent a small car."] | |
] | |
description = """ | |
This is an interactive window to chat with GPT-2 fine-tuned on the Schema-Guided Dialogues dataset, | |
in which we find domains such as travel, weather, media, calendar, banking, | |
restaurant booking... | |
""" | |
article = """ | |
### Model Outputs | |
This task-oriented dialogue system is trained end-to-end, following the method detailed in | |
[SimpleTOD](https://arxiv.org/pdf/2005.00796.pdf), where GPT-2 is trained by casting task-oriented | |
dialogue as a seq2seq task. | |
From the dialogue history, composed of the previous user and system responses, the model is trained | |
to output the belief state, the action decisions and the system response as a sequence. We show all | |
three outputs in this demo : the belief state tracks the user goal (restaurant cuisine : Indian or media | |
genre : comedy for ex.), the action decisions show how the system should proceed (restaurants request city | |
or media offer title for ex.) and the natural language response provides an output the user can interpret. | |
The model responses are *de-lexicalized* : database values in the training set have been replaced with their | |
slot names to make the learning process database agnostic. These slots are meant to later be replaced by actual | |
results from a database, using the belief state to issue calls. | |
The model is capable of dealing with multiple domains : a list of possible inputs is provided to get the | |
conversation going. | |
### Dataset | |
The SGD dataset ([blogpost](https://ai.googleblog.com/2019/10/introducing-schema-guided-dialogue.html) and | |
[article](https://arxiv.org/pdf/1909.05855.pdf)) contains multiple task domains... Here is a list of some | |
of the services and their descriptions from the dataset: | |
* **Restaurants** : A leading provider for restaurant search and reservations | |
* **Weather** : Check the weather for any place and any date | |
* **Buses** : Find a bus to take you to the city you want | |
* **Calendar** : Calendar service to manage personal events and reservations | |
* **Flights** : Find your next flight | |
* **Events** : Get tickets for the coolest concerts and sports in your area | |
* **Homes** : A widely used service for finding apartments and scheduling visits | |
* **Hotels** : A popular service for searching and reserving rooms in hotels | |
* **Media** : A leading provider of movies for searching and watching on-demand | |
* **Banks** : Manage bank accounts and transfer money | |
* **Movies** : A go-to provider for finding movies, searching for show times and booking tickets | |
* **Music** : A popular provider of a wide range of music content for searching and listening | |
* **RentalCars** : Car rental service with extensive coverage of locations and cars | |
""" | |
import gradio as gr | |
gr.Interface(fn=predict, | |
inputs=["text", "state"], | |
outputs=["chatbot", "state"], | |
title="Chatting with multi task-oriented GPT2", | |
examples=examples, | |
description=description, | |
article=article | |
).launch() | |