frendyrachman commited on
Commit
f513bae
·
verified ·
1 Parent(s): a4af23f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -97
app.py CHANGED
@@ -5,42 +5,15 @@ import chromadb
5
  import torch
6
  from sentence_transformers import SentenceTransformer
7
  import os
 
8
 
9
- # ChromaDB Setup (Persistent Client)
10
- CHROMA_DB_PATH = "new_hadith_rag_source"
11
- client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
12
- COLLECTION_NAME = "hadiths_new_complete"
13
-
14
- def load_or_create_collection():
15
- try:
16
- collection = client.get_collection(name=COLLECTION_NAME)
17
- print("Collection loaded successfully.")
18
- return collection
19
- except ValueError:
20
- print("Creating new collection...")
21
- collection = client.create_collection(name=COLLECTION_NAME, overwrite=True)
22
- ds = load_dataset("rwmasood/hadith-qa-pair")
23
- device = 'cpu'
24
- embedding_model = SentenceTransformer('all-MiniLM-L6-v2').to(device)
25
-
26
- for split in ds.keys():
27
- documents = [
28
- f"Hadith: {row['hadith-eng']}\nQuestion: {row['question']}\nReference: {row['reference']}"
29
- for row in ds[split]
30
- ]
31
- ids = [f"{split}_{i}" for i in range(len(documents))]
32
- embeddings = embedding_model.encode(documents, convert_to_tensor=True, device=device).numpy()
33
-
34
- collection.add(
35
- documents=documents,
36
- ids=ids,
37
- embeddings=embeddings
38
- )
39
-
40
- print(f"Collection created with {collection.count()} documents.")
41
- return collection
42
-
43
- collection = load_or_create_collection()
44
  print(f"Number of documents in collection: {collection.count()}")
45
 
46
  # Model and Tokenizer Loading
@@ -54,74 +27,99 @@ llm = AutoModelForSeq2SeqLM.from_pretrained(
54
  device_map="auto"
55
  )
56
 
57
- # Embedding Model
58
- device = 'cpu'
59
- model = SentenceTransformer('all-MiniLM-L6-v2').to(device)
60
 
61
- def query_collection(query_text, top_k=3):
62
- query_embedding = model.encode(query_text, convert_to_tensor=True, device=device).numpy()
63
- results = collection.query(
64
- query_embeddings=[query_embedding],
65
- n_results=top_k
66
- )
 
 
67
  return results
68
 
69
- def speculative_generation(context, question, num_candidates=3):
70
- responses = []
71
- for _ in range(num_candidates):
72
- prompt = f"Context:\n{context}\n\nQuestion:\n{question}\n\nAnswer:"
73
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
74
- try:
75
- outputs = llm.generate(**inputs, max_length=2048, num_return_sequences=1, num_beams=5, temperature=0.9, pad_token_id=tokenizer.eos_token_id)
76
- response = tokenizer.decode(outputs[0], skip_special_tokens=True)
77
- responses.append(response)
78
- except Exception as e:
79
- print(f"Error during generation: {e}")
80
- responses.append(f"An error occurred during generation: {e}")
81
- return responses
82
-
83
- def grade_responses(responses, query):
84
- best_score = -1
85
- best_response = ""
86
- for response in responses:
87
- score = sum(1 for word in query.lower().split() if word in response.lower())
88
- query_embedding = model.encode(query, convert_to_tensor=True, device=device)
89
- response_embedding = model.encode(response, convert_to_tensor=True, device=device)
90
- similarity = torch.nn.functional.cosine_similarity(query_embedding, response_embedding, dim=0).item()
91
- score += similarity * 10
92
- if score > best_score:
93
- best_score = score
94
- best_response = response
95
- return best_response
96
-
97
- def chatbot_response(user_query, top_k=3, num_candidates=3):
98
  results = query_collection(user_query, top_k)
99
- context = "\n\n".join(results['documents'][0])
100
- speculative_responses = speculative_generation(context, user_query, num_candidates)
101
- best_response = grade_responses(speculative_responses, user_query)
102
- return best_response
 
 
 
 
 
 
 
 
103
 
104
  def chatbot(query, num_candidates):
 
 
 
 
105
  if not query.strip():
106
  return "Please ask a question about hadiths."
