tinyChat / app.py
aliMohammad16's picture
Update app.py
f74f783 verified
import torch
from transformers import AutoTokenizer, pipeline
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import gradio as gr
from typing import List
# Configuration
class Config:
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
embedding_model = "all-MiniLM-L6-v2"
vector_dim = 384 # Sentence Transformer embedding dimension
top_k = 3 # Retrieve top 3 relevant chunks
chunk_size = 256 # Text chunk size
# Vector Database
class VectorDB:
def __init__(self):
self.index = faiss.IndexFlatL2(Config.vector_dim)
self.texts = []
self.embedding_model = SentenceTransformer(Config.embedding_model)
def add_text(self, text: str):
embedding = self.embedding_model.encode([text])[0]
embedding = np.array([embedding], dtype=np.float32)
faiss.normalize_L2(embedding)
self.index.add(embedding)
self.texts.append(text)
def search(self, query: str) -> List[str]:
if self.index.ntotal == 0:
return []
query_embedding = self.embedding_model.encode([query])[0]
query_embedding = np.array([query_embedding], dtype=np.float32)
faiss.normalize_L2(query_embedding)
D, I = self.index.search(query_embedding, min(Config.top_k, self.index.ntotal))
return [self.texts[i] for i in I[0] if i < len(self.texts)]
# Load Model
class TinyChatModel:
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained(Config.model_name)
self.pipe = pipeline("text-generation", model=Config.model_name, torch_dtype=torch.bfloat16, device_map="auto")
def generate_response(self, message: str, context: str = "") -> str:
messages = [{"role": "user", "content": message}]
if context:
messages.insert(0, {"role": "system", "content": f"Context:\n{context}"})
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
outputs = self.pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
return outputs[0]["generated_text"].split("<|assistant|>")[-1].strip()
# Initialize
vector_db = VectorDB()
chat_model = TinyChatModel()
# Function to handle context addition and chat
def chat_function(user_input: str, context: str = ""):
if context:
vector_db.add_text(context)
# Search relevant context
context_text = "\n".join(vector_db.search(user_input))
response = chat_model.generate_response(user_input, context_text)
vector_db.add_text(f"User: {user_input}\nAssistant: {response}")
return response
# Gradio Interface
def gradio_interface(user_input: str, context: str = ""):
response = chat_function(user_input, context)
return response
# Create Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# TinyChat: A Conversational AI")
with gr.Row():
with gr.Column():
user_input = gr.Textbox(label="User Input", placeholder="Ask anything...")
context_input = gr.Textbox(label="Optional Context", placeholder="Paste context here (optional)", lines=3)
submit_button = gr.Button("Send")
output = gr.Textbox(label="Response", placeholder="Assistant's reply will appear here...")
submit_button.click(fn=gradio_interface, inputs=[user_input, context_input], outputs=output)
# Run the Gradio app
if __name__ == "__main__":
demo.launch(share=True)