Spaces:
Runtime error
Runtime error
File size: 3,011 Bytes
c67e035 29567f1 b702fe6 29567f1 c67e035 29567f1 c67e035 29567f1 b702fe6 29567f1 09fa947 29567f1 09fa947 29567f1 7476b7e 29567f1 c51870d 29567f1 0151088 3e4a10f 0151088 3e4a10f 29567f1 3e4a10f 7476b7e 0151088 e18c985 66980c9 29567f1 0151088 29567f1 3e4a10f 1c9a4c1 3e4a10f e84e0fa 3e4a10f 29567f1 3e4a10f e84e0fa 0151088 c67e035 0151088 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 |
import gradio as gr
import torch
import gc
import threading
import time
from transformers import AutoTokenizer, AutoModelForCausalLM
# Load the tokenizer and model (lightweight model as per your suggestion)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
# Function to clean up memory
def clean_memory():
while True:
gc.collect() # Free up CPU memory
if device == "cuda":
torch.cuda.empty_cache() # Free up GPU memory
time.sleep(1) # Clean every second
# Start memory cleanup in a background thread
cleanup_thread = threading.Thread(target=clean_memory, daemon=True)
cleanup_thread.start()
def generate_response(message, history, max_tokens, temperature, top_p):
"""
Generates a response from the model.
"""
# Prepare conversation history as input
input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt").to(device)
# Generate the output using the model
with torch.no_grad():
output = model.generate(
input_ids,
max_length=max_tokens,
temperature=temperature,
top_p=top_p,
pad_token_id=tokenizer.eos_token_id,
)
response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
history.append((message, response))
return history, ""
def update_chatbox(history, message, max_tokens, temperature, top_p):
"""
Update the chat history and generate the next AI response.
"""
history.append(("User", message)) # Add user message to history
history, _ = generate_response(message, history, max_tokens, temperature, top_p)
return history, "" # Return updated history and clear the user input
# Define the Gradio interface with the Blocks context
with gr.Blocks(css=".gradio-container {border: none;}") as demo:
chat_history = gr.State([]) # Initialize an empty chat history state
max_tokens = gr.Slider(minimum=1, maximum=1024, value=256, step=1, label="Max Tokens")
temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)")
chatbot = gr.Chatbot(label="Character-like AI Chat")
user_input = gr.Textbox(show_label=False, placeholder="Type your message here...")
send_button = gr.Button("Send")
# When the send button is clicked, update chat history
send_button.click(
fn=update_chatbox,
inputs=[chat_history, user_input, max_tokens, temperature, top_p],
outputs=[chatbot, user_input], # Update chatbox and clear user input
queue=True # Ensure responses are shown in order
)
# Launch the Gradio interface
if __name__ == "__main__":
demo.launch(share=True)
|