ChatOPT / app.py
prasanna2003's picture
Update app.py
43789bc
import gradio as gr
import torch
from transformers import AutoTokenizer
class Pipline:
def __init__(self, model, tokenizer, device='cpu'):
self.device = device
self.model = model.to(self.device)
self.tokenizer = tokenizer
self.pre_prompt = "\n\nYou are a AI assistant who helps the user to solve their issue\n\n"
@torch.no_grad()
def respond(self, Instruction=None, input=None, temperature=0.8, max_length=200, do_sample=True, top_k=0, top_p=0.9, repetition_penalty=1.0, num_return_sequences=1, num_beams=1, early_stopping=False, use_cache=True, **generate_kwargs):
if not Instruction and not input:
raise ValueError("Either Instruction or input must be passed.")
query = f"""{self.pre_prompt}
Instruction: {Instruction if Instruction else ""}
Input: {input if input else ""}
Output:"""
inp_tokens_l = self.tokenizer(query, return_tensors='pt').input_ids
inp_tokens = inp_tokens_l.to(self.device)
out_tokens = self.model.generate(inp_tokens, max_length=max_length, temperature=temperature, do_sample=do_sample, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, num_return_sequences=num_return_sequences, num_beams=num_beams, early_stopping=early_stopping, use_cache=use_cache, **generate_kwargs)
out_text = self.tokenizer.batch_decode(out_tokens, skip_special_tokens=True)
# self.pre_prompt = out_text[0].split("<|endoftext|>")[0]
return out_text
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-125M")
model = torch.load('./model-cpu.pkl')
pipe = Pipline(model=model, tokenizer=tokenizer, device='cpu')
input_components = [
gr.inputs.Textbox(label='Instruction', placeholder='Enter instruction...'),
gr.inputs.Textbox(label='Input', placeholder='Enter input...'),
]
output_components = [
gr.outputs.Textbox(label='Output'),
]
def chatbot_response(Instruction, input, max_length, temperature):
output = pipe.respond(
Instruction=Instruction,
input=input,
max_length=int(max_length),
temperature=float(temperature),
)
return output[0]
interface = gr.Interface(
fn=chatbot_response,
inputs=input_components + [
gr.inputs.Slider(
label='Max Length',
minimum=10,
maximum=500,
step=10,
default=200,
),
gr.inputs.Slider(
label='Temperature',
minimum=0.1,
maximum=1.0,
step=0.1,
default=0.8,
),
],
outputs=output_components,
title='ChatOPT',
description='Type in an instruction and input, and get a response from the model',
)
interface.launch()