Spaces:
Sleeping
Sleeping
import os | |
import requests | |
import logging | |
import gradio as gr | |
from dotenv import load_dotenv | |
from pydub import AudioSegment | |
from io import BytesIO | |
import time | |
import sqlite3 | |
import re | |
# Configure logging | |
logging.basicConfig(level=logging.DEBUG) | |
# Load environment variables | |
load_dotenv() | |
# Configure Hugging Face API URL and headers for Meta-Llama-3-70B-Instruct | |
api_url = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-70B-Instruct" | |
huggingface_api_key = os.getenv("HF_API_TOKEN") | |
headers = {"Authorization": f"Bearer {huggingface_api_key}"} | |
# Function to query the Hugging Face model | |
def query_huggingface(payload): | |
logging.debug(f"Querying model with payload: {payload}") | |
response = requests.post(api_url, headers=headers, json=payload) | |
logging.debug(f"Received response: {response.status_code} {response.text}") | |
return response.json() | |
# Function to query the Whisper model for audio transcription | |
def query_whisper(audio_path): | |
API_URL_WHISPER = "https://api-inference.huggingface.co/models/openai/whisper-large-v2" | |
headers = {"Authorization": f"Bearer {huggingface_api_key}"} | |
MAX_RETRIES = 5 | |
RETRY_DELAY = 1 # seconds | |
for attempt in range(MAX_RETRIES): | |
try: | |
if not os.path.exists(audio_path): | |
raise FileNotFoundError(f"Audio file does not exist: {audio_path}") | |
with open(audio_path, "rb") as f: | |
data = f.read() | |
response = requests.post(API_URL_WHISPER, headers=headers, data=data) | |
response.raise_for_status() | |
return response.json() | |
except Exception as e: | |
if attempt < MAX_RETRIES - 1: | |
time.sleep(RETRY_DELAY) | |
else: | |
return {"error": str(e)} | |
# Function to generate speech from text using Nithu TTS | |
def generate_speech_nithu(answer): | |
API_URL_TTS_Nithu = "https://api-inference.huggingface.co/models/Nithu/text-to-speech" | |
headers = {"Authorization": f"Bearer {huggingface_api_key}"} | |
payload = {"inputs": answer} | |
MAX_RETRIES = 5 | |
RETRY_DELAY = 1 # seconds | |
for attempt in range(MAX_RETRIES): | |
try: | |
response = requests.post(API_URL_TTS_Nithu, headers=headers, json=payload) | |
response.raise_for_status() | |
audio_segment = AudioSegment.from_file(BytesIO(response.content), format="flac") | |
audio_file_path = "/tmp/answer_nithu.wav" | |
audio_segment.export(audio_file_path, format="wav") | |
return audio_file_path | |
except Exception as e: | |
if attempt < MAX_RETRIES - 1: | |
time.sleep(RETRY_DELAY) | |
else: | |
return {"error": str(e)} | |
# Function to generate speech from text using Ryan TTS | |
def generate_speech_ryan(answer): | |
API_URL_TTS_Ryan = "https://api-inference.huggingface.co/models/espnet/english_male_ryanspeech_fastspeech2" | |
headers = {"Authorization": f"Bearer {huggingface_api_key}"} | |
payload = {"inputs": answer} | |
MAX_RETRIES = 5 | |
RETRY_DELAY = 1 # seconds | |
for attempt in range(MAX_RETRIES): | |
try: | |
response = requests.post(API_URL_TTS_Ryan, headers=headers, json=payload) | |
response.raise_for_status() | |
response_json = response.json() | |
audio = response_json.get("audio", None) | |
sampling_rate = response_json.get("sampling_rate", None) | |
if audio and sampling_rate: | |
audio_segment = AudioSegment.from_file(BytesIO(audio), format="wav") | |
audio_file_path = "/tmp/answer_ryan.wav" | |
audio_segment.export(audio_file_path, format="wav") | |
return audio_file_path | |
else: | |
raise ValueError("Invalid response format from Ryan TTS API") | |
except Exception as e: | |
if attempt < MAX_RETRIES - 1: | |
time.sleep(RETRY_DELAY) | |
else: | |
return {"error": str(e)} | |
# Function to fetch patient data from both databases | |
def fetch_patient_data(cataract_db_path, glaucoma_db_path): | |
patient_data = {} | |
# Fetch data from cataract_results table | |
try: | |
conn = sqlite3.connect(cataract_db_path) | |
cursor = conn.cursor() | |
cursor.execute("SELECT * FROM cataract_results") | |
cataract_data = cursor.fetchall() | |
conn.close() | |
patient_data['cataract_results'] = cataract_data | |
except Exception as e: | |
patient_data['cataract_results'] = f"Error fetching cataract results: {str(e)}" | |
# Fetch data from results table (glaucoma) | |
try: | |
conn = sqlite3.connect(glaucoma_db_path) | |
cursor = conn.cursor() | |
cursor.execute("SELECT * FROM results") | |
glaucoma_data = cursor.fetchall() | |
conn.close() | |
patient_data['results'] = glaucoma_data | |
except Exception as e: | |
patient_data['results'] = f"Error fetching glaucoma results: {str(e)}" | |
return patient_data | |
# Function to transform fetched data into a readable format | |
def transform_patient_data(patient_data): | |
readable_data = "Readable Patient Data:\n\n" | |
if 'cataract_results' in patient_data: | |
if isinstance(patient_data['cataract_results'], str): | |
readable_data += patient_data['cataract_results'] + "\n" | |
else: | |
readable_data += "Cataract Results:\n" | |
for row in patient_data['cataract_results']: | |
if len(row) >= 6: | |
readable_data += f"Patient ID: {row[0]}, Red Quantity: {row[2]}, Green Quantity: {row[3]}, Blue Quantity: {row[4]}, Stage: {row[5]}\n" | |
else: | |
readable_data += "Error: Incomplete data row in cataract results\n" | |
readable_data += "\n" | |
if 'results' in patient_data: | |
if isinstance(patient_data['results'], str): | |
readable_data += patient_data['results'] + "\n" | |
else: | |
readable_data += "Glaucoma Results:\n" | |
for row in patient_data['results']: | |
if len(row) >= 7: | |
readable_data += f"Patient ID: {row[0]}, Cup Area: {row[2]}, Disk Area: {row[3]}, Rim Area: {row[4]}, Rim to Disc Line Ratio: {row[5]}, DDLS Stage: {row[6]}\n" | |
else: | |
readable_data += "Error: Incomplete data row in glaucoma results\n" | |
readable_data += "\n" | |
return readable_data | |
# Paths to your databases | |
cataract_db_path = 'cataract_results.db' | |
glaucoma_db_path = 'glaucoma_results.db' | |
# Fetch and transform patient data | |
patient_data = fetch_patient_data(cataract_db_path, glaucoma_db_path) | |
readable_patient_data = transform_patient_data(patient_data) | |
# Function to extract details from the input prompt | |
def extract_details_from_prompt(prompt): | |
pattern = re.compile(r"(Glaucoma|Cataract) (\d+)", re.IGNORECASE) | |
matches = pattern.findall(prompt) | |
return [(match[0].capitalize(), int(match[1])) for match in matches] | |
# Function to fetch specific patient data based on the condition and ID | |
def get_specific_patient_data(patient_data, condition, patient_id): | |
specific_data = "" | |
if condition == "Cataract": | |
specific_data = "Cataract Results:\n" | |
for row in patient_data.get('cataract_results', []): | |
if isinstance(row, tuple) and row[0] == patient_id: | |
specific_data += f"Patient ID: {row[0]}, Red Quantity: {row[2]}, Green Quantity: {row[3]}, Blue Quantity: {row[4]}, Stage: {row[5]}\n" | |
break | |
elif condition == "Glaucoma": | |
specific_data = "Glaucoma Results:\n" | |
for row in patient_data.get('results', []): | |
if isinstance(row, tuple) and row[0] == patient_id: | |
specific_data += f"Patient ID: {row[0]}, Cup Area: {row[2]}, Disk Area: {row[3]}, Rim Area: {row[4]}, Rim to Disc Line Ratio: {row[5]}, DDLS Stage: {row[6]}\n" | |
break | |
return specific_data | |
# Function to aggregate patient history for all mentioned IDs in the question | |
def get_aggregated_patient_history(patient_data, details): | |
history = "" | |
for condition, patient_id in details: | |
history += get_specific_patient_data(patient_data, condition, patient_id) + "\n" | |
return history.strip() | |
# Toggle visibility of input elements based on input type | |
def toggle_visibility(input_type): | |
if input_type == "Voice": | |
return gr.update(visible=True), gr.update(visible(False)) | |
else: | |
return gr.update(visible=False), gr.update(visible(True)) | |
def cleanup_response(response): | |
# Extract only the part after "Answer:" and remove any trailing spaces | |
answer_start = response.find("Answer:") | |
if answer_start != -1: | |
response = response[answer_start + len("Answer:"):].strip() | |
return response | |
def chatbot(audio, input_type, text): | |
if input_type == "Voice": | |
transcription = query_whisper(audio.name) | |
if "error" in transcription: | |
return "Error transcribing audio: " + transcription["error"], None | |
query = transcription['text'] | |
else: | |
query = text | |
# Extract details from the prompt | |
details = extract_details_from_prompt(query) | |
# Get aggregated patient history based on the extracted details | |
patient_history = get_aggregated_patient_history(patient_data, details) | |
# Create the payload with the patient history and the user's query | |
payload = { | |
"inputs": f"role: ophthalmologist assistant patient history: {patient_history} question: {query}" | |
} | |
logging.debug(f"Raw input to the LLM: {payload['inputs']}") | |
# Query the Hugging Face model with the payload | |
response = query_huggingface(payload) | |
if isinstance(response, list): | |
raw_response = response[0].get("generated_text", "Sorry, I couldn't generate a response.") | |
else: | |
raw_response = response.get("generated_text", "Sorry, I couldn't generate a response.") | |
logging.debug(f"Raw output from the LLM: {raw_response}") | |
return raw_response, None | |
# Gradio interface for generating voice response | |
def generate_voice_response(tts_model, text_response): | |
if tts_model == "Nithu (Custom)": | |
audio_file_path = generate_speech_nithu(text_response) | |
return audio_file_path, None | |
elif tts_model == "Ryan (ESPnet)": | |
audio_file_path = generate_speech_ryan(text_response) | |
return audio_file_path, None | |
else: | |
return None, None | |
# Function to update patient history in the interface | |
def update_patient_history(): | |
return readable_patient_data |