Spaces:
Paused
Paused
import gradio as gr | |
import plotly.express as px | |
import os | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BlenderbotForConditionalGeneration | |
# Check if CUDA is available and set device accordingly | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# Set environment variables for GPU usage and memory allocation if CUDA is available | |
if device == "cuda": | |
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True' | |
torch.cuda.empty_cache() | |
torch.cuda.set_per_process_memory_fraction(0.8) # Adjust the fraction as needed | |
# System message (placeholder, adjust as needed) | |
system_message = "" | |
# Load the model and tokenizer | |
def hermes_model(): | |
tokenizer = AutoTokenizer.from_pretrained("TheBloke/CapybaraHermes-2.5-Mistral-7B-AWQ") | |
model = AutoModelForCausalLM.from_pretrained("TheBloke/CapybaraHermes-2.5-Mistral-7B-AWQ", low_cpu_mem_usage=True, device_map="auto") | |
return model, tokenizer | |
def blender_model(): | |
model = BlenderbotForConditionalGeneration.from_pretrained("facebook/blenderbot-400M-distill") | |
tokenizer = AutoTokenizer.from_pretrained("facebook/blenderbot-400M-distill") | |
return model, tokenizer | |
model, tokenizer = blender_model() | |
def chat_response(msg_prompt: str) -> str: | |
try: | |
inputs = tokenizer(msg_prompt, return_tensors="pt") | |
reply_ids = model.generate(**inputs) | |
outputs = tokenizer.batch_decode(reply_ids, skip_special_tokens=True)[0] | |
return outputs | |
except Exception as e: | |
return str(e) | |
# Function to generate a response from the model | |
def chat_responses(msg_prompt: str) -> str: | |
""" | |
Generates a response from the model given a prompt. | |
Args: | |
msg_prompt (str): The user's message prompt. | |
Returns: | |
str: The model's response. | |
""" | |
generation_params = { | |
"do_sample": True, | |
"temperature": 0.7, | |
"top_p": 0.95, | |
"top_k": 40, | |
"max_new_tokens": 512, | |
"repetition_penalty": 1.1, | |
} | |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, **generation_params) | |
try: | |
prompt_template = f'''system | |
{system_message} | |
user | |
{msg_prompt} | |
assistant | |
''' | |
pipe_output = pipe(prompt_template)[0]['generated_text'] | |
# Separate assistant's response from the output | |
response_lines = pipe_output.split('assistant') | |
assistant_response = response_lines[-1].strip() if len(response_lines) > 1 else pipe_output.strip() | |
return assistant_response | |
except Exception as e: | |
return str(e) | |
# Function to generate a random plot | |
def random_plot(): | |
df = px.data.iris() | |
fig = px.scatter(df, x="sepal_width", y="sepal_length", color="species", | |
size='petal_length', hover_data=['petal_width']) | |
return fig | |
# Function to handle likes/dislikes (for demonstration purposes) | |
def print_like_dislike(x: gr.LikeData): | |
print(x.index, x.value, x.liked) | |
# Function to add messages to the chat history | |
def add_message(history, message, files): | |
if files is not None: | |
for file in files: | |
history.append(((file,), None)) | |
if message is not None: | |
history.append((message, None)) | |
return history, gr.update(value=None, interactive=True) | |
# Function to simulate the bot response | |
def bot(history): | |
if history: | |
user_message = history[-1][0] | |
bot_response = chat_response(user_message) | |
history[-1][1] = bot_response | |
return history | |
fig = random_plot() | |
# Gradio interface setup | |
with gr.Blocks(fill_height=True) as demo: | |
chatbot = gr.Chatbot(elem_id="chatbot", bubble_full_width=False, scale=1) | |
with gr.Row(): | |
chat_input = gr.Textbox(placeholder="Enter message...", show_label=False) | |
file_input = gr.File(label="Upload file(s)", file_count="multiple") | |
chat_msg = chat_input.submit(add_message, [chatbot, chat_input, file_input], [chatbot, chat_input]) | |
bot_msg = chat_msg.then(bot, chatbot, chatbot) | |
bot_msg.then(lambda: gr.update(interactive=True), None, [chat_input]) | |
chatbot.like(print_like_dislike, None, None) | |
demo.queue() | |
demo.launch() |