Phoenix21 commited on
Commit
9a82698
·
1 Parent(s): 0bda508

removed punkt

Browse files
Files changed (1) hide show
  1. app.py +87 -145
app.py CHANGED
@@ -1,10 +1,6 @@
1
  import os
2
  import logging
3
  import re
4
- import nltk
5
- import spacy
6
- import traceback
7
- from nltk.tokenize import sent_tokenize
8
  from langchain.vectorstores import Chroma
9
  from langchain_core.output_parsers import StrOutputParser
10
  from langchain_core.runnables import RunnablePassthrough
@@ -20,13 +16,9 @@ import pandas as pd
20
  import json
21
 
22
  # Enable logging for debugging
23
- logging.basicConfig(level=logging.DEBUG)
24
  logger = logging.getLogger(__name__)
25
 
26
- # Set NLTK data path to the local 'nltk_data' directory
27
- nltk.data.path.append(os.path.join(os.path.dirname(__file__), 'nltk_data'))
28
- logger.debug("Configured NLTK data path to local 'nltk_data' directory.")
29
-
30
  # Function to clean the API key
31
  def clean_api_key(key):
32
  return ''.join(c for c in key if ord(c) < 128)
@@ -77,83 +69,41 @@ def load_documents(file_paths):
77
  logger.warning(f"Unsupported file format: {file_path}")
78
  except Exception as e:
79
  logger.error(f"Error processing file {file_path}: {e}")
80
- logger.error(traceback.format_exc())
81
  return docs
82
 
83
- # Function to ensure the response ends with complete sentences using NLTK
84
  def ensure_complete_sentences(text):
85
- logger.debug("Ensuring complete sentences for the given text.")
86
- try:
87
- sentences = sent_tokenize(text)
88
- if sentences:
89
- return ' '.join(sentences).strip()
90
- return text # Return as is if no complete sentence is found
91
- except LookupError as e:
92
- logger.error("NLTK resource 'punkt' not found. Attempting to download again.")
93
- try:
94
- nltk.download('punkt', download_dir=os.path.join(os.path.dirname(__file__), 'nltk_data'))
95
- nltk.data.path.append(os.path.join(os.path.dirname(__file__), 'nltk_data'))
96
- sentences = sent_tokenize(text)
97
- return ' '.join(sentences).strip()
98
- except Exception as e_inner:
99
- logger.error("Failed to download 'punkt' resource.")
100
- logger.error(traceback.format_exc())
101
- raise e_inner
102
- except Exception as e:
103
- logger.error("Unexpected error during sentence tokenization.")
104
- logger.error(traceback.format_exc())
105
- raise e
106
-
107
- # Advanced input validation using spaCy (Section 8a)
108
- def is_valid_input_nlp(text, threshold=0.5):
109
  """
110
- Validates input text using spaCy's NLP capabilities.
111
-
112
- Parameters:
113
- - text (str): The input text to validate.
114
- - threshold (float): The minimum ratio of meaningful tokens required.
115
-
116
- Returns:
117
- - bool: True if the input is valid, False otherwise.
118
  """
119
  if not text or text.strip() == "":
120
- logger.debug("Input text is empty or contains only whitespace.")
121
  return False
122
- doc = nlp(text)
123
- meaningful_tokens = [token for token in doc if token.is_alpha]
124
- if not meaningful_tokens:
125
- logger.debug("No meaningful (alphabetic) tokens found in input.")
126
  return False
127
- ratio = len(meaningful_tokens) / len(doc)
128
- logger.debug(f"Meaningful tokens ratio: {ratio}")
129
- return ratio >= threshold
130
-
131
- # Function to estimate prompt tokens (simple word count approximation)
132
- def estimate_prompt_tokens(prompt):
133
- """
134
- Estimates the number of tokens in the prompt.
135
- This is a placeholder function. Replace it with actual token estimation logic.
136
-
137
- Parameters:
138
- - prompt (str): The prompt text.
139
-
140
- Returns:
141
- - int: Estimated number of tokens.
142
- """
143
- return len(prompt.split())
144
 
145
  # Initialize the LLM using ChatGroq with GROQ's API
146
- def initialize_llm(model, temperature, max_tokens, prompt_template):
147
  try:
148
- # Estimate prompt tokens
149
- estimated_prompt_tokens = estimate_prompt_tokens(prompt_template)
150
- logger.debug(f"Estimated prompt tokens: {estimated_prompt_tokens}")
151
-
152
- # Allocate remaining tokens to response
153
- response_max_tokens = max_tokens - estimated_prompt_tokens
154
- logger.debug(f"Response max tokens: {response_max_tokens}")
155
-
156
- if response_max_tokens <= 100:
157
  raise ValueError("max_tokens is too small to allocate for the response.")
158
 
159
  llm = ChatGroq(
@@ -162,53 +112,26 @@ def initialize_llm(model, temperature, max_tokens, prompt_template):
162
  max_tokens=response_max_tokens, # Adjusted max_tokens
163
  api_key=api_key # Ensure the API key is passed correctly
164
  )
165
- logger.debug("LLM initialized successfully.")
166
  return llm
167
  except Exception as e:
168
  logger.error(f"Error initializing LLM: {e}")
169
- logger.error(traceback.format_exc())
170
- raise e
171
 
172
  # Create the RAG pipeline
173
  def create_rag_pipeline(file_paths, model, temperature, max_tokens):
174
  try:
175
- # Define the prompt template first to estimate tokens
176
- custom_prompt_template = PromptTemplate(
177
- input_variables=["context", "question"],
178
- template="""
179
- You are an AI assistant with expertise in daily wellness. Your aim is to provide detailed and comprehensive solutions regarding daily wellness topics without unnecessary verbosity.
180
-
181
- Context:
182
- {context}
183
-
184
- Question:
185
- {question}
186
-
187
- Provide a thorough and complete answer, including relevant examples and a suggested schedule. Ensure that the response does not end abruptly.
188
- """
189
- )
190
-
191
- # Estimate prompt tokens
192
- estimated_prompt_tokens = estimate_prompt_tokens(custom_prompt_template.template)
193
- logger.debug(f"Estimated prompt tokens from template: {estimated_prompt_tokens}")
194
-
195
- # Initialize the LLM with token allocation
196
- llm = initialize_llm(model, temperature, max_tokens, custom_prompt_template.template)
197
-
198
- # Load and process documents
199
  docs = load_documents(file_paths)
200
  if not docs:
201
  logger.warning("No documents were loaded. Please check your file paths and formats.")
202
  return None, "No documents were loaded. Please check your file paths and formats."
203
 
204
- # Split documents into chunks
205
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
206
  splits = text_splitter.split_documents(docs)
207
- logger.debug(f"Documents split into {len(splits)} chunks.")
208
 
209
  # Initialize the embedding model
210
  embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
211
- logger.debug("Embedding model initialized successfully.")
212
 
213
  # Use a temporary directory for Chroma vectorstore to prevent caching issues on Hugging Face Spaces
214
  vectorstore = Chroma.from_documents(
@@ -217,25 +140,39 @@ def create_rag_pipeline(file_paths, model, temperature, max_tokens):
217
  persist_directory="/tmp/chroma_db" # Temporary storage directory
218
  )
219
  vectorstore.persist() # Save the database to disk
220
- logger.debug("Vectorstore initialized and persisted successfully.")
221
 
222
  retriever = vectorstore.as_retriever()
223
 
224
- # Create the RetrievalQA chain
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  rag_chain = RetrievalQA.from_chain_type(
226
  llm=llm,
227
  chain_type="stuff",
228
  retriever=retriever,
229
  chain_type_kwargs={"prompt": custom_prompt_template}
230
  )
231
- logger.debug("RAG pipeline created successfully.")
232
  return rag_chain, "Pipeline created successfully."
233
  except Exception as e:
234
  logger.error(f"Error creating RAG pipeline: {e}")
235
- logger.error(traceback.format_exc())
236
  return None, f"Error creating RAG pipeline: {e}"
237
 
238
- # Function to handle feedback (Section 8d)
239
  def handle_feedback(feedback_text):
240
  """
241
  Handles user feedback by logging it.
@@ -254,43 +191,48 @@ def handle_feedback(feedback_text):
254
  else:
255
  return "No feedback provided."
256
 
 
 
 
 
 
 
 
 
 
 
 
 
257
  # Function to answer questions with input validation and post-processing
258
- def answer_question(file_paths, model, temperature, max_tokens, question, feedback):
 
 
 
 
 
 
 
 
259
  try:
260
- # Validate input using spaCy-based validation
261
- if not is_valid_input_nlp(question):
262
- logger.debug("Invalid input detected.")
263
- return "Please provide a valid question or input containing meaningful text.", ""
264
 
265
- rag_chain, message = create_rag_pipeline(file_paths, model, temperature, max_tokens)
266
- if rag_chain is None:
267
- logger.debug("RAG pipeline creation failed.")
268
- return message, ""
269
 
270
- try:
271
- answer = rag_chain.run(question)
272
- logger.debug("Question answered successfully.")
273
- # Post-process to ensure the answer ends with complete sentences
274
- complete_answer = ensure_complete_sentences(answer)
275
-
276
- # Handle feedback
277
- feedback_response = handle_feedback(feedback)
278
-
279
- return complete_answer, feedback_response
280
- except Exception as e_inner:
281
- logger.error(f"Error during RAG pipeline execution: {e_inner}")
282
- logger.error(traceback.format_exc())
283
- return f"Error during RAG pipeline execution: {e_inner}", ""
284
-
285
- except Exception as e_outer:
286
- logger.error(f"Unexpected error in answer_question: {e_outer}")
287
- logger.error(traceback.format_exc())
288
- return f"Unexpected error: {e_outer}", ""
289
-
290
- # Gradio Interface with Feedback Mechanism (Section 8d)
291
  def gradio_interface(model, temperature, max_tokens, question, feedback):
292
- file_paths = ['AIChatbot.csv'] # Ensure this file is present in your Space root directory
293
- return answer_question(file_paths, model, temperature, max_tokens, question, feedback)
 
294
 
295
  # Define Gradio UI
296
  interface = gr.Interface(
@@ -298,7 +240,7 @@ interface = gr.Interface(
298
  inputs=[
299
  gr.Textbox(
300
  label="Model Name",
301
- value="llama3-8b-8192",
302
  placeholder="e.g., llama3-8b-8192"
303
  ),
304
  gr.Slider(
@@ -306,7 +248,7 @@ interface = gr.Interface(
306
  minimum=0,
307
  maximum=1,
308
  step=0.01,
309
- value=0.7,
310
  info="Controls the randomness of the response. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic."
311
  ),
312
  gr.Slider(
@@ -314,7 +256,7 @@ interface = gr.Interface(
314
  minimum=200,
315
  maximum=2048,
316
  step=1,
317
- value=500,
318
  info="Determines the maximum number of tokens in the response. Higher values allow for longer answers."
319
  ),
320
  gr.Textbox(
 
1
  import os
2
  import logging
3
  import re
 
 
 
 
4
  from langchain.vectorstores import Chroma
5
  from langchain_core.output_parsers import StrOutputParser
6
  from langchain_core.runnables import RunnablePassthrough
 
16
  import json
17
 
18
  # Enable logging for debugging
19
+ logging.basicConfig(level=logging.INFO) # Changed to INFO to reduce verbosity
20
  logger = logging.getLogger(__name__)
21
 
 
 
 
 
22
  # Function to clean the API key
23
  def clean_api_key(key):
24
  return ''.join(c for c in key if ord(c) < 128)
 
69
  logger.warning(f"Unsupported file format: {file_path}")
70
  except Exception as e:
71
  logger.error(f"Error processing file {file_path}: {e}")
72
+ logger.debug("Exception details:", exc_info=True)
73
  return docs
74
 
75
+ # Function to ensure the response ends with complete sentences
76
  def ensure_complete_sentences(text):
77
+ # Use regex to find all complete sentences
78
+ sentences = re.findall(r'[^.!?]*[.!?]', text)
79
+ if sentences:
80
+ # Join all complete sentences to form the complete answer
81
+ return ' '.join(sentences).strip()
82
+ return text # Return as is if no complete sentence is found
83
+
84
+ # Function to check if input is valid
85
+ def is_valid_input(text):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  """
87
+ Checks if the input text is meaningful.
88
+ Returns True if the text contains alphabetic characters and is of sufficient length.
 
 
 
 
 
 
89
  """
90
  if not text or text.strip() == "":
 
91
  return False
92
+ # Regex to check for at least one alphabetic character
93
+ if not re.search('[A-Za-z]', text):
 
 
94
  return False
95
+ # Additional check: minimum length
96
+ if len(text.strip()) < 5:
97
+ return False
98
+ return True
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  # Initialize the LLM using ChatGroq with GROQ's API
101
+ def initialize_llm(model, temperature, max_tokens):
102
  try:
103
+ # Allocate a portion of tokens for the prompt, e.g., 20%
104
+ prompt_allocation = int(max_tokens * 0.2)
105
+ response_max_tokens = max_tokens - prompt_allocation
106
+ if response_max_tokens <= 50:
 
 
 
 
 
107
  raise ValueError("max_tokens is too small to allocate for the response.")
108
 
109
  llm = ChatGroq(
 
112
  max_tokens=response_max_tokens, # Adjusted max_tokens
113
  api_key=api_key # Ensure the API key is passed correctly
114
  )
115
+ logger.info("LLM initialized successfully.")
116
  return llm
117
  except Exception as e:
118
  logger.error(f"Error initializing LLM: {e}")
119
+ raise
 
120
 
121
  # Create the RAG pipeline
122
  def create_rag_pipeline(file_paths, model, temperature, max_tokens):
123
  try:
124
+ llm = initialize_llm(model, temperature, max_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  docs = load_documents(file_paths)
126
  if not docs:
127
  logger.warning("No documents were loaded. Please check your file paths and formats.")
128
  return None, "No documents were loaded. Please check your file paths and formats."
129
 
 
130
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
131
  splits = text_splitter.split_documents(docs)
 
132
 
133
  # Initialize the embedding model
134
  embedding_model = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
 
135
 
136
  # Use a temporary directory for Chroma vectorstore to prevent caching issues on Hugging Face Spaces
137
  vectorstore = Chroma.from_documents(
 
140
  persist_directory="/tmp/chroma_db" # Temporary storage directory
141
  )
142
  vectorstore.persist() # Save the database to disk
143
+ logger.info("Vectorstore initialized and persisted successfully.")
144
 
145
  retriever = vectorstore.as_retriever()
146
 
147
+ custom_prompt_template = PromptTemplate(
148
+ input_variables=["context", "question"],
149
+ template="""
150
+ You are an AI assistant with expertise in daily wellness. Your aim is to provide detailed and comprehensive solutions regarding daily wellness topics without unnecessary verbosity.
151
+
152
+ Context:
153
+ {context}
154
+
155
+ Question:
156
+ {question}
157
+
158
+ Provide a thorough and complete answer, including relevant examples and a suggested schedule. Ensure that the response does not end abruptly.
159
+ """
160
+ )
161
+
162
  rag_chain = RetrievalQA.from_chain_type(
163
  llm=llm,
164
  chain_type="stuff",
165
  retriever=retriever,
166
  chain_type_kwargs={"prompt": custom_prompt_template}
167
  )
168
+ logger.info("RAG pipeline created successfully.")
169
  return rag_chain, "Pipeline created successfully."
170
  except Exception as e:
171
  logger.error(f"Error creating RAG pipeline: {e}")
172
+ logger.debug("Exception details:", exc_info=True)
173
  return None, f"Error creating RAG pipeline: {e}"
174
 
175
+ # Function to handle feedback (Optional Enhancement)
176
  def handle_feedback(feedback_text):
177
  """
