import numpy as np import pickle import urllib from transformers import pipeline from transformers import AutoModelForMaskedLM, AutoTokenizer import gradio as gr import matplotlib.pyplot as plt plot_url = "https://huggingface.co/spaces/fvancesco/test_time_1.1/resolve/main/plot_example.p" dates = [] dates.extend([f"18 {m}" for m in range(1,13)]) dates.extend([f"19 {m}" for m in range(1,13)]) dates.extend([f"20 {m}" for m in range(1,13)]) dates.extend([f"21 {m}" for m in range(1,13)]) months = [x.split(" ")[-1] for x in dates] model_name = "fvancesco/tmp_date" model = AutoModelForMaskedLM.from_pretrained(model_name) tokenizer = AutoTokenizer.from_pretrained(model_name) model.eval() #pipe = pipeline('fill-mask', model=model, tokenizer=tokenizer, device=0) pipe = pipeline('fill-mask', model=model, tokenizer=tokenizer) def get_mf_dict(text): # predictions texts = [] for d in dates: texts.append(f"{d} {text}") tmp_preds = pipe(texts, top_k=50265) preds = {} for i in range(len(tmp_preds)): preds[dates[i]] = tmp_preds[i] # get preds summary (only top words) top_n = 5 # top n for each prediction most_freq_tokens = set() for d in dates: tmp = [t['token_str'] for t in preds[d][:top_n]] most_freq_tokens.update(tmp) token_prob = {} for d in dates: token_prob[d] = {p['token_str']:p['score'] for p in preds[d]} mf_dict = {p:np.zeros(len(dates)) for p in most_freq_tokens} c=0 for d in dates: for t in most_freq_tokens: mf_dict[t][c] = token_prob[d][t] c+=1 return mf_dict def plot_time(text): mf_dict = get_mf_dict(text) #max_tokens = 10 fig = plt.figure(figsize=(16,9)) ax = fig.add_subplot(111) #fig, ax = plt.subplots(figsize=(16,9)) x = [i for i in range(len(dates))] ax.set_xlabel('Month') ax.set_xlim(0) ax.set_xticks(x) ax.set_xticklabels(months) # ax.set_yticks([-1,0,1]) ax2 = ax.twiny() ax2.set_xlabel('Year') ax2.set_xlim(0) ax2.set_xticks([0,12,24,36,47]) ax2.set_xticklabels('') ax2.set_xticks([6,18,30,42,47], minor=True) ax2.set_xticklabels(['2018','2019','2020','2021',''], minor=True) ax2.grid() # plot lines for k in mf_dict.keys(): ax.plot(x, mf_dict[k], label = k) # k = list(mf_dict.keys()) # for i in range(max_tokens): # ax.plot(x, mf_dict[k[i]], label = k[i]) ax.legend(loc='center left', bbox_to_anchor=(1.0, 0.5)) return fig def add_mask(text): out = "" if len(text) == 0 or text[-1] == " ": out = text+"" else: out = text+" " return out with gr.Blocks() as demo: textbox = gr.Textbox(value="Happy !", max_lines=1) with gr.Row(): generate_btn = gr.Button("Generate Plot") mask_btn = gr.Button("Add ") # plot (with starting example already loaded) f = urllib.request.urlopen(plot_url) plot_example = pickle.load(f) plot = gr.Plot(plot_example) #textbox.change(fn=plot_time, inputs=textbox, outputs=plot) generate_btn.click(fn=plot_time, inputs=textbox, outputs=plot) mask_btn.click(fn=add_mask, inputs=textbox, outputs=textbox) demo.launch(debug=True)