import gradio as gr import torch import gc import threading import time from transformers import AutoTokenizer, AutoModelForCausalLM # Load the tokenizer and model (lightweight model as per your suggestion) tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") device = "cuda" if torch.cuda.is_available() else "cpu" model = model.to(device) # Function to clean up memory def clean_memory(): while True: gc.collect() # Free up CPU memory if device == "cuda": torch.cuda.empty_cache() # Free up GPU memory time.sleep(1) # Clean every second # Start memory cleanup in a background thread cleanup_thread = threading.Thread(target=clean_memory, daemon=True) cleanup_thread.start() def generate_response(message, history, max_tokens, temperature, top_p): """ Generates a response from the model. """ # Prepare conversation history as input input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt").to(device) # Generate the output using the model with torch.no_grad(): output = model.generate( input_ids, max_length=max_tokens, temperature=temperature, top_p=top_p, pad_token_id=tokenizer.eos_token_id, ) response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True) history.append((message, response)) return history, "" def update_chatbox(history, message, max_tokens, temperature, top_p): """ Update the chat history and generate the next AI response. """ history.append(("User", message)) # Add user message to history history, _ = generate_response(message, history, max_tokens, temperature, top_p) return history, "" # Return updated history and clear the user input # Define the Gradio interface with the Blocks context with gr.Blocks(css=".gradio-container {border: none;}") as demo: chat_history = gr.State([]) # Initialize an empty chat history state max_tokens = gr.Slider(minimum=1, maximum=1024, value=256, step=1, label="Max Tokens") temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature") top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)") chatbot = gr.Chatbot(label="Character-like AI Chat") user_input = gr.Textbox(show_label=False, placeholder="Type your message here...") send_button = gr.Button("Send") # When the send button is clicked, update chat history send_button.click( fn=update_chatbox, inputs=[chat_history, user_input, max_tokens, temperature, top_p], outputs=[chatbot, user_input], # Update chatbox and clear user input queue=True # Ensure responses are shown in order ) # Launch the Gradio interface if __name__ == "__main__": demo.launch(share=True)