107
- try:
108
- answer = chatbot_response(query, num_candidates)
109
- if "don't know" in answer.lower() or "not sure" in answer.lower():
110
- return "Sorry. I don't have information about the hadiths related. It might be a dhoif, or maudhu, or I just don't have the knowledge."
111
- else:
112
- return answer
113
- except Exception as e:
114
- print(f"Error in chatbot: {e}")
115
- return f"An error occurred: {e}"
116
-
117
- if __name__ == "__main__":
118
- iface = gr.Interface(
119
- fn=chatbot,
120
- inputs=[
121
- gr.Textbox(lines=2, placeholder="Enter your question here..."),
122
- gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Number of Hadiths as References")],
123
- outputs=gr.Textbox(label="Answer"),
124
- title="Hadith QA Chatbot",
125
- description="Ask questions related to Hadiths."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  )
127
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import torch
6
  from sentence_transformers import SentenceTransformer
7
  import os
8
+ from chromadb.utils import embedding_functions
9
 
10
+ # Initialize ChromaDB client with the existing path
11
+ client = chromadb.PersistentClient(path="new_hadith_rag_source")
12
+
13
+ # Load the existing collection
14
+ collection = client.get_collection(name="hadiths_new_complete")
15
+
16
+ # Debugging print to verify the number of documents in the collection
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  print(f"Number of documents in collection: {collection.count()}")
18
 
19
  # Model and Tokenizer Loading
 
27
  device_map="auto"
28
  )
29
 
30
+ # Load the pre-trained model and tokenizer
31
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
32
+ retrieval_model = SentenceTransformer('all-MiniLM-L6-v2').to(device)
33
 
34
+ # Function to query the collection
35
+ def query_collection(query, n_results):
36
+ # Compute the embedding for the query
37
+ query_embedding = retrieval_model.encode([query], convert_to_tensor=True, device=device).cpu().numpy()
38
+
39
+ # Query the collection
40
+ results = collection.query(query_embeddings=query_embedding, n_results=n_results)
41
+
42
  return results
43
 
44
+ # Generate a response using the retrieved documents as context
45
+ def generate_response(context, question):
46
+ prompt = f"Please provide a short, well-structured answer and avoids repetition from context:\n{context}\n\nQuestion:\n{question}\n\nAnswer:"
47
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
48
+ outputs = llm.generate(**inputs, max_length=2048, num_return_sequences=1, num_beams=5, temperature=0.9, pad_token_id=tokenizer.eos_token_id)
49
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
50
+ return response
51
+
52
+ # Main chatbot function with basic RAG
53
+ def chatbot_response(user_query, top_k=2):
54
+ # Step 1: Retrieve relevant documents
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  results = query_collection(user_query, top_k)
56
+
57
+ # Step 2: Combine retrieved documents into context
58
+ documents = [doc for doc_list in results['documents'] for doc in doc_list]
59
+ combined_context = "\n\n".join(documents)
60
+
61
+ # Step 3: Generate a response using the combined context
62
+ response = generate_response(combined_context, user_query)
63
+
64
+ return response
65
+
66
+ # Global variable to control the processing state
67
+ stop_processing = False
68
 
69
  def chatbot(query, num_candidates):
70
+ global stop_processing
71
+ stop_processing = False # Reset stop flag at the beginning of each query
72
+
73
+ # Jika query kosong, kembalikan pesan default
74
  if not query.strip():
75
  return "Please ask a question about hadiths."
76
+
77
+ # Lakukan retrieval dan generation dengan Speculative RAG
78
+ answer = chatbot_response(query, num_candidates)
79
+
80
+ # Check if stop button was pressed
81
+ if stop_processing:
82
+ return "Processing was stopped by the user."
83
+
84
+ # Format jawaban
85
+ if "don't know" in answer.lower() or "not sure" in answer.lower():
86
+ return "Sorry. I don't have information about the hadiths related. It might be a dhoif, or maudhu, or I just don't have the knowledge."
87
+ else:
88
+ return answer
89
+
90
+ def stop():
91
+ global stop_processing
92
+ stop_processing = True
93
+ return "Processing stopped."
94
+
95
+ # Buat Gradio interface
96
+ with gr.Blocks() as demo:
97
+ gr.Markdown(
98
+ """
99
+ # Burhan AI
100
+ Assalamualaikum! I am Burhan AI, a chatbot that can help you find answers to your questions about hadiths.
101
+ \n
102
+ Please note that this is a demo version and may not be perfect.
103
+ This chatbot is powered by the ChromaDB and Flan-T5-base models with RAG architecture.
104
+ Flan-T5-base is a small model and may not be as accurate as the bigger models.
105
+ If you have any feedback or suggestions, you can contact me at [email protected]
106
+ \n
107
+ Jazakallah Khairan!
108
+ """
109
  )
110
+ with gr.Row():
111
+ query_input = gr.Textbox(lines=2, placeholder="Enter your question here...")
112
+ num_candidates_input = gr.Slider(minimum=1, maximum=10, value=2, step=1, label="Number of References")
113
+ submit_button = gr.Button("Submit")
114
+
115
+ output_text = gr.Textbox(label="Response")
116
+
117
+ submit_button.click(chatbot, inputs=[query_input, num_candidates_input], outputs=output_text)
118
+
119
+ # Add a button to stop processing
120
+ stop_button = gr.Button("Stop Processing")
121
+ stop_output = gr.Textbox(visible=False)
122
+ stop_button.click(stop, inputs=[], outputs=stop_output)
123
+
124
+ # Jalankan Gradio interface
125
+ demo.launch()