Spaces:
Paused
Paused
import gradio as gr | |
from transformers import AutoTokenizer | |
from auto_gptq import AutoGPTQForCausalLM | |
import torch | |
# Ensure Torch uses CPU | |
torch.set_default_device("cpu") | |
# Model details | |
model_name_or_path = "TheBloke/Wizard-Vicuna-7B-Uncensored-SuperHOT-8K-GPTQ" | |
model_basename = "model" # Match uploaded file basename | |
# Load the tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True) | |
model = AutoGPTQForCausalLM.from_quantized( | |
model_name_or_path, | |
model_basename=model_basename, | |
trust_remote_code=True, | |
device_map="auto", # Use auto to ensure compatibility with CPU | |
use_safetensors=True, | |
torch_dtype=torch.float32, # Use float32 for CPU compatibility | |
quantize_config=None, | |
) | |
# Core personality prompt | |
core_personality = """You are Vespa Companion, a witty and knowledgeable AI designed to engage in thoughtful and helpful conversations about life, technology, and Vespa scooters.""" | |
# Function to generate a response | |
def generate_response(input_text): | |
try: | |
# Prepare the input prompt | |
prompt_template = f"{core_personality}\n\nUSER: {input_text}\nASSISTANT:" | |
input_ids = tokenizer(prompt_template, return_tensors="pt").input_ids # Use CPU | |
# Define generation configuration | |
generation_config = { | |
"max_new_tokens": 1024, # Allow longer responses if memory permits | |
"temperature": 0.7, | |
"top_p": 0.95, | |
"repetition_penalty": 1.2, | |
} | |
# Generate the response | |
output_ids = model.generate( | |
input_ids=input_ids, | |
max_new_tokens=generation_config["max_new_tokens"], | |
temperature=generation_config["temperature"], | |
top_p=generation_config["top_p"], | |
repetition_penalty=generation_config["repetition_penalty"], | |
) | |
# Decode and clean the output | |
raw_output = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
# Extract only the assistant's response | |
if "ASSISTANT:" in raw_output: | |
response = raw_output.split("ASSISTANT:")[-1].strip() | |
else: | |
response = raw_output.strip() | |
return response | |
except Exception as e: | |
return f"Response Generation Error: {e}" | |
# Chat function for Gradio | |
def chat_with_memory(history, user_input): | |
# Add user input to history | |
history.append({"role": "user", "content": user_input}) | |
# Generate response | |
response = generate_response(user_input) | |
history.append({"role": "assistant", "content": response}) | |
# Format history for display | |
display_history = [ | |
(entry["content"] if entry["role"] == "user" else None, entry["content"] if entry["role"] == "assistant" else None) | |
for entry in history | |
] | |
return display_history, history | |
# Gradio app setup | |
with gr.Blocks() as demo: | |
gr.Markdown("## Vespa Companion - Intelligent Chatbot") | |
chatbot = gr.Chatbot(label="Chat with Vespa Companion") | |
with gr.Row(): | |
msg = gr.Textbox(placeholder="Type your message here...", label="Your Message") | |
clear = gr.Button("Clear Conversation") | |
history = gr.State([]) # Initialize conversation history | |
msg.submit(chat_with_memory, [history, msg], [chatbot, history]) | |
clear.click(lambda: ([], []), None, [chatbot, history]) | |
# Launch the app | |
demo.launch() |