Spaces:
Runtime error
Runtime error
File size: 3,503 Bytes
decda19 f74f783 297e1ab decda19 f74f783 297e1ab decda19 f74f783 decda19 297e1ab decda19 297e1ab f74f783 decda19 297e1ab decda19 297e1ab f74f783 decda19 f74f783 decda19 297e1ab f74f783 297e1ab f74f783 decda19 f74f783 297e1ab f74f783 297e1ab f74f783 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import torch
from transformers import AutoTokenizer, pipeline
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import gradio as gr
from typing import List
# Configuration
class Config:
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
embedding_model = "all-MiniLM-L6-v2"
vector_dim = 384 # Sentence Transformer embedding dimension
top_k = 3 # Retrieve top 3 relevant chunks
chunk_size = 256 # Text chunk size
# Vector Database
class VectorDB:
def __init__(self):
self.index = faiss.IndexFlatL2(Config.vector_dim)
self.texts = []
self.embedding_model = SentenceTransformer(Config.embedding_model)
def add_text(self, text: str):
embedding = self.embedding_model.encode([text])[0]
embedding = np.array([embedding], dtype=np.float32)
faiss.normalize_L2(embedding)
self.index.add(embedding)
self.texts.append(text)
def search(self, query: str) -> List[str]:
if self.index.ntotal == 0:
return []
query_embedding = self.embedding_model.encode([query])[0]
query_embedding = np.array([query_embedding], dtype=np.float32)
faiss.normalize_L2(query_embedding)
D, I = self.index.search(query_embedding, min(Config.top_k, self.index.ntotal))
return [self.texts[i] for i in I[0] if i < len(self.texts)]
# Load Model
class TinyChatModel:
def __init__(self):
self.tokenizer = AutoTokenizer.from_pretrained(Config.model_name)
self.pipe = pipeline("text-generation", model=Config.model_name, torch_dtype=torch.bfloat16, device_map="auto")
def generate_response(self, message: str, context: str = "") -> str:
messages = [{"role": "user", "content": message}]
if context:
messages.insert(0, {"role": "system", "content": f"Context:\n{context}"})
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
outputs = self.pipe(prompt, max_new_tokens=256, do_sample=True, temperature=0.7, top_k=50, top_p=0.95)
return outputs[0]["generated_text"].split("<|assistant|>")[-1].strip()
# Initialize
vector_db = VectorDB()
chat_model = TinyChatModel()
# Function to handle context addition and chat
def chat_function(user_input: str, context: str = ""):
if context:
vector_db.add_text(context)
# Search relevant context
context_text = "\n".join(vector_db.search(user_input))
response = chat_model.generate_response(user_input, context_text)
vector_db.add_text(f"User: {user_input}\nAssistant: {response}")
return response
# Gradio Interface
def gradio_interface(user_input: str, context: str = ""):
response = chat_function(user_input, context)
return response
# Create Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# TinyChat: A Conversational AI")
with gr.Row():
with gr.Column():
user_input = gr.Textbox(label="User Input", placeholder="Ask anything...")
context_input = gr.Textbox(label="Optional Context", placeholder="Paste context here (optional)", lines=3)
submit_button = gr.Button("Send")
output = gr.Textbox(label="Response", placeholder="Assistant's reply will appear here...")
submit_button.click(fn=gradio_interface, inputs=[user_input, context_input], outputs=output)
# Run the Gradio app
if __name__ == "__main__":
demo.launch(share=True)
|