chat-llm / app.py
Threatthriver's picture
Update app.py
b702fe6 verified
import gradio as gr
import torch
import gc
import threading
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load the tokenizer and model (lightweight model as per your suggestion)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
# Function to clean up memory
def clean_memory():
while True:
gc.collect() # Free up CPU memory
if device == "cuda":
torch.cuda.empty_cache() # Free up GPU memory
time.sleep(1) # Clean every second
# Start memory cleanup in a background thread
cleanup_thread = threading.Thread(target=clean_memory, daemon=True)
cleanup_thread.start()
def generate_response(message, history, max_tokens, temperature, top_p):
"""
Generates a response from the model.
"""
# Prepare conversation history as input
input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt").to(device)
# Generate the output using the model
with torch.no_grad():
output = model.generate(
input_ids,
max_length=max_tokens,
temperature=temperature,
top_p=top_p,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
history.append((message, response))
return history, ""
def update_chatbox(history, message, max_tokens, temperature, top_p):
"""
Update the chat history and generate the next AI response.
"""
history.append(("User", message)) # Add user message to history
history, _ = generate_response(message, history, max_tokens, temperature, top_p)
return history, "" # Return updated history and clear the user input
# Define the Gradio interface with the Blocks context
with gr.Blocks(css=".gradio-container {border: none;}") as demo:
chat_history = gr.State([]) # Initialize an empty chat history state
max_tokens = gr.Slider(minimum=1, maximum=1024, value=256, step=1, label="Max Tokens")
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)")
chatbot = gr.Chatbot(label="Character-like AI Chat")
user_input = gr.Textbox(show_label=False, placeholder="Type your message here...")
send_button = gr.Button("Send")
# When the send button is clicked, update chat history
send_button.click(
fn=update_chatbox,
inputs=[chat_history, user_input, max_tokens, temperature, top_p],
outputs=[chatbot, user_input], # Update chatbox and clear user input
queue=True # Ensure responses are shown in order
)
# Launch the Gradio interface
if __name__ == "__main__":
demo.launch(share=True)