Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from config import load_config | |
config = load_config("config.yaml") | |
model_config = config["model_config"] | |
model_name = model_config.pop("model_name") | |
# Convert torch_dtype from string to torch.dtype | |
if "torch_dtype" in model_config: | |
if model_config["torch_dtype"] == "float32": | |
model_config["torch_dtype"] = torch.float32 | |
elif model_config["torch_dtype"] == "float16": | |
model_config["torch_dtype"] = torch.float16 | |
elif model_config["torch_dtype"] == "bfloat16": | |
model_config["torch_dtype"] = torch.bfloat16 | |
model = AutoModelForCausalLM.from_pretrained(model_name, **model_config) | |
checkpoint_model = "checkpoint_dir/checkpoint-650" | |
model.load_adapter(checkpoint_model) | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint_model, trust_remote_code=True) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.padding_side = "right" | |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
def respond( | |
message, | |
history, | |
system_message, | |
max_tokens=256, | |
temperature=0.3, | |
top_p=0.95, | |
): | |
# Construct the chat list | |
chat_list = [{"role": "system", "content": system_message}] | |
for user, assistant in history: | |
chat_list.extend( | |
[ | |
{"role": "user", "content": user}, | |
{"role": "assistant", "content": assistant}, | |
] | |
) | |
chat_list.append({"role": "user", "content": message}) | |
# Manually construct the prompt | |
prompt = "" | |
for chat in chat_list: | |
prompt += f"{chat['role']}: {chat['content']}\n" | |
prompt += "assistant:" | |
# Get the input length | |
input_ids = pipe.tokenizer.encode(prompt, return_tensors="pt") | |
input_length = input_ids.shape[1] | |
# Adjust max_length to be greater than input_length | |
max_length = input_length + max_tokens | |
# Ensure max_length is an integer | |
max_length = int(max_length) | |
outputs = pipe( | |
prompt, | |
max_new_tokens=int(max_tokens), # Ensure this is an integer | |
max_length=max_length, | |
num_beams=1, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=50 | |
) | |
new_text = outputs[0]['generated_text'][len(prompt):] | |
return new_text.strip() | |
examples = [ | |
["Suggest some breeds that get along with each other"], | |
["Explain LLM in AI"], | |
["I want to explore Dubai. What are the best places to visit?"], | |
] | |
demo = gr.ChatInterface( | |
respond, | |
textbox=gr.Textbox( | |
placeholder="Enter your message here...", container=False, scale=7 | |
), | |
examples=examples, | |
additional_inputs=[ | |
gr.Textbox( | |
value="You are General Knowledge Assistant. Answer the questions based on the provided information. Be succinct and use first-principles thinking to answer the questions.", | |
label="System message", | |
), | |
gr.Slider(minimum=1, maximum=2048, value=256, step=1, label="Max new tokens"), | |
gr.Slider(minimum=0.1, maximum=2.0, value=0.3, step=0.1, label="Temperature"), | |
gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p (nucleus sampling)", | |
), | |
], | |
title="General Knowledge Assistant", | |
description="Ask me anything about general knowledge. I'll try to answer succinctly using first principles.", | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |