ariankhalfani commited on
Commit
6344cc8
·
verified ·
1 Parent(s): 68c4c07

Create chatbot.py

Browse files
Files changed (1) hide show
  1. chatbot.py +259 -0
chatbot.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import requests
3
+ import logging
4
+ import gradio as gr
5
+ from dotenv import load_dotenv
6
+ from pydub import AudioSegment
7
+ from io import BytesIO
8
+ import time
9
+ import sqlite3
10
+ import re
11
+
12
+ # Configure logging
13
+ logging.basicConfig(level=logging.DEBUG)
14
+
15
+ # Configure Hugging Face API URL and headers for Meta-Llama-3-70B-Instruct
16
+ api_url = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-70B-Instruct"
17
+ huggingface_api_key = os.getenv("HF_API_TOKEN")
18
+ headers = {"Authorization": f"Bearer {huggingface_api_key}"}
19
+
20
+ # Function to query the Hugging Face model
21
+ def query_huggingface(payload):
22
+ logging.debug(f"Querying model with payload: {payload}")
23
+ response = requests.post(api_url, headers=headers, json=payload)
24
+ logging.debug(f"Received response: {response.status_code} {response.text}")
25
+ return response.json()
26
+
27
+ # Function to query the Whisper model for audio transcription
28
+ def query_whisper(audio_path):
29
+ API_URL_WHISPER = "https://api-inference.huggingface.co/models/openai/whisper-large-v2"
30
+ headers = {"Authorization": f"Bearer {huggingface_api_key}"}
31
+ MAX_RETRIES = 5
32
+ RETRY_DELAY = 1 # seconds
33
+
34
+ for attempt in range(MAX_RETRIES):
35
+ try:
36
+ if not os.path.exists(audio_path):
37
+ raise FileNotFoundError(f"Audio file does not exist: {audio_path}")
38
+
39
+ with open(audio_path, "rb") as f:
40
+ data = f.read()
41
+
42
+ response = requests.post(API_URL_WHISPER, headers=headers, data=data)
43
+ response.raise_for_status()
44
+ return response.json()
45
+ except Exception as e:
46
+ if attempt < MAX_RETRIES - 1:
47
+ time.sleep(RETRY_DELAY)
48
+ else:
49
+ return {"error": str(e)}
50
+
51
+ # Function to generate speech from text using Nithu TTS
52
+ def generate_speech_nithu(answer):
53
+ API_URL_TTS_Nithu = "https://api-inference.huggingface.co/models/Nithu/text-to-speech"
54
+ headers = {"Authorization": f"Bearer {huggingface_api_key}"}
55
+ payload = {"inputs": answer}
56
+ MAX_RETRIES = 5
57
+ RETRY_DELAY = 1 # seconds
58
+
59
+ for attempt in range(MAX_RETRIES):
60
+ try:
61
+ response = requests.post(API_URL_TTS_Nithu, headers=headers, json=payload)
62
+ response.raise_for_status()
63
+ audio_segment = AudioSegment.from_file(BytesIO(response.content), format="flac")
64
+ audio_file_path = "/tmp/answer_nithu.wav"
65
+ audio_segment.export(audio_file_path, format="wav")
66
+ return audio_file_path
67
+ except Exception as e:
68
+ if attempt < MAX_RETRIES - 1:
69
+ time.sleep(RETRY_DELAY)
70
+ else:
71
+ return {"error": str(e)}
72
+
73
+ # Function to generate speech from text using Ryan TTS
74
+ def generate_speech_ryan(answer):
75
+ API_URL_TTS_Ryan = "https://api-inference.huggingface.co/models/espnet/english_male_ryanspeech_fastspeech2"
76
+ headers = {"Authorization": f"Bearer {huggingface_api_key}"}
77
+ payload = {"inputs": answer}
78
+ MAX_RETRIES = 5
79
+ RETRY_DELAY = 1 # seconds
80
+
81
+ for attempt in range(MAX_RETRIES):
82
+ try:
83
+ response = requests.post(API_URL_TTS_Ryan, headers=headers, json=payload)
84
+ response.raise_for_status()
85
+ response_json = response.json()
86
+ audio = response_json.get("audio", None)
87
+ sampling_rate = response_json.get("sampling_rate", None)
88
+ if audio and sampling_rate:
89
+ audio_segment = AudioSegment.from_file(BytesIO(audio), format="wav")
90
+ audio_file_path = "/tmp/answer_ryan.wav"
91
+ audio_segment.export(audio_file_path, format="wav")
92
+ return audio_file_path
93
+ else:
94
+ raise ValueError("Invalid response format from Ryan TTS API")
95
+ except Exception as e:
96
+ if attempt < MAX_RETRIES - 1:
97
+ time.sleep(RETRY_DELAY)
98
+ else:
99
+ return {"error": str(e)}
100
+
101
+ # Function to fetch patient data from both databases
102
+ def fetch_patient_data(cataract_db_path, glaucoma_db_path):
103
+ patient_data = {}
104
+
105
+ # Fetch data from cataract_results table
106
+ try:
107
+ conn = sqlite3.connect(cataract_db_path)
108
+ cursor = conn.cursor()
109
+ cursor.execute("SELECT * FROM cataract_results")
110
+ cataract_data = cursor.fetchall()
111
+ conn.close()
112
+ patient_data['cataract_results'] = cataract_data
113
+ except Exception as e:
114
+ patient_data['cataract_results'] = f"Error fetching cataract results: {str(e)}"
115
+
116
+ # Fetch data from results table (glaucoma)
117
+ try:
118
+ conn = sqlite3.connect(glaucoma_db_path)
119
+ cursor = conn.cursor()
120
+ cursor.execute("SELECT * FROM results")
121
+ glaucoma_data = cursor.fetchall()
122
+ conn.close()
123
+ patient_data['results'] = glaucoma_data
124
+ except Exception as e:
125
+ patient_data['results'] = f"Error fetching glaucoma results: {str(e)}"
126
+
127
+ return patient_data
128
+
129
+ # Function to transform fetched data into a readable format
130
+ def transform_patient_data(patient_data):
131
+ readable_data = "Readable Patient Data:\n\n"
132
+
133
+ if 'cataract_results' in patient_data:
134
+ if isinstance(patient_data['cataract_results'], str):
135
+ readable_data += patient_data['cataract_results'] + "\n"
136
+ else:
137
+ readable_data += "Cataract Results:\n"
138
+ for row in patient_data['cataract_results']:
139
+ if len(row) >= 6:
140
+ readable_data += f"Patient ID: {row[0]}, Red Quantity: {row[2]}, Green Quantity: {row[3]}, Blue Quantity: {row[4]}, Stage: {row[5]}\n"
141
+ else:
142
+ readable_data += "Error: Incomplete data row in cataract results\n"
143
+ readable_data += "\n"
144
+
145
+ if 'results' in patient_data:
146
+ if isinstance(patient_data['results'], str):
147
+ readable_data += patient_data['results'] + "\n"
148
+ else:
149
+ readable_data += "Glaucoma Results:\n"
150
+ for row in patient_data['results']:
151
+ if len(row) >= 7:
152
+ readable_data += f"Patient ID: {row[0]}, Cup Area: {row[2]}, Disk Area: {row[3]}, Rim Area: {row[4]}, Rim to Disc Line Ratio: {row[5]}, DDLS Stage: {row[6]}\n"
153
+ else:
154
+ readable_data += "Error: Incomplete data row in glaucoma results\n"
155
+ readable_data += "\n"
156
+
157
+ return readable_data
158
+
159
+ # Paths to your databases
160
+ cataract_db_path = 'cataract_results.db'
161
+ glaucoma_db_path = 'glaucoma_results.db'
162
+
163
+ # Fetch and transform patient data
164
+ patient_data = fetch_patient_data(cataract_db_path, glaucoma_db_path)
165
+ readable_patient_data = transform_patient_data(patient_data)
166
+
167
+ # Function to extract details from the input prompt
168
+ def extract_details_from_prompt(prompt):
169
+ pattern = re.compile(r"(Glaucoma|Cataract) (\d+)", re.IGNORECASE)
170
+ match = pattern.search(prompt)
171
+ if match:
172
+ condition = match.group(1).capitalize()
173
+ patient_id = int(match.group(2))
174
+ return condition, patient_id
175
+ return None, None
176
+
177
+ # Function to fetch specific patient data based on the condition and ID
178
+ def get_specific_patient_data(patient_data, condition, patient_id):
179
+ specific_data = ""
180
+ if condition == "Cataract":
181
+ specific_data = "Cataract Results:\n"
182
+ for row in patient_data.get('cataract_results', []):
183
+ if isinstance(row, tuple) and row[0] == patient_id:
184
+ specific_data += f"Patient ID: {row[0]}, Red Quantity: {row[2]}, Green Quantity: {row[3]}, Blue Quantity: {row[4]}, Stage: {row[5]}\n"
185
+ break
186
+ elif condition == "Glaucoma":
187
+ specific_data = "Glaucoma Results:\n"
188
+ for row in patient_data.get('results', []):
189
+ if isinstance(row, tuple) and row[0] == patient_id:
190
+ specific_data += f"Patient ID: {row[0]}, Cup Area: {row[2]}, Disk Area: {row[3]}, Rim Area: {row[4]}, Rim to Disc Line Ratio: {row[5]}, DDLS Stage: {row[6]}\n"
191
+ break
192
+ return specific_data
193
+
194
+ # Toggle visibility of input fields based on the selected input type
195
+ def toggle_input_visibility(input_type):
196
+ if input_type == "Voice":
197
+ return gr.update(visible=True), gr.update(visible=False)
198
+ else:
199
+ return gr.update(visible=False), gr.update(visible=True)
200
+
201
+ # Function to clean up the response text
202
+ def cleanup_response(response):
203
+ # Extract only the part after "Answer:" and remove any trailing spaces
204
+ answer_start = response.find("Answer:")
205
+ if answer_start != -1:
206
+ response = response[answer_start + len("Answer:"):].strip()
207
+ return response
208
+
209
+ # Gradio interface for the chatbot
210
+ def chatbot(audio, input_type, text):
211
+ if input_type == "Voice":
212
+ transcription = query_whisper(audio.name)
213
+ if "error" in transcription:
214
+ return "Error transcribing audio: " + transcription["error"], None
215
+ query = transcription['text']
216
+ condition, patient_id = extract_details_from_prompt(query)
217
+ if condition and patient_id:
218
+ patient_history = get_specific_patient_data(patient_data, condition, patient_id)
219
+ payload = {
220
+ "inputs": f"role: ophthalmologist assistant patient history: {patient_history} question: {query}"
221
+ }
222
+ response = query_huggingface(payload)
223
+ if isinstance(response, list):
224
+ raw_response = response[0].get("generated_text", "Sorry, I couldn't generate a response.")
225
+ else:
226
+ raw_response = response.get("generated_text", "Sorry, I couldn't generate a response.")
227
+
228
+ clean_response = cleanup_response(raw_response)
229
+ return clean_response, None
230
+
231
+ elif input_type == "Text":
232
+ condition, patient_id = extract_details_from_prompt(text)
233
+ if condition and patient_id:
234
+ patient_history = get_specific_patient_data(patient_data, condition, patient_id)
235
+ payload = {
236
+ "inputs": f"role: ophthalmologist assistant patient history: {patient_history} question: {text}"
237
+ }
238
+ response = query_huggingface(payload)
239
+ if isinstance(response, list):
240
+ raw_response = response[0].get("generated_text", "Sorry, I couldn't generate a response.")
241
+ else:
242
+ raw_response = response.get("generated_text", "Sorry, I couldn't generate a response.")
243
+
244
+ clean_response = cleanup_response(raw_response)
245
+ return clean_response, None
246
+
247
+ # Gradio interface for generating voice response
248
+ def generate_voice_response(tts_model, text_response):
249
+ if tts_model == "Nithu (Custom)":
250
+ audio_file_path = generate_speech_nithu(text_response)
251
+ return audio_file_path, None
252
+ elif tts_model == "Ryan (ESPnet)":
253
+ audio_file_path = generate_speech_ryan(text_response)
254
+ return audio_file_path, None
255
+ else:
256
+ return None, None
257
+
258
+ def update_patient_history():
259
+ return readable_patient_data