Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
import gc | |
import threading | |
import time | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# Load the tokenizer and model (lightweight model as per your suggestion) | |
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") | |
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model = model.to(device) | |
# Function to clean up memory | |
def clean_memory(): | |
while True: | |
gc.collect() # Free up CPU memory | |
if device == "cuda": | |
torch.cuda.empty_cache() # Free up GPU memory | |
time.sleep(1) # Clean every second | |
# Start memory cleanup in a background thread | |
cleanup_thread = threading.Thread(target=clean_memory, daemon=True) | |
cleanup_thread.start() | |
def generate_response(message, history, max_tokens, temperature, top_p): | |
""" | |
Generates a response from the model. | |
""" | |
# Prepare conversation history as input | |
input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt").to(device) | |
# Generate the output using the model | |
with torch.no_grad(): | |
output = model.generate( | |
input_ids, | |
max_length=max_tokens, | |
temperature=temperature, | |
top_p=top_p, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True) | |
history.append((message, response)) | |
return history, "" | |
def update_chatbox(history, message, max_tokens, temperature, top_p): | |
""" | |
Update the chat history and generate the next AI response. | |
""" | |
history.append(("User", message)) # Add user message to history | |
history, _ = generate_response(message, history, max_tokens, temperature, top_p) | |
return history, "" # Return updated history and clear the user input | |
# Define the Gradio interface with the Blocks context | |
with gr.Blocks(css=".gradio-container {border: none;}") as demo: | |
chat_history = gr.State([]) # Initialize an empty chat history state | |
max_tokens = gr.Slider(minimum=1, maximum=1024, value=256, step=1, label="Max Tokens") | |
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature") | |
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)") | |
chatbot = gr.Chatbot(label="Character-like AI Chat") | |
user_input = gr.Textbox(show_label=False, placeholder="Type your message here...") | |
send_button = gr.Button("Send") | |
# When the send button is clicked, update chat history | |
send_button.click( | |
fn=update_chatbox, | |
inputs=[chat_history, user_input, max_tokens, temperature, top_p], | |
outputs=[chatbot, user_input], # Update chatbox and clear user input | |
queue=True # Ensure responses are shown in order | |
) | |
# Launch the Gradio interface | |
if __name__ == "__main__": | |
demo.launch(share=True) | |