File size: 3,011 Bytes
c67e035
29567f1
b702fe6
 
 
29567f1
c67e035
29567f1
 
 
c67e035
29567f1
 
 
b702fe6
 
 
 
 
 
 
 
 
 
 
 
29567f1
09fa947
29567f1
09fa947
29567f1
 
 
 
 
 
 
 
7476b7e
 
29567f1
 
 
 
 
 
c51870d
29567f1
0151088
3e4a10f
0151088
3e4a10f
29567f1
3e4a10f
7476b7e
0151088
e18c985
66980c9
29567f1
0151088
29567f1
 
3e4a10f
1c9a4c1
3e4a10f
 
e84e0fa
3e4a10f
 
 
29567f1
3e4a10f
 
e84e0fa
 
0151088
c67e035
0151088
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import gradio as gr
import torch
import gc
import threading
import time
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load the tokenizer and model (lightweight model as per your suggestion)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")

device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)

# Function to clean up memory
def clean_memory():
    while True:
        gc.collect()  # Free up CPU memory
        if device == "cuda":
            torch.cuda.empty_cache()  # Free up GPU memory
        time.sleep(1)  # Clean every second

# Start memory cleanup in a background thread
cleanup_thread = threading.Thread(target=clean_memory, daemon=True)
cleanup_thread.start()

def generate_response(message, history, max_tokens, temperature, top_p):
    """
    Generates a response from the model.
    """
    # Prepare conversation history as input
    input_ids = tokenizer.encode(message + tokenizer.eos_token, return_tensors="pt").to(device)

    # Generate the output using the model
    with torch.no_grad():
        output = model.generate(
            input_ids,
            max_length=max_tokens,
            temperature=temperature,
            top_p=top_p,
            pad_token_id=tokenizer.eos_token_id,
        )
    
    response = tokenizer.decode(output[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
    history.append((message, response))
    return history, ""

def update_chatbox(history, message, max_tokens, temperature, top_p):
    """
    Update the chat history and generate the next AI response.
    """
    history.append(("User", message))  # Add user message to history
    history, _ = generate_response(message, history, max_tokens, temperature, top_p)
    return history, ""  # Return updated history and clear the user input

# Define the Gradio interface with the Blocks context
with gr.Blocks(css=".gradio-container {border: none;}") as demo:
    chat_history = gr.State([])  # Initialize an empty chat history state
    max_tokens = gr.Slider(minimum=1, maximum=1024, value=256, step=1, label="Max Tokens")
    temperature = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature")
    top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="Top-p (nucleus sampling)")

    chatbot = gr.Chatbot(label="Character-like AI Chat")

    user_input = gr.Textbox(show_label=False, placeholder="Type your message here...")
    send_button = gr.Button("Send")

    # When the send button is clicked, update chat history
    send_button.click(
        fn=update_chatbox,
        inputs=[chat_history, user_input, max_tokens, temperature, top_p],
        outputs=[chatbot, user_input],  # Update chatbox and clear user input
        queue=True  # Ensure responses are shown in order
    )

# Launch the Gradio interface
if __name__ == "__main__":
    demo.launch(share=True)