gpt2-TOD_app / app.py
armandstrickernlp
update article
0e21bc9
#tuto : https://gradio.app/creating_a_chatbot/
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import re
ckpt = 'armandnlp/gpt2-TOD_finetuned_SGD'
tokenizer = AutoTokenizer.from_pretrained(ckpt)
model = AutoModelForCausalLM.from_pretrained(ckpt)
def format_resp(system_resp):
# format Belief, Action and Response tags
system_resp = system_resp.replace('<|belief|>', '*Belief State: ')
system_resp = system_resp.replace('<|action|>', '*Actions: ')
system_resp = system_resp.replace('<|response|>', '*System Response: ')
return system_resp
def predict(input, history=[]):
if history != []:
# model expects only user and system responses, no belief or action sequences
# therefore we clean up the history first.
# history is a list of token ids which represents all the previous states in the conversation
# ie. tokenied user inputs + tokenized model outputs
history_str = tokenizer.decode(history[0])
turns = re.split('<\|system\|>|<\|user\|>', history_str)[1:]
for i in range(0, len(turns)-1, 2):
turns[i] = '<|user|>' + turns[i]
# keep only the response part of each system_out in the history (no belief and action)
turns[i+1] = '<|system|>' + turns[i+1].split('<|response|>')[1]
history4input = tokenizer.encode(''.join(turns), return_tensors='pt')
else:
history4input = torch.LongTensor(history)
# format input for model by concatenating <|context|> + history4input + new_input + <|endofcontext|>
new_user_input_ids = tokenizer.encode(' <|user|> '+input, return_tensors='pt')
context = tokenizer.encode('<|context|>', return_tensors='pt')
endofcontext = tokenizer.encode(' <|endofcontext|>', return_tensors='pt')
model_input = torch.cat([context, history4input, new_user_input_ids, endofcontext], dim=-1)
# generate output
out = model.generate(model_input, max_length=1024, eos_token_id=50262).tolist()[0]
# formatting the history
# leave out endof... tokens
string_out = tokenizer.decode(out)
system_out = string_out.split('<|endofcontext|>')[1].replace('<|endofbelief|>', '').replace('<|endofaction|>', '').replace('<|endofresponse|>', '')
resp_tokenized = tokenizer.encode(' <|system|> '+system_out, return_tensors='pt')
history = torch.cat([torch.LongTensor(history), new_user_input_ids, resp_tokenized], dim=-1).tolist()
# history = history + last user input + <|system|> <|belief|> ... <|action|> ... <|response|>...
# format responses to print out
# need to output all of the turns, hence why the history must contain belief + action info
# even if we have to take it out of the model input
turns = tokenizer.decode(history[0])
turns = re.split('<\|system\|>|<\|user\|>', turns)[1:] # list of all the user and system turns until now
# list of tuples [(user, system), (user, system)...]
# 1 tuple represents 1 exchange at 1 turn
# system resp is formatted with function above to make more readable
resps = [(turns[i], format_resp(turns[i+1])) for i in range(0, len(turns)-1, 2)]
return resps, history
examples = [["I want to book a restaurant for 2 people on Saturday."],
["What's the weather in Cambridge today ?"],
["I need to find a bus to Boston."],
["I want to add an event to my calendar."],
["I would like to book a plane ticket to New York."],
["I want to find a concert around LA."],
["Hi, I'd like to find an apartment in London please."],
["Can you find me a hotel room near Seattle please ?"],
["I want to watch a film online, a comedy would be nice"],
["I want to transfer some money please."],
["I want to reserve a movie ticket for tomorrow evening"],
["Can you play the song Learning to Fly by Tom Petty ?"],
["I need to rent a small car."]
]
description = """
This is an interactive window to chat with GPT-2 fine-tuned on the Schema-Guided Dialogues dataset,
in which we find domains such as travel, weather, media, calendar, banking,
restaurant booking...
"""
article = """
### Model Outputs
This task-oriented dialogue system is trained end-to-end, following the method detailed in
[SimpleTOD](https://arxiv.org/pdf/2005.00796.pdf), where GPT-2 is trained by casting task-oriented
dialogue as a seq2seq task.
From the dialogue history, composed of the previous user and system responses, the model is trained
to output the belief state, the action decisions and the system response as a sequence. We show all
three outputs in this demo : the belief state tracks the user goal (restaurant cuisine : Indian or media
genre : comedy for ex.), the action decisions show how the system should proceed (restaurants request city
or media offer title for ex.) and the natural language response provides an output the user can interpret.
The model responses are *de-lexicalized* : database values in the training set have been replaced with their
slot names to make the learning process database agnostic. These slots are meant to later be replaced by actual
results from a database, using the belief state to issue calls.
The model is capable of dealing with multiple domains : a list of possible inputs is provided to get the
conversation going.
### Dataset
The SGD dataset ([blogpost](https://ai.googleblog.com/2019/10/introducing-schema-guided-dialogue.html) and
[article](https://arxiv.org/pdf/1909.05855.pdf)) contains multiple task domains... Here is a list of some
of the services and their descriptions from the dataset:
* **Restaurants** : A leading provider for restaurant search and reservations
* **Weather** : Check the weather for any place and any date
* **Buses** : Find a bus to take you to the city you want
* **Calendar** : Calendar service to manage personal events and reservations
* **Flights** : Find your next flight
* **Events** : Get tickets for the coolest concerts and sports in your area
* **Homes** : A widely used service for finding apartments and scheduling visits
* **Hotels** : A popular service for searching and reserving rooms in hotels
* **Media** : A leading provider of movies for searching and watching on-demand
* **Banks** : Manage bank accounts and transfer money
* **Movies** : A go-to provider for finding movies, searching for show times and booking tickets
* **Music** : A popular provider of a wide range of music content for searching and listening
* **RentalCars** : Car rental service with extensive coverage of locations and cars
"""
import gradio as gr
gr.Interface(fn=predict,
inputs=["text", "state"],
outputs=["chatbot", "state"],
title="Chatting with multi task-oriented GPT2",
examples=examples,
description=description,
article=article
).launch()