aliMohammad16 commited on
Commit
f74f783
·
verified ·
1 Parent(s): 79d0b35

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -28
app.py CHANGED
@@ -1,17 +1,18 @@
1
- import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, pipeline
4
  from sentence_transformers import SentenceTransformer
5
  import faiss
6
  import numpy as np
 
 
7
 
8
  # Configuration
9
  class Config:
10
  model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
11
  embedding_model = "all-MiniLM-L6-v2"
12
- vector_dim = 384
13
- top_k = 3
14
- chunk_size = 256
15
 
16
  # Vector Database
17
  class VectorDB:
@@ -27,7 +28,7 @@ class VectorDB:
27
  self.index.add(embedding)
28
  self.texts.append(text)
29
 
30
- def search(self, query: str):
31
  if self.index.ntotal == 0:
32
  return []
33
  query_embedding = self.embedding_model.encode([query])[0]
@@ -42,7 +43,7 @@ class TinyChatModel:
42
  self.tokenizer = AutoTokenizer.from_pretrained(Config.model_name)
43
  self.pipe = pipeline("text-generation", model=Config.model_name, torch_dtype=torch.bfloat16, device_map="auto")
44
 
45
- def generate_response(self, message: str, context: str = ""):
46
  messages = [{"role": "user", "content": message}]
47
  if context:
48
  messages.insert(0, {"role": "system", "content": f"Context:\n{context}"})
@@ -54,32 +55,35 @@ class TinyChatModel:
54
  vector_db = VectorDB()
55
  chat_model = TinyChatModel()
56
 
57
- def chat_interface(user_input):
58
- context = "\n".join(vector_db.search(user_input))
59
- response = chat_model.generate_response(user_input, context)
 
 
 
 
 
60
  vector_db.add_text(f"User: {user_input}\nAssistant: {response}")
 
61
  return response
62
 
63
- def add_text_interface(text):
64
- vector_db.add_text(text)
65
- return "Text added to memory!"
 
66
 
67
- # Gradio UI
68
- demo = gr.Blocks()
69
- with demo:
70
- gr.Markdown("# 🦙 TinyChat - AI Chatbot")
71
- with gr.Row():
72
- chatbot = gr.Chatbot()
73
  with gr.Row():
74
- user_input = gr.Textbox(label="Your Message")
75
- send_btn = gr.Button("Send")
76
- with gr.Row():
77
- add_text_input = gr.Textbox(label="Add Knowledge to AI")
78
- add_text_btn = gr.Button("Add Text")
79
-
80
- send_btn.click(chat_interface, inputs=user_input, outputs=chatbot)
81
- add_text_btn.click(add_text_interface, inputs=add_text_input, outputs=gr.Textbox())
82
 
83
- # Launch
84
  if __name__ == "__main__":
85
- demo.launch()
 
 
1
  import torch
2
  from transformers import AutoTokenizer, pipeline
3
  from sentence_transformers import SentenceTransformer
4
  import faiss
5
  import numpy as np
6
+ import gradio as gr
7
+ from typing import List
8
 
9
  # Configuration
10
  class Config:
11
  model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
12
  embedding_model = "all-MiniLM-L6-v2"
13
+ vector_dim = 384 # Sentence Transformer embedding dimension
14
+ top_k = 3 # Retrieve top 3 relevant chunks
15
+ chunk_size = 256 # Text chunk size
16
 
17
  # Vector Database
18
  class VectorDB:
 
28
  self.index.add(embedding)
29
  self.texts.append(text)
30
 
31
+ def search(self, query: str) -> List[str]:
32
  if self.index.ntotal == 0:
33
  return []
34
  query_embedding = self.embedding_model.encode([query])[0]
 
43
  self.tokenizer = AutoTokenizer.from_pretrained(Config.model_name)
44
  self.pipe = pipeline("text-generation", model=Config.model_name, torch_dtype=torch.bfloat16, device_map="auto")
45
 
46
+ def generate_response(self, message: str, context: str = "") -> str:
47
  messages = [{"role": "user", "content": message}]
48
  if context:
49
  messages.insert(0, {"role": "system", "content": f"Context:\n{context}"})
 
55
  vector_db = VectorDB()
56
  chat_model = TinyChatModel()
57
 
58
+ # Function to handle context addition and chat
59
+ def chat_function(user_input: str, context: str = ""):
60
+ if context:
61
+ vector_db.add_text(context)
62
+
63
+ # Search relevant context
64
+ context_text = "\n".join(vector_db.search(user_input))
65
+ response = chat_model.generate_response(user_input, context_text)
66
  vector_db.add_text(f"User: {user_input}\nAssistant: {response}")
67
+
68
  return response
69
 
70
+ # Gradio Interface
71
+ def gradio_interface(user_input: str, context: str = ""):
72
+ response = chat_function(user_input, context)
73
+ return response
74
 
75
+ # Create Gradio UI
76
+ with gr.Blocks() as demo:
77
+ gr.Markdown("# TinyChat: A Conversational AI")
 
 
 
78
  with gr.Row():
79
+ with gr.Column():
80
+ user_input = gr.Textbox(label="User Input", placeholder="Ask anything...")
81
+ context_input = gr.Textbox(label="Optional Context", placeholder="Paste context here (optional)", lines=3)
82
+ submit_button = gr.Button("Send")
83
+ output = gr.Textbox(label="Response", placeholder="Assistant's reply will appear here...")
84
+
85
+ submit_button.click(fn=gradio_interface, inputs=[user_input, context_input], outputs=output)
 
86
 
87
+ # Run the Gradio app
88
  if __name__ == "__main__":
89
+ demo.launch(share=True)