Spaces:
Sleeping
Sleeping
import numpy as np | |
import pickle | |
import urllib | |
from transformers import pipeline | |
from transformers import AutoModelForMaskedLM, AutoTokenizer | |
import gradio as gr | |
import matplotlib.pyplot as plt | |
plot_url = "https://huggingface.co/spaces/fvancesco/test_time_1.1/resolve/main/plot_example.p" | |
dates = [] | |
dates.extend([f"18 {m}" for m in range(1,13)]) | |
dates.extend([f"19 {m}" for m in range(1,13)]) | |
dates.extend([f"20 {m}" for m in range(1,13)]) | |
dates.extend([f"21 {m}" for m in range(1,13)]) | |
months = [x.split(" ")[-1] for x in dates] | |
model_name = "fvancesco/tmp_date" | |
model = AutoModelForMaskedLM.from_pretrained(model_name) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model.eval() | |
#pipe = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=0) | |
pipe = pipeline('fill-mask', model=model, tokenizer=tokenizer) | |
def get_mf_dict(text): | |
# predictions | |
texts = [] | |
for d in dates: | |
texts.append(f"{d} {text}") | |
tmp_preds = pipe(texts, top_k=50265) | |
preds = {} | |
for i in range(len(tmp_preds)): | |
preds[dates[i]] = tmp_preds[i] | |
# get preds summary (only top words) | |
top_n = 5 # top n for each prediction | |
most_freq_tokens = set() | |
for d in dates: | |
tmp = [t['token_str'] for t in preds[d][:top_n]] | |
most_freq_tokens.update(tmp) | |
token_prob = {} | |
for d in dates: | |
token_prob[d] = {p['token_str']:p['score'] for p in preds[d]} | |
mf_dict = {p:np.zeros(len(dates)) for p in most_freq_tokens} | |
c=0 | |
for d in dates: | |
for t in most_freq_tokens: | |
mf_dict[t][c] = token_prob[d][t] | |
c+=1 | |
return mf_dict | |
def plot_time(text): | |
mf_dict = get_mf_dict(text) | |
#max_tokens = 10 | |
fig = plt.figure(figsize=(16,9)) | |
ax = fig.add_subplot(111) | |
#fig, ax = plt.subplots(figsize=(16,9)) | |
x = [i for i in range(len(dates))] | |
ax.set_xlabel('Month') | |
ax.set_xlim(0) | |
ax.set_xticks(x) | |
ax.set_xticklabels(months) | |
# ax.set_yticks([-1,0,1]) | |
ax2 = ax.twiny() | |
ax2.set_xlabel('Year') | |
ax2.set_xlim(0) | |
ax2.set_xticks([0,12,24,36,47]) | |
ax2.set_xticklabels('') | |
ax2.set_xticks([6,18,30,42,47], minor=True) | |
ax2.set_xticklabels(['2018','2019','2020','2021',''], minor=True) | |
ax2.grid() | |
# plot lines | |
for k in mf_dict.keys(): | |
ax.plot(x, mf_dict[k], label = k) | |
# k = list(mf_dict.keys()) | |
# for i in range(max_tokens): | |
# ax.plot(x, mf_dict[k[i]], label = k[i]) | |
ax.legend(loc='center left', bbox_to_anchor=(1.0, 0.5)) | |
return fig | |
def add_mask(text): | |
out = "" | |
if len(text) == 0 or text[-1] == " ": | |
out = text+"<mask>" | |
else: | |
out = text+" <mask>" | |
return out | |
with gr.Blocks() as demo: | |
text_description=""" | |
# TimeLMs Demo | |
This is a demo for **timeLMs**: | |
- [Github](https://github.com/cardiffnlp/timelms) | |
- [Paper](https://aclanthology.org/2022.acl-demo.25.pdf) | |
Input any text with a *\<mask\>* token as in the example, and (the demo does not | |
use GPUs, and it takes about 1 min). In the graph, we show the probability of | |
some token candidates for mask over different months. | |
In this demo we run use a roberta-base model trained on tweets, where the first two | |
tokens are the year and the month ("21 1" for January 2021). It was trained | |
for tweets between January 2018 to December 2021). | |
""" | |
description = gr.Markdown(text_description) | |
textbox = gr.Textbox(value="Happy <mask>!", max_lines=1) | |
with gr.Row(): | |
generate_btn = gr.Button("Generate Plot") | |
mask_btn = gr.Button("Add <mask>") | |
# plot (with starting example already loaded) | |
f = urllib.request.urlopen(plot_url) | |
plot_example = pickle.load(f) | |
plot = gr.Plot(plot_example) | |
#textbox.change(fn=plot_time, inputs=textbox, outputs=plot) | |
generate_btn.click(fn=plot_time, inputs=textbox, outputs=plot) | |
mask_btn.click(fn=add_mask, inputs=textbox, outputs=textbox) | |
demo.launch(debug=True) |