Aakash Vardhan
.
9f6d492
raw
history blame
3.53 kB
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from config import load_config
config = load_config("config.yaml")
model_config = config["model_config"]
model_name = model_config.pop("model_name")
# Convert torch_dtype from string to torch.dtype
if "torch_dtype" in model_config:
if model_config["torch_dtype"] == "float32":
model_config["torch_dtype"] = torch.float32
elif model_config["torch_dtype"] == "float16":
model_config["torch_dtype"] = torch.float16
elif model_config["torch_dtype"] == "bfloat16":
model_config["torch_dtype"] = torch.bfloat16
model = AutoModelForCausalLM.from_pretrained(model_name, **model_config)
checkpoint_model = "checkpoint_dir/checkpoint-650"
model.load_adapter(checkpoint_model)
tokenizer = AutoTokenizer.from_pretrained(checkpoint_model, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "right"
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer)
def respond(
message,
history,
system_message,
max_tokens=256,
temperature=0.3,
top_p=0.95,
):
# Construct the chat list
chat_list = [{"role": "system", "content": system_message}]
for user, assistant in history:
chat_list.extend(
[
{"role": "user", "content": user},
{"role": "assistant", "content": assistant},
]
)
chat_list.append({"role": "user", "content": message})
# Manually construct the prompt
prompt = ""
for chat in chat_list:
prompt += f"{chat['role']}: {chat['content']}\n"
prompt += "assistant:"
# Get the input length
input_ids = pipe.tokenizer.encode(prompt, return_tensors="pt")
input_length = input_ids.shape[1]
# Adjust max_length to be greater than input_length
max_length = input_length + max_tokens
# Ensure max_length is an integer
max_length = int(max_length)
outputs = pipe(
prompt,
max_new_tokens=int(max_tokens), # Ensure this is an integer
max_length=max_length,
num_beams=1,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=50
)
new_text = outputs[0]['generated_text'][len(prompt):]
return new_text.strip()
examples = [
["Suggest some breeds that get along with each other"],
["Explain LLM in AI"],
["I want to explore Dubai. What are the best places to visit?"],
]
demo = gr.ChatInterface(
respond,
textbox=gr.Textbox(
placeholder="Enter your message here...", container=False, scale=7
),
examples=examples,
additional_inputs=[
gr.Textbox(
value="You are General Knowledge Assistant. Answer the questions based on the provided information. Be succinct and use first-principles thinking to answer the questions.",
label="System message",
),
gr.Slider(minimum=1, maximum=2048, value=256, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.3, step=0.1, label="Temperature"),
gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (nucleus sampling)",
),
],
title="General Knowledge Assistant",
description="Ask me anything about general knowledge. I'll try to answer succinctly using first principles.",
)
if __name__ == "__main__":
demo.launch(debug=True)