Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoTokenizer, pipeline | |
from sentence_transformers import SentenceTransformer | |
import faiss | |
import numpy as np | |
import gradio as gr | |
from typing import List | |
# Configuration | |
class Config: | |
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" | |
embedding_model = "all-MiniLM-L6-v2" | |
vector_dim = 384 # Sentence Transformer embedding dimension | |
top_k = 3 # Retrieve top 3 relevant chunks | |
chunk_size = 256 # Text chunk size | |
# Vector Database | |
class VectorDB: | |
def __init__(self): | |
self.index = faiss.IndexFlatL2(Config.vector_dim) | |
self.texts = [] | |
self.embedding_model = SentenceTransformer(Config.embedding_model) | |
def add_text(self, text: str): | |
embedding = self.embedding_model.encode([text])[0] | |
embedding = np.array([embedding], dtype=np.float32) | |
faiss.normalize_L2(embedding) | |
self.index.add(embedding) | |
self.texts.append(text) | |
def search(self, query: str) -> List[str]: | |
if self.index.ntotal == 0: | |
return [] | |
query_embedding = self.embedding_model.encode([query])[0] | |
query_embedding = np.array([query_embedding], dtype=np.float32) | |
faiss.normalize_L2(query_embedding) | |
D, I = self.index.search(query_embedding, min(Config.top_k, self.index.ntotal)) | |
return [self.texts[i] for i in I[0] if i < len(self.texts)] | |
# Load Model | |
class TinyChatModel: | |
def __init__(self): | |
self.tokenizer = AutoTokenizer.from_pretrained(Config.model_name) | |
self.pipe = pipeline("text-generation", model=Config.model_name, torch_dtype=torch.bfloat16, device_map="auto") | |
def generate_response(self, message: str, context: str = "") -> str: | |
messages = [{"role": "user", "content": message}] | |
if context: | |
messages.insert(0, {"role": "system", "content": f"Context:\n{context}"}) | |
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
outputs = self.pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95) | |
return outputs[0]["generated_text"].split("<|assistant|>")[-1].strip() | |
# Initialize | |
vector_db = VectorDB() | |
chat_model = TinyChatModel() | |
# Function to handle context addition and chat | |
def chat_function(user_input: str, context: str = ""): | |
if context: | |
vector_db.add_text(context) | |
# Search relevant context | |
context_text = "\n".join(vector_db.search(user_input)) | |
response = chat_model.generate_response(user_input, context_text) | |
vector_db.add_text(f"User: {user_input}\nAssistant: {response}") | |
return response | |
# Gradio Interface | |
def gradio_interface(user_input: str, context: str = ""): | |
response = chat_function(user_input, context) | |
return response | |
# Create Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("# TinyChat: A Conversational AI") | |
with gr.Row(): | |
with gr.Column(): | |
user_input = gr.Textbox(label="User Input", placeholder="Ask anything...") | |
context_input = gr.Textbox(label="Optional Context", placeholder="Paste context here (optional)", lines=3) | |
submit_button = gr.Button("Send") | |
output = gr.Textbox(label="Response", placeholder="Assistant's reply will appear here...") | |
submit_button.click(fn=gradio_interface, inputs=[user_input, context_input], outputs=output) | |
# Run the Gradio app | |
if __name__ == "__main__": | |
demo.launch(share=True) | |