Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from config import load_config | |
config = load_config("config.yaml") | |
model_config = config["model_config"] | |
model_name = model_config.pop("model_name") | |
checkpoint_model = "checkpoint_dir/checkpoint-650" | |
# Global variables for model and tokenizer | |
model = None | |
tokenizer = None | |
pipe = None | |
def load_model_and_tokenizer(): | |
global model, tokenizer, pipe | |
if model is None: | |
print("Loading model and tokenizer...") | |
# Convert torch_dtype from string to torch.dtype | |
if "torch_dtype" in model_config: | |
if model_config["torch_dtype"] == "float32": | |
model_config["torch_dtype"] = torch.float32 | |
elif model_config["torch_dtype"] == "float16": | |
model_config["torch_dtype"] = torch.float16 | |
elif model_config["torch_dtype"] == "bfloat16": | |
model_config["torch_dtype"] = torch.bfloat16 | |
# Load the model without quantization config | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
low_cpu_mem_usage=True, | |
**model_config | |
) | |
model.load_adapter(checkpoint_model) | |
tokenizer = AutoTokenizer.from_pretrained(checkpoint_model, trust_remote_code=True) | |
tokenizer.pad_token = tokenizer.eos_token | |
tokenizer.padding_side = "right" | |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) | |
print("Model and tokenizer loaded successfully.") | |
def respond(message, history): | |
load_model_and_tokenizer() | |
system_message = """You are General Knowledge Assistant. | |
Answer the questions based on the provided information. | |
Be succinct and use first-principles thinking to answer the questions.""" | |
# Construct the chat list | |
chat_list = [{"role": "system", "content": system_message}] | |
for user, assistant in history: | |
chat_list.extend( | |
[ | |
{"role": "user", "content": user}, | |
{"role": "assistant", "content": assistant}, | |
] | |
) | |
chat_list.append({"role": "user", "content": message}) | |
prompt = pipe.tokenizer.apply_chat_template( | |
chat_list, tokenize=False, add_generation_prompt=True | |
) | |
outputs = pipe( | |
prompt, | |
max_new_tokens=256, | |
num_beams=1, | |
do_sample=True, | |
temperature=0.3, | |
top_p=0.95, | |
top_k=50, | |
) | |
new_text = outputs[0]["generated_text"][len(prompt) :] | |
return new_text.strip() | |
examples = [ | |
["Suggest some breeds that get along with each other"], | |
["Explain LLM in AI"], | |
["I want to explore Dubai. What are the best places to visit?"], | |
] | |
demo = gr.ChatInterface( | |
respond, | |
textbox=gr.Textbox( | |
placeholder="Enter your message here...", container=False, scale=7 | |
), | |
examples=examples, | |
title="General Knowledge Assistant", | |
description="Ask me anything about general knowledge. I'll try to answer succinctly using first principles.", | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |