ariankhalfani commited on
Commit
7c98099
·
verified ·
1 Parent(s): 8ad240b

Update chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +29 -41
chatbot.py CHANGED
@@ -167,12 +167,8 @@ readable_patient_data = transform_patient_data(patient_data)
167
  # Function to extract details from the input prompt
168
  def extract_details_from_prompt(prompt):
169
  pattern = re.compile(r"(Glaucoma|Cataract) (\d+)", re.IGNORECASE)
170
- match = pattern.search(prompt)
171
- if match:
172
- condition = match.group(1).capitalize()
173
- patient_id = int(match.group(2))
174
- return condition, patient_id
175
- return None, None
176
 
177
  # Function to fetch specific patient data based on the condition and ID
178
  def get_specific_patient_data(patient_data, condition, patient_id):
@@ -191,18 +187,25 @@ def get_specific_patient_data(patient_data, condition, patient_id):
191
  break
192
  return specific_data
193
 
194
- # Toggle visibility of input fields based on the selected input type
195
- def toggle_input_visibility(input_type):
 
 
 
 
 
 
 
196
  if input_type == "Voice":
197
  return gr.update(visible=True), gr.update(visible=False)
198
  else:
199
  return gr.update(visible=False), gr.update(visible=True)
200
 
201
- # Function to clean up the response text
202
  def cleanup_response(response):
203
  # Extract only the part after "Answer:" and remove any trailing spaces
204
  answer_start = response.find("Answer:")
205
- if (answer_start != -1):
206
  response = response[answer_start + len("Answer:"):].strip()
207
  return response
208
 
@@ -213,38 +216,23 @@ def chatbot(audio, input_type, text):
213
  if "error" in transcription:
214
  return "Error transcribing audio: " + transcription["error"], None
215
  query = transcription['text']
216
- condition, patient_id = extract_details_from_prompt(query)
217
- patient_history = ""
218
- if condition and patient_id:
219
- patient_history = get_specific_patient_data(patient_data, condition, patient_id)
220
- payload = {
221
- "inputs": f"role: ophthalmologist assistant patient history: {patient_history} question: {query}"
222
- }
223
- response = query_huggingface(payload)
224
- if isinstance(response, list):
225
- raw_response = response[0].get("generated_text", "Sorry, I couldn't generate a response.")
226
- else:
227
- raw_response = response.get("generated_text", "Sorry, I couldn't generate a response.")
228
-
229
- clean_response = cleanup_response(raw_response)
230
- return clean_response, None
231
-
232
- elif input_type == "Text":
233
- condition, patient_id = extract_details_from_prompt(text)
234
- patient_history = ""
235
- if condition and patient_id:
236
- patient_history = get_specific_patient_data(patient_data, condition, patient_id)
237
- payload = {
238
- "inputs": f"role: ophthalmologist assistant patient history: {patient_history} question: {text}"
239
- }
240
- response = query_huggingface(payload)
241
- if isinstance(response, list):
242
- raw_response = response[0].get("generated_text", "Sorry, I couldn't generate a response.")
243
- else:
244
- raw_response = response.get("generated_text", "Sorry, I couldn't generate a response.")
245
 
246
- clean_response = cleanup_response(raw_response)
247
- return clean_response, None
248
 
249
  # Gradio interface for generating voice response
250
  def generate_voice_response(tts_model, text_response):
 
167
  # Function to extract details from the input prompt
168
  def extract_details_from_prompt(prompt):
169
  pattern = re.compile(r"(Glaucoma|Cataract) (\d+)", re.IGNORECASE)
170
+ matches = pattern.findall(prompt)
171
+ return [(match[0].capitalize(), int(match[1])) for match in matches]
 
 
 
 
172
 
173
  # Function to fetch specific patient data based on the condition and ID
174
  def get_specific_patient_data(patient_data, condition, patient_id):
 
187
  break
188
  return specific_data
189
 
190
+ # Function to aggregate patient history for all mentioned IDs in the question
191
+ def get_aggregated_patient_history(patient_data, details):
192
+ history = ""
193
+ for condition, patient_id in details:
194
+ history += get_specific_patient_data(patient_data, condition, patient_id) + "\n"
195
+ return history.strip()
196
+
197
+ # Toggle visibility of input elements based on input type
198
+ def toggle_visibility(input_type):
199
  if input_type == "Voice":
200
  return gr.update(visible=True), gr.update(visible=False)
201
  else:
202
  return gr.update(visible=False), gr.update(visible=True)
203
 
204
+ # Cleanup response text
205
  def cleanup_response(response):
206
  # Extract only the part after "Answer:" and remove any trailing spaces
207
  answer_start = response.find("Answer:")
208
+ if answer_start != -1:
209
  response = response[answer_start + len("Answer:"):].strip()
210
  return response
211
 
 
216
  if "error" in transcription:
217
  return "Error transcribing audio: " + transcription["error"], None
218
  query = transcription['text']
219
+ else:
220
+ query = text
221
+
222
+ details = extract_details_from_prompt(query)
223
+ patient_history = get_aggregated_patient_history(patient_data, details)
224
+
225
+ payload = {
226
+ "inputs": f"role: ophthalmologist assistant patient history: {patient_history} question: {query}"
227
+ }
228
+ response = query_huggingface(payload)
229
+ if isinstance(response, list):
230
+ raw_response = response[0].get("generated_text", "Sorry, I couldn't generate a response.")
231
+ else:
232
+ raw_response = response.get("generated_text", "Sorry, I couldn't generate a response.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
 
234
+ clean_response = cleanup_response(raw_response)
235
+ return clean_response, None
236
 
237
  # Gradio interface for generating voice response
238
  def generate_voice_response(tts_model, text_response):