File size: 3,503 Bytes
decda19
 
 
 
 
f74f783
 
297e1ab
decda19
 
 
 
f74f783
 
 
297e1ab
decda19
 
 
 
 
 
 
 
 
 
 
 
 
 
f74f783
decda19
 
 
 
 
 
 
297e1ab
decda19
 
 
 
 
297e1ab
f74f783
decda19
 
 
 
 
 
297e1ab
decda19
 
 
297e1ab
f74f783
 
 
 
 
 
 
 
decda19
f74f783
decda19
297e1ab
f74f783
 
 
 
297e1ab
f74f783
 
 
decda19
f74f783
 
 
 
 
 
 
297e1ab
f74f783
297e1ab
f74f783
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch
from transformers import AutoTokenizer, pipeline
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import gradio as gr
from typing import List

# Configuration
class Config:
    model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
    embedding_model = "all-MiniLM-L6-v2"
    vector_dim = 384  # Sentence Transformer embedding dimension
    top_k = 3  # Retrieve top 3 relevant chunks
    chunk_size = 256  # Text chunk size

# Vector Database
class VectorDB:
    def __init__(self):
        self.index = faiss.IndexFlatL2(Config.vector_dim)
        self.texts = []
        self.embedding_model = SentenceTransformer(Config.embedding_model)
    
    def add_text(self, text: str):
        embedding = self.embedding_model.encode([text])[0]
        embedding = np.array([embedding], dtype=np.float32)
        faiss.normalize_L2(embedding)
        self.index.add(embedding)
        self.texts.append(text)
    
    def search(self, query: str) -> List[str]:
        if self.index.ntotal == 0:
            return []
        query_embedding = self.embedding_model.encode([query])[0]
        query_embedding = np.array([query_embedding], dtype=np.float32)
        faiss.normalize_L2(query_embedding)
        D, I = self.index.search(query_embedding, min(Config.top_k, self.index.ntotal))
        return [self.texts[i] for i in I[0] if i < len(self.texts)]

# Load Model
class TinyChatModel:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained(Config.model_name)
        self.pipe = pipeline("text-generation", model=Config.model_name, torch_dtype=torch.bfloat16, device_map="auto")

    def generate_response(self, message: str, context: str = "") -> str:
        messages = [{"role": "user", "content": message}]
        if context:
            messages.insert(0, {"role": "system", "content": f"Context:\n{context}"})
        prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        outputs = self.pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
        return outputs[0]["generated_text"].split("<|assistant|>")[-1].strip()

# Initialize
vector_db = VectorDB()
chat_model = TinyChatModel()

# Function to handle context addition and chat
def chat_function(user_input: str, context: str = ""):
    if context:
        vector_db.add_text(context)
    
    # Search relevant context
    context_text = "\n".join(vector_db.search(user_input))
    response = chat_model.generate_response(user_input, context_text)
    vector_db.add_text(f"User: {user_input}\nAssistant: {response}")
    
    return response

# Gradio Interface
def gradio_interface(user_input: str, context: str = ""):
    response = chat_function(user_input, context)
    return response

# Create Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# TinyChat: A Conversational AI")
    with gr.Row():
        with gr.Column():
            user_input = gr.Textbox(label="User Input", placeholder="Ask anything...")
            context_input = gr.Textbox(label="Optional Context", placeholder="Paste context here (optional)", lines=3)
            submit_button = gr.Button("Send")
            output = gr.Textbox(label="Response", placeholder="Assistant's reply will appear here...")
            
            submit_button.click(fn=gradio_interface, inputs=[user_input, context_input], outputs=output)

# Run the Gradio app
if __name__ == "__main__":
    demo.launch(share=True)