frendyrachman commited on
Commit
99d6fc6
·
verified ·
1 Parent(s): fe51286

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -49
app.py CHANGED
@@ -1,64 +1,140 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
3
 
4
  """
5
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
  """
7
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
 
 
 
 
25
 
26
- messages.append({"role": "user", "content": message})
 
 
 
27
 
28
- response = ""
 
 
 
 
 
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
41
 
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
-
62
-
63
- if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ from huggingface_hub import login
4
+ from datasets import load_dataset
5
+ import chromadb
6
+ import torch
7
+ from sentence_transformers import SentenceTransformer
8
+ import os
9
 
10
  """
11
  For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
12
  """
13
  client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
14
 
15
+ # ChromaDB Setup (Persistent Client)
16
+ CHROMA_DB_PATH = "new_hadith_rag_source" # Directory to store ChromaDB data
17
+ client = chromadb.PersistentClient(path=CHROMA_DB_PATH)
18
 
19
+ COLLECTION_NAME = "hadiths_new_complete"
 
 
 
 
 
 
 
 
20
 
21
+ # Function to load or create the ChromaDB collection
22
+ def load_or_create_collection():
23
+ try:
24
+ collection = client.get_collection(name=COLLECTION_NAME)
25
+ print("Collection loaded successfully.")
26
+ return collection
27
+ except ValueError: # Collection doesn't exist
28
+ print("Creating new collection...")
29
+ collection = client.create_collection(name=COLLECTION_NAME, overwrite=True)
30
 
31
+ # Load data and add to the collection
32
+ ds = load_dataset("rwmasood/hadith-qa-pair")
33
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
34
+ embedding_model = SentenceTransformer('all-MiniLM-L6-v2').to(device) # Using a local name to avoid shadowing
35
 
36
+ for split in ds.keys():
37
+ documents = [
38
+ f"Hadith: {row['hadith-eng']}\nQuestion: {row['question']}\nReference: {row['reference']}"
39
+ for row in ds[split]
40
+ ]
41
+ ids = [f"{split}_{i}" for i in range(len(documents))]
42
 
43
+ # Compute embeddings using CUDA
44
+ embeddings = embedding_model.encode(documents, convert_to_tensor=True, device=device)
 
 
 
 
 
 
45
 
46
+ collection.add(
47
+ documents=documents,
48
+ ids=ids,
49
+ embeddings=embeddings.cpu().numpy()
50
+ )
51
 
52
+ print(f"Collection created with {collection.count()} documents.")
53
+ return collection
54
 
55
+ # Load or create the collection
56
+ collection = load_or_create_collection()
57
+
58
+ # Debugging print
59
+ print(f"Number of documents in collection: {collection.count()}")
60
+
61
+ # Model and Tokenizer Loading
62
+ model_name = "meta-llama/Llama-3.2-3B-Instruct"
63
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
64
+ llm = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16, pad_token_id=tokenizer.eos_token_id)
65
+
66
+ # Embedding Model
67
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
68
+ model = SentenceTransformer('all-MiniLM-L6-v2').to(device)
69
+
70
+
71
+ # Helper Functions (Querying, Generation, Grading) - No changes needed
72
+ def query_collection(query_text, top_k=3):
73
+ query_embedding = model.encode(query_text, convert_to_tensor=True, device=device).cpu().numpy()
74
+ results = collection.query(
75
+ query_embeddings=[query_embedding],
76
+ n_results=top_k
77
+ )
78
+ return results
79
+
80
+ def speculative_generation(context, question, num_candidates=3):
81
+ responses = []
82
+ for _ in range(num_candidates):
83
+ prompt = f"Context:\n{context}\n\nQuestion:\n{question}\n\nAnswer:"
84
+ inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
85
+ try:
86
+ outputs = llm.generate(**inputs, max_length=2048, num_return_sequences=1, num_beams=5, temperature=0.9, pad_token_id=tokenizer.eos_token_id)
87
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
88
+ responses.append(response)
89
+ except Exception as e:
90
+ print(f"Error during generation: {e}")
91
+ responses.append("An error occurred during generation.")
92
+ return responses
93
+
94
+ def grade_responses(responses, query):
95
+ best_score = -1
96
+ best_response = ""
97
+ for response in responses:
98
+ score = 0
99
+ score += sum(1 for word in query.lower().split() if word in response.lower())
100
+ query_embedding = model.encode(query, convert_to_tensor=True, device=device)
101
+ response_embedding = model.encode(response, convert_to_tensor=True, device=device)
102
+ similarity = torch.nn.functional.cosine_similarity(query_embedding, response_embedding, dim=0).item()
103
+ score += similarity * 10
104
+ if score > best_score:
105
+ best_score = score
106
+ best_response = response
107
+ return best_response
108
+
109
+ def chatbot_response(user_query, top_k=3, num_candidates=3):
110
+ results = query_collection(user_query, top_k)
111
+ context = "\n\n".join(results['documents'][0])
112
+ speculative_responses = speculative_generation(context, user_query, num_candidates)
113
+ best_response = grade_responses(speculative_responses, user_query)
114
+ return best_response
115
+
116
+ # Chatbot Function (Adjusted for Error Handling and Default Message)
117
+ def chatbot(query):
118
+ if not query.strip():
119
+ return "Please ask a question about hadiths."
120
+
121
+ try:
122
+ answer = chatbot_response(query)
123
+ if "don't know" in answer.lower() or "not sure" in answer.lower():
124
+ return "Sorry. I don't have information about the hadiths related. It might be a dhoif, or maudhu, or I just don't have the knowledge."
125
+ else:
126
+ return answer
127
+ except Exception as e:
128
+ print(f"Error in chatbot: {e}")
129
+ return f"An error occurred: {e}"
130
+
131
+ # Gradio Interface
132
+ if __name__ == "__main__": # Ensures this only runs when the script is executed directly
133
+ iface = gr.Interface(
134
+ fn=chatbot,
135
+ inputs=gr.Textbox(lines=7, placeholder="Ask me a question about hadiths...", label="Question"),
136
+ outputs=gr.Textbox(label="Answer"),
137
+ title="Hadith QA Chatbot",
138
+ description="Ask questions related to Hadiths."
139
+ )
140
+ iface.launch()