#!/usr/bin/env python # coding: utf-8 # In[ ]: from transformers import AutoModelForCausalLM, AutoTokenizer import torch import gradio as gr import re def cleaning_history_tuple(history): s=sum(history,()) s=list(s) s2="" for i in s: i=re.sub("\n", '', i) i=re.sub("
", '', i) i=re.sub("
", '', i) s2=s2+i+'\n' return s2 def ai_output(string1,string2): a1=len(string1) a2=len(string2) string3=string2[a1:] sub1="A:" sub2="User" #sub3="\n" try: try: idx1=string3.index(sub1) response=string3[:idx1] return response except: idx1=string3.index(sub2) response=string3[:idx1] return response except: return string3 model4 = AutoModelForCausalLM.from_pretrained("bigscience/bloom-3b") tokenizer4 = AutoTokenizer.from_pretrained("bigscience/bloom-3b") def predict(input,initial_prompt, temperature=0.7,top_p=1,top_k=5,max_tokens=64,no_repeat_ngram_size=1,num_beams=6,do_sample=True, history=[]): s = cleaning_history_tuple(history) s = s+ "\n"+ "User: "+ input + "\n" + "Assistant: " s2=initial_prompt+" " + s input_ids = tokenizer4.encode(str(s2), return_tensors="pt") response = model4.generate(input_ids, min_length = 10, max_new_tokens=int(max_tokens), top_k=int(top_k), top_p=float(top_p), temperature=float(temperature), no_repeat_ngram_size=int(no_repeat_ngram_size), num_beams = int(num_beams), do_sample = bool(do_sample), ) response2 = tokenizer4.decode(response[0]) print("Response after decoding tokenizer: ",response2) print("\n\n") response3=ai_output(s2,response2) input="User: "+input response3="Assistant: "+ response3 history.append((input, response3)) return history, history #gr.Interface(fn=predict,title="BLOOM-3b", # inputs=["text","text","text","text","text","text","text","text","text",'state'], # # outputs=["chatbot",'state']).launch() gr.Interface(inputs=[gr.Textbox(label="input", lines=1, value=""), gr.Textbox(label="initial_prompt", lines=1, value=prompt), gr.Textbox(label="temperature", lines=1, value=0.7), gr.Textbox(label="top_p", lines=1, value=1), gr.Textbox(label="top_k", lines=1, value=5), gr.Textbox(label="max_tokens", lines=1, value=64), gr.Textbox(label="no_repeat_ngram_size", lines=1, value=1), gr.Textbox(label="num_beams", lines=1, value=6), gr.Textbox(label="do_sample", lines=1, value="True"), 'state'], fn=predict, title="OPT-6.7B", outputs=["chatbot",'state'] #inputs=["text","text","text","text","text","text","text","text","text",'state'], ).launch()