178
  Handles user feedback by logging it.
 
191
  else:
192
  return "No feedback provided."
193
 
194
+ # Initialize the RAG pipeline once at startup
195
+ # Define the file paths (ensure 'AIChatbot.csv' is in the root directory of your Space)
196
+ file_paths = ['AIChatbot.csv']
197
+ model = "llama3-8b-8192" # Default model name
198
+ temperature = 0.7 # Default temperature
199
+ max_tokens = 500 # Default max tokens
200
+
201
+ rag_chain, message = create_rag_pipeline(file_paths, model, temperature, max_tokens)
202
+ if rag_chain is None:
203
+ logger.error("Failed to initialize RAG pipeline at startup.")
204
+ # Depending on your preference, you might want to exit or continue. Here, we'll continue.
205
+
206
  # Function to answer questions with input validation and post-processing
207
+ def answer_question(model, temperature, max_tokens, question, feedback):
208
+ # Validate input
209
+ if not is_valid_input(question):
210
+ logger.info("Received invalid input from user.")
211
+ return "Please provide a valid question or input containing meaningful text.", ""
212
+
213
+ # Check if the RAG pipeline needs to be re-initialized (e.g., if model or parameters have changed)
214
+ # For simplicity, we'll assume the pipeline remains the same. For dynamic models, implement re-initialization here.
215
+
216
  try:
217
+ answer = rag_chain.run(question)
218
+ logger.info("Question answered successfully.")
219
+ # Post-process to ensure the answer ends with complete sentences
220
+ complete_answer = ensure_complete_sentences(answer)
221
 
222
+ # Handle feedback
223
+ feedback_response = handle_feedback(feedback)
 
 
224
 
225
+ return complete_answer, feedback_response
226
+ except Exception as e_inner:
227
+ logger.error(f"Error during RAG pipeline execution: {e_inner}")
228
+ logger.debug("Exception details:", exc_info=True)
229
+ return f"Error during RAG pipeline execution: {e_inner}", ""
230
+
231
+ # Gradio Interface with Feedback Mechanism
 
 
 
 
 
 
 
 
 
 
 
 
 
 
232
  def gradio_interface(model, temperature, max_tokens, question, feedback):
233
+ # Optionally, you can add functionality to update the RAG pipeline if model or parameters change
234
+ # For now, we'll ignore changes to model parameters after initialization
235
+ return answer_question(model, temperature, max_tokens, question, feedback)
236
 
237
  # Define Gradio UI
238
  interface = gr.Interface(
 
240
  inputs=[
241
  gr.Textbox(
242
  label="Model Name",
243
+ value=model,
244
  placeholder="e.g., llama3-8b-8192"
245
  ),
246
  gr.Slider(
 
248
  minimum=0,
249
  maximum=1,
250
  step=0.01,
251
+ value=temperature,
252
  info="Controls the randomness of the response. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic."
253
  ),
254
  gr.Slider(
 
256
  minimum=200,
257
  maximum=2048,
258
  step=1,
259
+ value=max_tokens,
260
  info="Determines the maximum number of tokens in the response. Higher values allow for longer answers."
261
  ),
262
  gr.Textbox(