nexus-main / chatbot.py
ariankhalfani's picture
Update chatbot.py
0119dc3 verified
raw
history blame
10.4 kB
import os
import requests
import logging
import gradio as gr
from dotenv import load_dotenv
from pydub import AudioSegment
from io import BytesIO
import time
import sqlite3
import re
# Configure logging
logging.basicConfig(level=logging.DEBUG)
# Load environment variables
load_dotenv()
# Configure Hugging Face API URL and headers for Meta-Llama-3-70B-Instruct
api_url = "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-70B-Instruct"
huggingface_api_key = os.getenv("HF_API_TOKEN")
headers = {"Authorization": f"Bearer {huggingface_api_key}"}
# Function to query the Hugging Face model
def query_huggingface(payload):
logging.debug(f"Querying model with payload: {payload}")
response = requests.post(api_url, headers=headers, json=payload)
logging.debug(f"Received response: {response.status_code} {response.text}")
return response.json()
# Function to query the Whisper model for audio transcription
def query_whisper(audio_path):
API_URL_WHISPER = "https://api-inference.huggingface.co/models/openai/whisper-large-v2"
headers = {"Authorization": f"Bearer {huggingface_api_key}"}
MAX_RETRIES = 5
RETRY_DELAY = 1 # seconds
for attempt in range(MAX_RETRIES):
try:
if not os.path.exists(audio_path):
raise FileNotFoundError(f"Audio file does not exist: {audio_path}")
with open(audio_path, "rb") as f:
data = f.read()
response = requests.post(API_URL_WHISPER, headers=headers, data=data)
response.raise_for_status()
return response.json()
except Exception as e:
if attempt < MAX_RETRIES - 1:
time.sleep(RETRY_DELAY)
else:
return {"error": str(e)}
# Function to generate speech from text using Nithu TTS
def generate_speech_nithu(answer):
API_URL_TTS_Nithu = "https://api-inference.huggingface.co/models/Nithu/text-to-speech"
headers = {"Authorization": f"Bearer {huggingface_api_key}"}
payload = {"inputs": answer}
MAX_RETRIES = 5
RETRY_DELAY = 1 # seconds
for attempt in range(MAX_RETRIES):
try:
response = requests.post(API_URL_TTS_Nithu, headers=headers, json=payload)
response.raise_for_status()
audio_segment = AudioSegment.from_file(BytesIO(response.content), format="flac")
audio_file_path = "/tmp/answer_nithu.wav"
audio_segment.export(audio_file_path, format="wav")
return audio_file_path
except Exception as e:
if attempt < MAX_RETRIES - 1:
time.sleep(RETRY_DELAY)
else:
return {"error": str(e)}
# Function to generate speech from text using Ryan TTS
def generate_speech_ryan(answer):
API_URL_TTS_Ryan = "https://api-inference.huggingface.co/models/espnet/english_male_ryanspeech_fastspeech2"
headers = {"Authorization": f"Bearer {huggingface_api_key}"}
payload = {"inputs": answer}
MAX_RETRIES = 5
RETRY_DELAY = 1 # seconds
for attempt in range(MAX_RETRIES):
try:
response = requests.post(API_URL_TTS_Ryan, headers=headers, json=payload)
response.raise_for_status()
response_json = response.json()
audio = response_json.get("audio", None)
sampling_rate = response_json.get("sampling_rate", None)
if audio and sampling_rate:
audio_segment = AudioSegment.from_file(BytesIO(audio), format="wav")
audio_file_path = "/tmp/answer_ryan.wav"
audio_segment.export(audio_file_path, format="wav")
return audio_file_path
else:
raise ValueError("Invalid response format from Ryan TTS API")
except Exception as e:
if attempt < MAX_RETRIES - 1:
time.sleep(RETRY_DELAY)
else:
return {"error": str(e)}
# Function to fetch patient data from both databases
def fetch_patient_data(cataract_db_path, glaucoma_db_path):
patient_data = {}
# Fetch data from cataract_results table
try:
conn = sqlite3.connect(cataract_db_path)
cursor = conn.cursor()
cursor.execute("SELECT * FROM cataract_results")
cataract_data = cursor.fetchall()
conn.close()
patient_data['cataract_results'] = cataract_data
except Exception as e:
patient_data['cataract_results'] = f"Error fetching cataract results: {str(e)}"
# Fetch data from results table (glaucoma)
try:
conn = sqlite3.connect(glaucoma_db_path)
cursor = conn.cursor()
cursor.execute("SELECT * FROM results")
glaucoma_data = cursor.fetchall()
conn.close()
patient_data['results'] = glaucoma_data
except Exception as e:
patient_data['results'] = f"Error fetching glaucoma results: {str(e)}"
return patient_data
# Function to transform fetched data into a readable format
def transform_patient_data(patient_data):
readable_data = "Readable Patient Data:\n\n"
if 'cataract_results' in patient_data:
if isinstance(patient_data['cataract_results'], str):
readable_data += patient_data['cataract_results'] + "\n"
else:
readable_data += "Cataract Results:\n"
for row in patient_data['cataract_results']:
if len(row) >= 6:
readable_data += f"Patient ID: {row[0]}, Red Quantity: {row[2]}, Green Quantity: {row[3]}, Blue Quantity: {row[4]}, Stage: {row[5]}\n"
else:
readable_data += "Error: Incomplete data row in cataract results\n"
readable_data += "\n"
if 'results' in patient_data:
if isinstance(patient_data['results'], str):
readable_data += patient_data['results'] + "\n"
else:
readable_data += "Glaucoma Results:\n"
for row in patient_data['results']:
if len(row) >= 7:
readable_data += f"Patient ID: {row[0]}, Cup Area: {row[2]}, Disk Area: {row[3]}, Rim Area: {row[4]}, Rim to Disc Line Ratio: {row[5]}, DDLS Stage: {row[6]}\n"
else:
readable_data += "Error: Incomplete data row in glaucoma results\n"
readable_data += "\n"
return readable_data
# Paths to your databases
cataract_db_path = 'cataract_results.db'
glaucoma_db_path = 'glaucoma_results.db'
# Fetch and transform patient data
patient_data = fetch_patient_data(cataract_db_path, glaucoma_db_path)
readable_patient_data = transform_patient_data(patient_data)
# Function to extract details from the input prompt
def extract_details_from_prompt(prompt):
pattern = re.compile(r"(Glaucoma|Cataract) (\d+)", re.IGNORECASE)
matches = pattern.findall(prompt)
return [(match[0].capitalize(), int(match[1])) for match in matches]
# Function to fetch specific patient data based on the condition and ID
def get_specific_patient_data(patient_data, condition, patient_id):
specific_data = ""
if condition == "Cataract":
specific_data = "Cataract Results:\n"
for row in patient_data.get('cataract_results', []):
if isinstance(row, tuple) and row[0] == patient_id:
specific_data += f"Patient ID: {row[0]}, Red Quantity: {row[2]}, Green Quantity: {row[3]}, Blue Quantity: {row[4]}, Stage: {row[5]}\n"
break
elif condition == "Glaucoma":
specific_data = "Glaucoma Results:\n"
for row in patient_data.get('results', []):
if isinstance(row, tuple) and row[0] == patient_id:
specific_data += f"Patient ID: {row[0]}, Cup Area: {row[2]}, Disk Area: {row[3]}, Rim Area: {row[4]}, Rim to Disc Line Ratio: {row[5]}, DDLS Stage: {row[6]}\n"
break
return specific_data
# Function to aggregate patient history for all mentioned IDs in the question
def get_aggregated_patient_history(patient_data, details):
history = ""
for condition, patient_id in details:
history += get_specific_patient_data(patient_data, condition, patient_id) + "\n"
return history.strip()
# Toggle visibility of input elements based on input type
def toggle_visibility(input_type):
if input_type == "Voice":
return gr.update(visible=True), gr.update(visible=False)
else:
return gr.update(visible=False), gr.update(visible=True)
# Cleanup response text
def cleanup_response(response):
# Extract only the part after "Answer:" and remove any trailing spaces
answer_start = response.find("Answer:")
if answer_start != -1:
response = response[answer_start + len("Answer:"):].strip()
return response
# Gradio interface for the chatbot
def chatbot(audio, input_type, text):
if input_type == "Voice":
transcription = query_whisper(audio.name)
if "error" in transcription:
return "Error transcribing audio: " + transcription["error"], None
query = transcription['text']
else:
query = text
details = extract_details_from_prompt(query)
patient_history = get_aggregated_patient_history(patient_data, details)
payload = {
"inputs": f"role: ophthalmologist assistant patient history: {patient_history} question: {query}"
}
logging.debug(f"Raw input to the LLM: {payload['inputs']}")
response = query_huggingface(payload)
if isinstance(response, list):
raw_response = response[0].get("generated_text", "Sorry, I couldn't generate a response.")
else:
raw_response = response.get("generated_text", "Sorry, I couldn't generate a response.")
logging.debug(f"Raw output from the LLM: {raw_response}")
clean_response = cleanup_response(raw_response)
return clean_response, None
# Gradio interface for generating voice response
def generate_voice_response(tts_model, text_response):
if tts_model == "Nithu (Custom)":
audio_file_path = generate_speech_nithu(text_response)
return audio_file_path, None
elif tts_model == "Ryan (ESPnet)":
audio_file_path = generate_speech_ryan(text_response)
return audio_file_path, None
else:
return None, None
# Function to update patient history in the interface
def update_patient_history():
return readable_patient_data