Spaces:
Runtime error
Runtime error
ariankhalfani
commited on
Update chatbot.py
Browse files- chatbot.py +29 -41
chatbot.py
CHANGED
@@ -167,12 +167,8 @@ readable_patient_data = transform_patient_data(patient_data)
|
|
167 |
# Function to extract details from the input prompt
|
168 |
def extract_details_from_prompt(prompt):
|
169 |
pattern = re.compile(r"(Glaucoma|Cataract) (\d+)", re.IGNORECASE)
|
170 |
-
|
171 |
-
|
172 |
-
condition = match.group(1).capitalize()
|
173 |
-
patient_id = int(match.group(2))
|
174 |
-
return condition, patient_id
|
175 |
-
return None, None
|
176 |
|
177 |
# Function to fetch specific patient data based on the condition and ID
|
178 |
def get_specific_patient_data(patient_data, condition, patient_id):
|
@@ -191,18 +187,25 @@ def get_specific_patient_data(patient_data, condition, patient_id):
|
|
191 |
break
|
192 |
return specific_data
|
193 |
|
194 |
-
#
|
195 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
196 |
if input_type == "Voice":
|
197 |
return gr.update(visible=True), gr.update(visible=False)
|
198 |
else:
|
199 |
return gr.update(visible=False), gr.update(visible=True)
|
200 |
|
201 |
-
#
|
202 |
def cleanup_response(response):
|
203 |
# Extract only the part after "Answer:" and remove any trailing spaces
|
204 |
answer_start = response.find("Answer:")
|
205 |
-
if
|
206 |
response = response[answer_start + len("Answer:"):].strip()
|
207 |
return response
|
208 |
|
@@ -213,38 +216,23 @@ def chatbot(audio, input_type, text):
|
|
213 |
if "error" in transcription:
|
214 |
return "Error transcribing audio: " + transcription["error"], None
|
215 |
query = transcription['text']
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
|
230 |
-
return clean_response, None
|
231 |
-
|
232 |
-
elif input_type == "Text":
|
233 |
-
condition, patient_id = extract_details_from_prompt(text)
|
234 |
-
patient_history = ""
|
235 |
-
if condition and patient_id:
|
236 |
-
patient_history = get_specific_patient_data(patient_data, condition, patient_id)
|
237 |
-
payload = {
|
238 |
-
"inputs": f"role: ophthalmologist assistant patient history: {patient_history} question: {text}"
|
239 |
-
}
|
240 |
-
response = query_huggingface(payload)
|
241 |
-
if isinstance(response, list):
|
242 |
-
raw_response = response[0].get("generated_text", "Sorry, I couldn't generate a response.")
|
243 |
-
else:
|
244 |
-
raw_response = response.get("generated_text", "Sorry, I couldn't generate a response.")
|
245 |
|
246 |
-
|
247 |
-
|
248 |
|
249 |
# Gradio interface for generating voice response
|
250 |
def generate_voice_response(tts_model, text_response):
|
|
|
167 |
# Function to extract details from the input prompt
|
168 |
def extract_details_from_prompt(prompt):
|
169 |
pattern = re.compile(r"(Glaucoma|Cataract) (\d+)", re.IGNORECASE)
|
170 |
+
matches = pattern.findall(prompt)
|
171 |
+
return [(match[0].capitalize(), int(match[1])) for match in matches]
|
|
|
|
|
|
|
|
|
172 |
|
173 |
# Function to fetch specific patient data based on the condition and ID
|
174 |
def get_specific_patient_data(patient_data, condition, patient_id):
|
|
|
187 |
break
|
188 |
return specific_data
|
189 |
|
190 |
+
# Function to aggregate patient history for all mentioned IDs in the question
|
191 |
+
def get_aggregated_patient_history(patient_data, details):
|
192 |
+
history = ""
|
193 |
+
for condition, patient_id in details:
|
194 |
+
history += get_specific_patient_data(patient_data, condition, patient_id) + "\n"
|
195 |
+
return history.strip()
|
196 |
+
|
197 |
+
# Toggle visibility of input elements based on input type
|
198 |
+
def toggle_visibility(input_type):
|
199 |
if input_type == "Voice":
|
200 |
return gr.update(visible=True), gr.update(visible=False)
|
201 |
else:
|
202 |
return gr.update(visible=False), gr.update(visible=True)
|
203 |
|
204 |
+
# Cleanup response text
|
205 |
def cleanup_response(response):
|
206 |
# Extract only the part after "Answer:" and remove any trailing spaces
|
207 |
answer_start = response.find("Answer:")
|
208 |
+
if answer_start != -1:
|
209 |
response = response[answer_start + len("Answer:"):].strip()
|
210 |
return response
|
211 |
|
|
|
216 |
if "error" in transcription:
|
217 |
return "Error transcribing audio: " + transcription["error"], None
|
218 |
query = transcription['text']
|
219 |
+
else:
|
220 |
+
query = text
|
221 |
+
|
222 |
+
details = extract_details_from_prompt(query)
|
223 |
+
patient_history = get_aggregated_patient_history(patient_data, details)
|
224 |
+
|
225 |
+
payload = {
|
226 |
+
"inputs": f"role: ophthalmologist assistant patient history: {patient_history} question: {query}"
|
227 |
+
}
|
228 |
+
response = query_huggingface(payload)
|
229 |
+
if isinstance(response, list):
|
230 |
+
raw_response = response[0].get("generated_text", "Sorry, I couldn't generate a response.")
|
231 |
+
else:
|
232 |
+
raw_response = response.get("generated_text", "Sorry, I couldn't generate a response.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
233 |
|
234 |
+
clean_response = cleanup_response(raw_response)
|
235 |
+
return clean_response, None
|
236 |
|
237 |
# Gradio interface for generating voice response
|
238 |
def generate_voice_response(tts_model, text_response):
|