import gradio as gr import torch from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline from config import load_config config = load_config("config.yaml") model_config = config["model_config"] model_name = model_config.pop("model_name") checkpoint_model = "checkpoint_dir/checkpoint-650" # Global variables for model and tokenizer model = None tokenizer = None pipe = None def load_model_and_tokenizer(): global model, tokenizer, pipe if model is None: print("Loading model and tokenizer...") # Convert torch_dtype from string to torch.dtype if "torch_dtype" in model_config: if model_config["torch_dtype"] == "float32": model_config["torch_dtype"] = torch.float32 elif model_config["torch_dtype"] == "float16": model_config["torch_dtype"] = torch.float16 elif model_config["torch_dtype"] == "bfloat16": model_config["torch_dtype"] = torch.bfloat16 # Load the model without quantization config model = AutoModelForCausalLM.from_pretrained( model_name, low_cpu_mem_usage=True, **model_config ) model.load_adapter(checkpoint_model) tokenizer = AutoTokenizer.from_pretrained(checkpoint_model, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" pipe = pipeline("text-generation", model=model, tokenizer=tokenizer) print("Model and tokenizer loaded successfully.") def respond(message, history): load_model_and_tokenizer() system_message = """You are General Knowledge Assistant. Answer the questions based on the provided information. Be succinct and use first-principles thinking to answer the questions.""" # Construct the chat list chat_list = [{"role": "system", "content": system_message}] for user, assistant in history: chat_list.extend( [ {"role": "user", "content": user}, {"role": "assistant", "content": assistant}, ] ) chat_list.append({"role": "user", "content": message}) prompt = pipe.tokenizer.apply_chat_template( chat_list, tokenize=False, add_generation_prompt=True ) outputs = pipe( prompt, max_new_tokens=256, num_beams=1, do_sample=True, temperature=0.3, top_p=0.95, top_k=50, ) new_text = outputs[0]["generated_text"][len(prompt) :] return new_text.strip() examples = [ ["Suggest some breeds that get along with each other"], ["Explain LLM in AI"], ["I want to explore Dubai. What are the best places to visit?"], ] demo = gr.ChatInterface( respond, textbox=gr.Textbox( placeholder="Enter your message here...", container=False, scale=7 ), examples=examples, title="General Knowledge Assistant", description="Ask me anything about general knowledge. I'll try to answer succinctly using first principles.", ) if __name__ == "__main__": demo.launch(debug=True)