Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -5,42 +5,15 @@ import chromadb
|
|
5 |
import torch
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
import os
|
|
|
8 |
|
9 |
-
# ChromaDB
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
collection = client.get_collection(name=COLLECTION_NAME)
|
17 |
-
print("Collection loaded successfully.")
|
18 |
-
return collection
|
19 |
-
except ValueError:
|
20 |
-
print("Creating new collection...")
|
21 |
-
collection = client.create_collection(name=COLLECTION_NAME, overwrite=True)
|
22 |
-
ds = load_dataset("rwmasood/hadith-qa-pair")
|
23 |
-
device = 'cpu'
|
24 |
-
embedding_model = SentenceTransformer('all-MiniLM-L6-v2').to(device)
|
25 |
-
|
26 |
-
for split in ds.keys():
|
27 |
-
documents = [
|
28 |
-
f"Hadith: {row['hadith-eng']}\nQuestion: {row['question']}\nReference: {row['reference']}"
|
29 |
-
for row in ds[split]
|
30 |
-
]
|
31 |
-
ids = [f"{split}_{i}" for i in range(len(documents))]
|
32 |
-
embeddings = embedding_model.encode(documents, convert_to_tensor=True, device=device).numpy()
|
33 |
-
|
34 |
-
collection.add(
|
35 |
-
documents=documents,
|
36 |
-
ids=ids,
|
37 |
-
embeddings=embeddings
|
38 |
-
)
|
39 |
-
|
40 |
-
print(f"Collection created with {collection.count()} documents.")
|
41 |
-
return collection
|
42 |
-
|
43 |
-
collection = load_or_create_collection()
|
44 |
print(f"Number of documents in collection: {collection.count()}")
|
45 |
|
46 |
# Model and Tokenizer Loading
|
@@ -54,74 +27,99 @@ llm = AutoModelForSeq2SeqLM.from_pretrained(
|
|
54 |
device_map="auto"
|
55 |
)
|
56 |
|
57 |
-
#
|
58 |
-
device = 'cpu'
|
59 |
-
|
60 |
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
|
|
|
|
67 |
return results
|
68 |
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
responses.append(f"An error occurred during generation: {e}")
|
81 |
-
return responses
|
82 |
-
|
83 |
-
def grade_responses(responses, query):
|
84 |
-
best_score = -1
|
85 |
-
best_response = ""
|
86 |
-
for response in responses:
|
87 |
-
score = sum(1 for word in query.lower().split() if word in response.lower())
|
88 |
-
query_embedding = model.encode(query, convert_to_tensor=True, device=device)
|
89 |
-
response_embedding = model.encode(response, convert_to_tensor=True, device=device)
|
90 |
-
similarity = torch.nn.functional.cosine_similarity(query_embedding, response_embedding, dim=0).item()
|
91 |
-
score += similarity * 10
|
92 |
-
if score > best_score:
|
93 |
-
best_score = score
|
94 |
-
best_response = response
|
95 |
-
return best_response
|
96 |
-
|
97 |
-
def chatbot_response(user_query, top_k=3, num_candidates=3):
|
98 |
results = query_collection(user_query, top_k)
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
def chatbot(query, num_candidates):
|
|
|
|
|
|
|
|
|
105 |
if not query.strip():
|
106 |
return "Please ask a question about hadiths."
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
)
|
127 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
import torch
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
import os
|
8 |
+
from chromadb.utils import embedding_functions
|
9 |
|
10 |
+
# Initialize ChromaDB client with the existing path
|
11 |
+
client = chromadb.PersistentClient(path="new_hadith_rag_source")
|
12 |
+
|
13 |
+
# Load the existing collection
|
14 |
+
collection = client.get_collection(name="hadiths_new_complete")
|
15 |
+
|
16 |
+
# Debugging print to verify the number of documents in the collection
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
print(f"Number of documents in collection: {collection.count()}")
|
18 |
|
19 |
# Model and Tokenizer Loading
|
|
|
27 |
device_map="auto"
|
28 |
)
|
29 |
|
30 |
+
# Load the pre-trained model and tokenizer
|
31 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
32 |
+
retrieval_model = SentenceTransformer('all-MiniLM-L6-v2').to(device)
|
33 |
|
34 |
+
# Function to query the collection
|
35 |
+
def query_collection(query, n_results):
|
36 |
+
# Compute the embedding for the query
|
37 |
+
query_embedding = retrieval_model.encode([query], convert_to_tensor=True, device=device).cpu().numpy()
|
38 |
+
|
39 |
+
# Query the collection
|
40 |
+
results = collection.query(query_embeddings=query_embedding, n_results=n_results)
|
41 |
+
|
42 |
return results
|
43 |
|
44 |
+
# Generate a response using the retrieved documents as context
|
45 |
+
def generate_response(context, question):
|
46 |
+
prompt = f"Please provide a short, well-structured answer and avoids repetition from context:\n{context}\n\nQuestion:\n{question}\n\nAnswer:"
|
47 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
48 |
+
outputs = llm.generate(**inputs, max_length=2048, num_return_sequences=1, num_beams=5, temperature=0.9, pad_token_id=tokenizer.eos_token_id)
|
49 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
50 |
+
return response
|
51 |
+
|
52 |
+
# Main chatbot function with basic RAG
|
53 |
+
def chatbot_response(user_query, top_k=2):
|
54 |
+
# Step 1: Retrieve relevant documents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
results = query_collection(user_query, top_k)
|
56 |
+
|
57 |
+
# Step 2: Combine retrieved documents into context
|
58 |
+
documents = [doc for doc_list in results['documents'] for doc in doc_list]
|
59 |
+
combined_context = "\n\n".join(documents)
|
60 |
+
|
61 |
+
# Step 3: Generate a response using the combined context
|
62 |
+
response = generate_response(combined_context, user_query)
|
63 |
+
|
64 |
+
return response
|
65 |
+
|
66 |
+
# Global variable to control the processing state
|
67 |
+
stop_processing = False
|
68 |
|
69 |
def chatbot(query, num_candidates):
|
70 |
+
global stop_processing
|
71 |
+
stop_processing = False # Reset stop flag at the beginning of each query
|
72 |
+
|
73 |
+
# Jika query kosong, kembalikan pesan default
|
74 |
if not query.strip():
|
75 |
return "Please ask a question about hadiths."
|
76 |
+
|
77 |
+
# Lakukan retrieval dan generation dengan Speculative RAG
|
78 |
+
answer = chatbot_response(query, num_candidates)
|
79 |
+
|
80 |
+
# Check if stop button was pressed
|
81 |
+
if stop_processing:
|
82 |
+
return "Processing was stopped by the user."
|
83 |
+
|
84 |
+
# Format jawaban
|
85 |
+
if "don't know" in answer.lower() or "not sure" in answer.lower():
|
86 |
+
return "Sorry. I don't have information about the hadiths related. It might be a dhoif, or maudhu, or I just don't have the knowledge."
|
87 |
+
else:
|
88 |
+
return answer
|
89 |
+
|
90 |
+
def stop():
|
91 |
+
global stop_processing
|
92 |
+
stop_processing = True
|
93 |
+
return "Processing stopped."
|
94 |
+
|
95 |
+
# Buat Gradio interface
|
96 |
+
with gr.Blocks() as demo:
|
97 |
+
gr.Markdown(
|
98 |
+
"""
|
99 |
+
# Burhan AI
|
100 |
+
Assalamualaikum! I am Burhan AI, a chatbot that can help you find answers to your questions about hadiths.
|
101 |
+
\n
|
102 |
+
Please note that this is a demo version and may not be perfect.
|
103 |
+
This chatbot is powered by the ChromaDB and Flan-T5-base models with RAG architecture.
|
104 |
+
Flan-T5-base is a small model and may not be as accurate as the bigger models.
|
105 |
+
If you have any feedback or suggestions, you can contact me at [email protected]
|
106 |
+
\n
|
107 |
+
Jazakallah Khairan!
|
108 |
+
"""
|
109 |
)
|
110 |
+
with gr.Row():
|
111 |
+
query_input = gr.Textbox(lines=2, placeholder="Enter your question here...")
|
112 |
+
num_candidates_input = gr.Slider(minimum=1, maximum=10, value=2, step=1, label="Number of References")
|
113 |
+
submit_button = gr.Button("Submit")
|
114 |
+
|
115 |
+
output_text = gr.Textbox(label="Response")
|
116 |
+
|
117 |
+
submit_button.click(chatbot, inputs=[query_input, num_candidates_input], outputs=output_text)
|
118 |
+
|
119 |
+
# Add a button to stop processing
|
120 |
+
stop_button = gr.Button("Stop Processing")
|
121 |
+
stop_output = gr.Textbox(visible=False)
|
122 |
+
stop_button.click(stop, inputs=[], outputs=stop_output)
|
123 |
+
|
124 |
+
# Jalankan Gradio interface
|
125 |
+
demo.launch()
|