medVedaReportAnalysis / backend_app.py
rishi002's picture
Update backend_app.py
05d79e2 verified
import os
import shutil
import tempfile
import io
import re
import json
import base64
import uuid
from datetime import datetime
from pathlib import Path
import gradio as gr
import torch
from langchain.chains import RetrievalQA
from langchain.document_loaders import PyPDFLoader, DirectoryLoader, TextLoader
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.llms import HuggingFacePipeline
from langchain.prompts import PromptTemplate
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import fitz # PyMuPDF for more robust PDF handling
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from fastapi import FastAPI, File, UploadFile, HTTPException, Request, Form, Query
from fastapi.responses import JSONResponse, HTMLResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from typing import List, Optional, Dict, Any
# Constants
KNOWLEDGE_DIR = "medical_knowledge"
VECTOR_STORE_PATH = "vectorstore"
MODEL_NAME = "microsoft/Phi-3-mini-4k-instruct" # Gated model requiring authentication
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EMBEDDING_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
REPORTS_DIR = "user_reports"
# Get HF token from environment variables (set in HF Spaces secrets)
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
print("Warning: HF_TOKEN not found in environment variables. You may not be able to access gated models.")
# Create necessary directories
os.makedirs(KNOWLEDGE_DIR, exist_ok=True)
os.makedirs(REPORTS_DIR, exist_ok=True)
# Standard medical parameter reference ranges
STANDARD_RANGES = {
"hemoglobin": {"unit": "g/dL", "male": {"min": 13.5, "max": 17.5}, "female": {"min": 12.0, "max": 15.5}},
"hb": {"unit": "g/dL", "male": {"min": 13.5, "max": 17.5}, "female": {"min": 12.0, "max": 15.5}},
"rbc": {"unit": "million/μL", "male": {"min": 4.5, "max": 5.9}, "female": {"min": 4.1, "max": 5.1}},
"wbc": {"unit": "thousand/μL", "min": 4.5, "max": 11.0},
"platelets": {"unit": "thousand/μL", "min": 150, "max": 450},
"glucose": {"unit": "mg/dL", "fasting": {"min": 70, "max": 100}, "random": {"min": 70, "max": 140}},
"hba1c": {"unit": "%", "min": 4.0, "max": 5.7},
"cholesterol": {"unit": "mg/dL", "min": 0, "max": 200},
"ldl": {"unit": "mg/dL", "min": 0, "max": 100},
"hdl": {"unit": "mg/dL", "male": {"min": 40, "max": 999}, "female": {"min": 50, "max": 999}},
"triglycerides": {"unit": "mg/dL", "min": 0, "max": 150},
"ast": {"unit": "U/L", "min": 5, "max": 40},
"alt": {"unit": "U/L", "min": 7, "max": 56},
"creatinine": {"unit": "mg/dL", "male": {"min": 0.7, "max": 1.3}, "female": {"min": 0.6, "max": 1.1}},
"tsh": {"unit": "μIU/mL", "min": 0.4, "max": 4.0},
"t3": {"unit": "ng/dL", "min": 80, "max": 200},
"t4": {"unit": "μg/dL", "min": 5.0, "max": 12.0},
"vitamin_d": {"unit": "ng/mL", "min": 30, "max": 100},
"vitamin_b12": {"unit": "pg/mL", "min": 200, "max": 900},
"iron": {"unit": "μg/dL", "male": {"min": 65, "max": 175}, "female": {"min": 50, "max": 170}},
"ferritin": {"unit": "ng/mL", "male": {"min": 20, "max": 250}, "female": {"min": 10, "max": 120}},
"sodium": {"unit": "mEq/L", "min": 135, "max": 145},
"potassium": {"unit": "mEq/L", "min": 3.5, "max": 5.0},
"calcium": {"unit": "mg/dL", "min": 8.5, "max": 10.2},
"urea": {"unit": "mg/dL", "min": 7, "max": 20},
"uric_acid": {"unit": "mg/dL", "male": {"min": 3.4, "max": 7.0}, "female": {"min": 2.4, "max": 6.0}}
}
# Parameter synonyms - to standardize parameter names across different reports
PARAMETER_SYNONYMS = {
"hemoglobin": ["hb", "hgb", "haemoglobin"],
"rbc": ["red blood cells", "red blood cell count", "erythrocytes"],
"wbc": ["white blood cells", "white blood cell count", "leukocytes", "total leucocyte count", "tlc"],
"platelets": ["platelet count", "thrombocytes", "plt"],
"glucose": ["blood glucose", "blood sugar", "fasting glucose", "fasting blood sugar", "fbs", "rbs"],
"hba1c": ["glycated hemoglobin", "hemoglobin a1c", "glycosylated hemoglobin"],
"cholesterol": ["total cholesterol", "serum cholesterol", "tc"],
"ldl": ["ldl cholesterol", "low density lipoprotein", "ldl-c"],
"hdl": ["hdl cholesterol", "high density lipoprotein", "hdl-c"],
"triglycerides": ["tg", "trigs"],
"ast": ["aspartate aminotransferase", "sgot"],
"alt": ["alanine aminotransferase", "sgpt"],
"creatinine": ["serum creatinine", "cr"],
"tsh": ["thyroid stimulating hormone", "thyrotropin"],
"t3": ["triiodothyronine", "total t3"],
"t4": ["thyroxine", "total t4"],
"vitamin_d": ["25-oh vitamin d", "25-hydroxyvitamin d", "25(oh)d"],
"vitamin_b12": ["cobalamin", "b12"],
"ferritin": ["serum ferritin"],
}
class MedicalReport:
"""Class to represent a single medical report"""
def __init__(self, report_id, report_text, report_date=None, report_name=None, gender=None):
self.id = report_id
self.raw_text = report_text
self.date = report_date or datetime.now()
self.name = report_name or f"Report {report_id[:8]}"
self.gender = gender or "unknown"
self.parameters = {} # Will hold extracted parameters
self.abnormal_parameters = [] # Will hold list of parameters outside reference range
def to_dict(self):
"""Convert report to dictionary for JSON serialization"""
return {
"id": self.id,
"name": self.name,
"date": self.date.isoformat() if isinstance(self.date, datetime) else self.date,
"gender": self.gender,
"parameters": self.parameters,
"abnormal_parameters": self.abnormal_parameters
}
class MedicalReportAnalyzer:
def __init__(self):
self.vector_store = None
self.llm = None
self.qa_chain = None
self.reports = {} # Dictionary to store multiple reports by ID
self.current_report_id = None # ID of the most recently processed report
# Initialize everything
self._load_or_create_vector_store()
self._initialize_llm()
self._setup_qa_chain()
def _load_or_create_vector_store(self):
"""Load existing vector store or create a new one from knowledge documents"""
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
# Check if vector store exists
if os.path.exists(VECTOR_STORE_PATH):
print("Loading existing vector store...")
self.vector_store = FAISS.load_local(VECTOR_STORE_PATH, embeddings)
else:
print("Creating new vector store from documents...")
# Create knowledge directory if it doesn't exist
os.makedirs(KNOWLEDGE_DIR, exist_ok=True)
# Check if there are documents to process
if len(os.listdir(KNOWLEDGE_DIR)) == 0:
print(f"Warning: No documents found in {KNOWLEDGE_DIR}. Please add medical PDFs.")
# Initialize empty vector store
self.vector_store = FAISS.from_texts(["No medical knowledge available yet."], embeddings)
self.vector_store.save_local(VECTOR_STORE_PATH)
return
# Load all PDFs from the knowledge directory
try:
# First try with DirectoryLoader
loader = DirectoryLoader(KNOWLEDGE_DIR, glob="**/*.pdf", loader_cls=PyPDFLoader)
documents = loader.load()
# Split documents into chunks
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=1000,
chunk_overlap=200,
length_function=len
)
chunks = text_splitter.split_documents(documents)
# Create and save the vector store
self.vector_store = FAISS.from_documents(chunks, embeddings)
self.vector_store.save_local(VECTOR_STORE_PATH)
except Exception as e:
print(f"Error loading documents with DirectoryLoader: {str(e)}")
# Initialize with minimal data
self.vector_store = FAISS.from_texts(["No medical knowledge available yet."], embeddings)
self.vector_store.save_local(VECTOR_STORE_PATH)
def _initialize_llm(self):
"""Initialize the language model with HF token authentication"""
print(f"Loading model {MODEL_NAME} on {DEVICE}...")
try:
# Use the HF_TOKEN for authentication
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
token=HF_TOKEN
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
token=HF_TOKEN,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
device_map="auto",
load_in_8bit=DEVICE == "cuda", # Use 8-bit quantization if on CUDA
)
# Create a text generation pipeline
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512,
temperature=0.1,
top_p=0.95,
repetition_penalty=1.1
)
# Create LangChain wrapper around the pipeline
self.llm = HuggingFacePipeline(pipeline=pipe)
except Exception as e:
print(f"Error loading the model: {str(e)}")
print("Falling back to a non-gated model...")
# Fallback to a non-gated model
fallback_model = "google/flan-t5-large"
tokenizer = AutoTokenizer.from_pretrained(fallback_model)
model = AutoModelForCausalLM.from_pretrained(
fallback_model,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
device_map="auto"
)
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
max_new_tokens=512
)
self.llm = HuggingFacePipeline(pipeline=pipe)
def _setup_qa_chain(self):
"""Set up the question-answering chain"""
# Define a custom prompt template for medical analysis
template = """
You are a medical assistant analyzing patient medical reports. Use the following pieces of context to answer the question. If you don't know the answer, just say that you don't know, don't try to make up an answer.
Also summarize you answer strictly in not more than 350 words and keep the language of your answer simple and easy to understand. Make sure you use easy and simple terms for explanation.
Patient Report Summary: {patient_data}
Context from medical knowledge base: {context}
Question: {question}
Answer:
"""
prompt = PromptTemplate(
template=template,
input_variables=["context", "question", "patient_data"]
)
# Create the QA chain
self.qa_chain = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type="stuff",
retriever=self.vector_store.as_retriever(search_kwargs={"k": 5}),
chain_type_kwargs={"prompt": prompt},
return_source_documents=False
)
def remove_header_information(self, text):
"""Remove header information from the report text"""
# Store the original text
original_text = text
# Split the text into lines to analyze
lines = text.split('\n')
# Define patterns to identify header information
header_patterns = [
r'(Name\s*:)',
r'(Patient\s*Name\s*:)',
r'(DOB|Date of Birth\s*:)',
r'(Age\s*:)',
r'(Gender\s*:)',
r'(Lab No\.|Laboratory Number\s*:)',
r'(Patient ID\s*:)',
r'(Report Status\s*:)',
r'(Ref By|Referred By\s*:)',
r'(Collected\s*:)',
r'(Reported\s*:)',
r'(A/c Status\s*:)',
r'(Processed at\s*:)',
r'(Collected at\s*:)',
r'(Address\s*:)',
r'(Phone|Mobile|Mob\s*:)',
]
# Create a regex pattern that matches any of the header patterns
combined_pattern = '|'.join(header_patterns)
# Find where the actual test results begin
test_results_start = -1
for i, line in enumerate(lines):
if re.search(r'(Test\s*Report|Test\s*Name|Test\s*Results|Results|HEMOGRAM|ROUTINE|EXAMINATION)', line,
re.IGNORECASE):
test_results_start = i
break
# If we couldn't find the start of test results, look for key medical terms
if test_results_start == -1:
for i, line in enumerate(lines):
# Look for common test result sections
if re.search(r'(Hemoglobin|Blood|Urine|CBC|Glucose|Cholesterol|Protein|RBC|WBC)', line, re.IGNORECASE):
test_results_start = max(0, i - 3) # Start a few lines before the first test result
break
# If we still couldn't find the start of test results, use a heuristic:
# Skip the first few lines which usually contain header information
if test_results_start == -1:
# Count lines with patient identifiable information
header_count = 0
for i, line in enumerate(lines):
if re.search(combined_pattern, line, re.IGNORECASE):
header_count += 1
# If we found several header lines, skip those plus a few more
if header_count > 0:
test_results_start = min(header_count + 5, len(lines) // 3)
else:
# If no clear header pattern was found, just skip the first 10% of lines as a fallback
test_results_start = max(1, len(lines) // 10)
# Return text from the determined start point
clean_text = '\n'.join(lines[test_results_start:])
# If this dramatically shortened the text, use a less aggressive approach
if len(clean_text) < len(text) * 0.5:
print("Warning: Header removal may have removed too much content. Using alternative approach.")
# Alternative approach: Just remove lines with header patterns
filtered_lines = []
for line in lines:
if not re.search(combined_pattern, line, re.IGNORECASE):
filtered_lines.append(line)
clean_text = '\n'.join(filtered_lines)
# Try to extract gender from header before removing it
gender = "unknown"
for line in lines[:min(20, len(lines))]: # Check first 20 lines for gender
if re.search(r'gender|sex', line, re.IGNORECASE):
if re.search(r'male', line, re.IGNORECASE) and not re.search(r'female', line, re.IGNORECASE):
gender = "male"
elif re.search(r'female', line, re.IGNORECASE):
gender = "female"
return clean_text, original_text, gender
def extract_text_from_pdf_pymupdf(self, pdf_path):
"""Extract text from PDF using PyMuPDF (more robust than PyPDF)"""
text = ""
try:
doc = fitz.open(pdf_path)
for page in doc:
text += page.get_text()
doc.close()
return text
except Exception as e:
print(f"PyMuPDF extraction error: {str(e)}")
return None
def extract_text_from_pdf_pypdf(self, pdf_path):
"""Extract text using PyPDF as a backup method"""
try:
loader = PyPDFLoader(pdf_path)
pages = loader.load()
return "\n".join([page.page_content for page in pages])
except Exception as e:
print(f"PyPDF extraction error: {str(e)}")
return None
def extract_report_date(self, text):
"""Extract the date from the report text"""
date_patterns = [
r'(?:Report\s*Date|Date\s*of\s*Report|Date)[^\n\d]*(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})',
r'(?:Report\s*Date|Date\s*of\s*Report|Date)[^\n\d]*(\d{1,2}\s+[A-Za-z]+\s+\d{2,4})',
r'(?:Collected|Sample\s*Date|Collected\s*On)[^\n\d]*(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})',
r'(?:Collected|Sample\s*Date|Collected\s*On)[^\n\d]*(\d{1,2}\s+[A-Za-z]+\s+\d{2,4})',
r'\b(\d{1,2}[/-]\d{1,2}[/-]\d{2,4})\b',
r'\b(\d{1,2}\s+[A-Za-z]+\s+\d{2,4})\b'
]
for pattern in date_patterns:
match = re.search(pattern, text)
if match:
date_str = match.group(1)
try:
# Try different date formats
for fmt in ['%d/%m/%Y', '%d-%m-%Y', '%m/%d/%Y', '%m-%d-%Y',
'%d %b %Y', '%d %B %Y', '%b %d %Y', '%B %d %Y']:
try:
date = datetime.strptime(date_str, fmt)
return date
except ValueError:
continue
except Exception:
pass
# If no date found, return current date
return datetime.now()
def standardize_parameter_name(self, param_name):
"""Standardize parameter names using synonyms dictionary"""
param_name = param_name.lower().strip()
# Direct match
if param_name in STANDARD_RANGES:
return param_name
# Check against synonyms
for standard_name, synonyms in PARAMETER_SYNONYMS.items():
if param_name in synonyms:
return standard_name
# No match found, return original
return param_name
def extract_parameters_from_text(self, text):
"""Extract medical parameters and their values from the report text"""
parameters = {}
# First look for structured tables with clear parameter names, values, and reference ranges
# Pattern: Parameter Name Value Reference Range
table_patterns = [
# Pattern: Parameter followed by value and reference range
r'(?P<param>[A-Za-z\s\-\(\)]+)\s+(?P<value>\d+\.?\d*)\s+(?P<unit>[A-Za-z%\/]+)?\s+(?P<range>\d+\.?\d*\s*[-–]\s*\d+\.?\d*)',
# Pattern: Parameter followed by value only
r'(?P<param>[A-Za-z\s\-\(\)]+)[:\s]+(?P<value>\d+\.?\d*)\s+(?P<unit>[A-Za-z%\/]+)?',
]
for pattern in table_patterns:
matches = re.finditer(pattern, text)
for match in matches:
param_name = match.group('param').strip().lower()
value_str = match.group('value').strip()
# Get unit if it exists in the match
unit = match.groupdict().get('unit', '')
if unit:
unit = unit.strip()
# Get range if it exists in the match
ref_range = match.groupdict().get('range', '')
# Standardize parameter name
standard_param = self.standardize_parameter_name(param_name)
# Convert value to float
try:
value = float(value_str)
# Store parameter with its value and metadata
parameters[standard_param] = {
"value": value,
"unit": unit,
"reference_range": ref_range,
"original_name": param_name
}
except ValueError:
# Skip if value couldn't be converted to float
continue
# Second pass: Look for common parameter patterns line by line
lines = text.split('\n')
for line in lines:
line = line.strip()
if not line:
continue
# Look for pattern: parameter name followed by value
param_match = re.search(r'^([A-Za-z\s\-\(\)]+)[:\s]+(\d+\.?\d*)', line)
if param_match:
param_name = param_match.group(1).strip().lower()
value_str = param_match.group(2).strip()
# Skip if already found in structured table
standard_param = self.standardize_parameter_name(param_name)
if standard_param in parameters:
continue
# Try to extract unit
unit_match = re.search(r'(\d+\.?\d*)\s+([A-Za-z%\/]+)', line)
unit = unit_match.group(2).strip() if unit_match else ""
# Try to extract reference range
range_match = re.search(r'(?:reference|normal|range)[:\s]*(\d+\.?\d*\s*[-–]\s*\d+\.?\d*)', line,
re.IGNORECASE)
ref_range = range_match.group(1).strip() if range_match else ""
# Convert value to float
try:
value = float(value_str)
# Store parameter with its value and metadata
parameters[standard_param] = {
"value": value,
"unit": unit,
"reference_range": ref_range,
"original_name": param_name
}
except ValueError:
# Skip if value couldn't be converted to float
continue
return parameters
def check_parameter_ranges(self, parameters, gender="unknown"):
"""Check if parameters are within normal ranges"""
abnormal_parameters = []
for param_name, param_data in parameters.items():
# Skip if parameter is not in our standard ranges
if param_name not in STANDARD_RANGES:
continue
value = param_data["value"]
standard_range = STANDARD_RANGES[param_name]
# Determine min and max based on gender if applicable
if "male" in standard_range and "female" in standard_range:
if gender == "male":
param_min = standard_range["male"]["min"]
param_max = standard_range["male"]["max"]
elif gender == "female":
param_min = standard_range["female"]["min"]
param_max = standard_range["female"]["max"]
else:
# If gender unknown, use wider range (min of mins, max of maxes)
param_min = min(standard_range["male"]["min"], standard_range["female"]["min"])
param_max = max(standard_range["male"]["max"], standard_range["female"]["max"])
else:
param_min = standard_range["min"]
param_max = standard_range["max"]
# Check if value is outside range
if value < param_min or value > param_max:
abnormal_parameters.append({
"name": param_name,
"value": value,
"min": param_min,
"max": param_max,
"status": "low" if value < param_min else "high"
})
# Update the parameter with min/max values
param_data["min"] = param_min
param_data["max"] = param_max
param_data["status"] = "normal" if param_min <= value <= param_max else (
"low" if value < param_min else "high")
return abnormal_parameters
def process_user_report(self, report_file):
"""Process the uploaded medical report with multiple fallback methods"""
if report_file is None:
return "No file uploaded. Please upload a medical report."
# Ensure the uploaded file is read as bytes
temp_dir = tempfile.mkdtemp()
try:
# Copy the uploaded file to the temp directory
temp_file_path = os.path.join(temp_dir, "user_report.pdf")
# Handle file based on its type
try:
if isinstance(report_file, str): # If it's a file path
shutil.copy(report_file, temp_file_path)
elif hasattr(report_file, 'name'): # Gradio file object
with open(temp_file_path, 'wb') as f:
with open(report_file.name, 'rb') as source:
f.write(source.read())
else: # Try to handle as bytes or file-like object
with open(temp_file_path, 'wb') as f:
f.write(report_file.read() if hasattr(report_file, 'read') else report_file)
except Exception as e:
print(f"Error saving file: {str(e)}")
return f"Error saving the uploaded file: {str(e)}"
# Try multiple methods to extract text from the PDF
text = None
# Method 1: PyMuPDF
text = self.extract_text_from_pdf_pymupdf(temp_file_path)
# Method 2: PyPDF as fallback
if not text:
text = self.extract_text_from_pdf_pypdf(temp_file_path)
# Method 3: Last resort - try to read as raw text
if not text:
try:
with open(temp_file_path, 'r', errors='ignore') as f:
text = f.read()
except Exception as e:
print(f"Raw text reading error: {str(e)}")
# If we got text, process it
if text and len(text.strip()) > 0:
# Remove header information from the text
cleaned_text, original_text, gender = self.remove_header_information(text)
# Generate a unique ID for this report
report_id = str(uuid.uuid4())
# Try to extract report date
report_date = self.extract_report_date(original_text)
# Create a report object
report = MedicalReport(
report_id=report_id,
report_text=cleaned_text,
report_date=report_date,
report_name=f"Report {report_date.strftime('%Y-%m-%d')}",
gender=gender
)
# Extract parameters from the text
parameters = self.extract_parameters_from_text(cleaned_text)
report.parameters = parameters
# Check if parameters are within normal ranges
abnormal_parameters = self.check_parameter_ranges(parameters, gender)
report.abnormal_parameters = abnormal_parameters
# Store the report in our reports dictionary
self.reports[report_id] = report
self.current_report_id = report_id
# Save report text to disk for persistence
report_save_path = os.path.join(REPORTS_DIR, f"{report_id}.txt")
with open(report_save_path, 'w') as f:
f.write(cleaned_text)
# Save report metadata for persistence
metadata_save_path = os.path.join(REPORTS_DIR, f"{report_id}_meta.json")
with open(metadata_save_path, 'w') as f:
json.dump(report.to_dict(), f)
return {
"status": "success",
"message": f"Report processed successfully. Extracted {len(parameters)} parameters.",
"report_id": report_id,
"abnormal_count": len(abnormal_parameters)
}
else:
# Create a minimal report with error message
report_id = str(uuid.uuid4())
report = MedicalReport(
report_id=report_id,
report_text="Unable to extract text from the provided PDF. This is an empty report placeholder.",
report_name="Error Report"
)
self.reports[report_id] = report
self.current_report_id = report_id
return {
"status": "error",
"message": "Warning: Could not extract text from the PDF. The file may be corrupted, password-protected, or contain only images.",
"report_id": report_id
}
finally:
# Clean up the temporary directory and file
shutil.rmtree(temp_dir)
def process_multiple_reports(self, report_files):
"""Process multiple medical reports (up to 3)"""
if not report_files or len(report_files) == 0:
return {
"status": "error",
"message": "No files uploaded. Please upload medical reports."
}
if len(report_files) > 3:
return {
"status": "error",
"message": "Too many files. Maximum 3 reports can be compared at once."
}
report_results = []
report_ids = []
# Process each report
for report_file in report_files:
result = self.process_user_report(report_file)
if isinstance(result, dict) and result.get("status") == "success":
report_results.append(result)
report_ids.append(result.get("report_id"))
else:
# If string or error status
error_msg = result if isinstance(result, str) else result.get("message", "Unknown error")
return {
"status": "error",
"message": f"Error processing one of the reports: {error_msg}"
}
# Return the results
return {
"status": "success",
"message": f"Successfully processed {len(report_results)} reports.",
"report_ids": report_ids,
"reports": report_results
}
def answer_question(self, question, report_id=None):
"""Answer a question based on the uploaded report and knowledge base"""
# Use the specified report_id or the most recent one
target_report_id = report_id or self.current_report_id
if not target_report_id or target_report_id not in self.reports:
return "No report has been processed or the specified report ID is invalid. Please upload a medical report first."
report = self.reports[target_report_id]
report_text = report.raw_text
# Get context from knowledge base
try:
retrieved_docs = self.vector_store.similarity_search(question, k=5)
context = "\n\n".join([doc.page_content for doc in retrieved_docs])
# Check if question is about patient demographics or identification
demographic_patterns = [
r'(patient|name|age|gender|birth|dob|address|phone|contact|id|identification)',
r'(doctor|physician|referring|referred by)',
r'(date|time|collected|processed|reported)',
r'(lab|laboratory|number|id)'
]
combined_demo_pattern = '|'.join(demographic_patterns)
# If question might be about demographics, check if we need to use original data
if re.search(combined_demo_pattern, question, re.IGNORECASE):
# For demographic questions, we can use the original report that includes headers
# But only if we have specific identification information requests
specific_id_patterns = [
r'(name of|patient name|who is|what is the name)',
r'(exact age|age of|how old)',
r'(address of|where|location|contact details)',
r'(doctor name|name of doctor|referring doctor|who referred)',
r'(date of|when was|time of|report date)',
r'(lab number|patient id|identification number)'
]
specific_id_pattern = '|'.join(specific_id_patterns)
# If it's a direct question about patient identity, don't answer
if re.search(specific_id_pattern, question, re.IGNORECASE):
return "I'm unable to provide specific patient identification information. This feature is disabled to protect patient privacy. Please ask about medical test results or interpretations instead."
# Create the inputs dict for the QA chain
inputs = {
"query": question,
"context": context,
"patient_data": report_text
}
# Run the chain with the correct parameter structure
result = self.qa_chain(inputs)
# Extract the answer from the result
if isinstance(result, dict) and 'result' in result:
return result['result']
else:
return str(result)
except Exception as e:
print(f"Error answering question: {str(e)}")
error_msg = f"Error processing your question: {str(e)}."
# Use Google Gemini API as fallback instead of direct LLM call
try:
import google.generativeai as genai
# Get API key from environment variable
gemini_api_key = os.environ.get("GEMINI_API_KEY")
if not gemini_api_key:
return f"{error_msg} Gemini API key not found in environment variables. Please set the GEMINI_API_KEY environment variable."
# Configure the Gemini API
genai.configure(api_key=gemini_api_key)
# Initialize the model (Gemini 2.0 Flash)
model = genai.GenerativeModel('gemini-2.0-flash')
# Create prompt for Gemini
gemini_prompt = f"""
Question about medical report: {question}
Patient data available: {report_text[:2000]}... (truncated)
Please analyze this medical report data and answer the question.
Your answer should:
1. Be strictly under 350 words
2. Use simple language that's easy to understand
3. Focus on the medical information relevant to the question
4. Avoid making assumptions beyond what's in the data
"""
# Generate response from Gemini
response = model.generate_content(gemini_prompt)
# Get the response text
if hasattr(response, 'text'):
gemini_answer = response.text
else:
gemini_answer = str(response)
return f"{gemini_answer}"
except Exception as gemini_error:
return f"{error_msg} Gemini fallback also failed: {str(gemini_error)}. Please try a different question or report."
def generate_single_report_analysis(self, report_id=None):
"""Generate a comprehensive analysis of a single report"""
# Use the specified report_id or the most recent one
target_report_id = report_id or self.current_report_id
if not target_report_id or target_report_id not in self.reports:
return {
"status": "error",
"message": "No report has been processed or the specified report ID is invalid."
}
report = self.reports[target_report_id]
try:
# Group parameters by category
categories = {
"blood_count": ["hemoglobin", "hb", "rbc", "wbc", "platelets"],
"glucose": ["glucose", "hba1c"],
"lipids": ["cholesterol", "ldl", "hdl", "triglycerides"],
"liver_function": ["ast", "alt", "bilirubin", "alp", "ggt"],
"kidney_function": ["creatinine", "urea", "uric_acid"],
"thyroid": ["tsh", "t3", "t4"],
"vitamins": ["vitamin_d", "vitamin_b12"],
"electrolytes": ["sodium", "potassium", "calcium"]
}
# Organize parameters by category
categorized_params = {}
uncategorized_params = []
for param_name, param_data in report.parameters.items():
categorized = False
for category, params in categories.items():
if param_name in params:
if category not in categorized_params:
categorized_params[category] = []
categorized_params[category].append({
"name": param_name,
**param_data
})
categorized = True
break
if not categorized:
uncategorized_params.append({
"name": param_name,
**param_data
})
# Generate analysis for abnormal parameters
abnormal_analysis = []
for abnormal in report.abnormal_parameters:
param_name = abnormal["name"]
status = abnormal["status"]
value = abnormal["value"]
min_val = abnormal["min"]
max_val = abnormal["max"]
# Get parameter details
param_data = report.parameters.get(param_name, {})
unit = param_data.get("unit", "")
if status == "low":
analysis = f"{param_name.upper()} is LOW at {value} {unit} (below reference range of {min_val}-{max_val} {unit})"
else:
analysis = f"{param_name.upper()} is HIGH at {value} {unit} (above reference range of {min_val}-{max_val} {unit})"
abnormal_analysis.append(analysis)
# Generate health suggestions using LLM
suggestions_prompt = f"""
As a medical assistant, provide simple health suggestions for a patient with the following abnormal results:
{' '.join(abnormal_analysis)}
Please provide:
1. A brief explanation of what each abnormal result might indicate (in simple terms)
2. General lifestyle suggestions that might help improve these values
3. When the patient should consider consulting a doctor
Keep your response under 400 words and use simple, non-technical language. DO NOT include disclaimers about not being a doctor or medical advice, just provide the information directly.
"""
# Use LLM to generate suggestions
health_suggestions = self.answer_question(suggestions_prompt, report_id)
# Create visualization data
visualization_data = self.create_single_report_visualizations(report)
# Assemble the complete analysis
analysis = {
"status": "success",
"report_id": target_report_id,
"report_date": report.date.isoformat() if isinstance(report.date, datetime) else report.date,
"gender": report.gender,
"parameters_count": len(report.parameters),
"abnormal_count": len(report.abnormal_parameters),
"abnormal_parameters": report.abnormal_parameters,
"categorized_parameters": categorized_params,
"uncategorized_parameters": uncategorized_params,
"health_suggestions": health_suggestions,
"visualizations": visualization_data
}
return analysis
except Exception as e:
print(f"Error generating report analysis: {str(e)}")
return {
"status": "error",
"message": f"Error generating analysis: {str(e)}"
}
def create_single_report_visualizations(self, report):
"""Create visualizations for a single report"""
try:
# 1. Parameters Status Chart (normal vs abnormal)
normal_count = len(report.parameters) - len(report.abnormal_parameters)
abnormal_count = len(report.abnormal_parameters)
status_chart = {
"type": "pie",
"data": {
"labels": ["Normal", "Abnormal"],
"values": [normal_count, abnormal_count]
},
"title": "Parameter Status Distribution"
}
# 2. Abnormal Parameters Chart
if abnormal_count > 0:
abnormal_names = []
abnormal_percentages = []
for abnormal in report.abnormal_parameters:
param_name = abnormal["name"]
value = abnormal["value"]
min_val = abnormal["min"]
max_val = abnormal["max"]
if value < min_val:
# Calculate how much below min (as percentage)
deviation = (min_val - value) / min_val * 100
abnormal_names.append(f"{param_name} (Low)")
else:
# Calculate how much above max (as percentage)
deviation = (value - max_val) / max_val * 100
abnormal_names.append(f"{param_name} (High)")
# Cap at 100% for very extreme values
deviation = min(deviation, 100)
abnormal_percentages.append(deviation)
abnormal_chart = {
"type": "bar",
"data": {
"labels": abnormal_names,
"values": abnormal_percentages
},
"title": "Abnormal Parameters (% Deviation from Reference)"
}
else:
abnormal_chart = None
# 3. Category Distribution Chart
categories = {
"blood_count": ["hemoglobin", "hb", "rbc", "wbc", "platelets"],
"glucose": ["glucose", "hba1c"],
"lipids": ["cholesterol", "ldl", "hdl", "triglycerides"],
"liver_function": ["ast", "alt", "bilirubin", "alp", "ggt"],
"kidney_function": ["creatinine", "urea", "uric_acid"],
"thyroid": ["tsh", "t3", "t4"],
"vitamins": ["vitamin_d", "vitamin_b12"],
"electrolytes": ["sodium", "potassium", "calcium"]
}
category_counts = {"Other": 0}
for param_name in report.parameters:
categorized = False
for category, params in categories.items():
if param_name in params:
if category not in category_counts:
category_counts[category] = 0
category_counts[category] += 1
categorized = True
break
if not categorized:
category_counts["Other"] += 1
category_chart = {
"type": "pie",
"data": {
"labels": list(category_counts.keys()),
"values": list(category_counts.values())
},
"title": "Parameter Categories"
}
# Return all visualization data
return {
"status_chart": status_chart,
"abnormal_chart": abnormal_chart,
"category_chart": category_chart
}
except Exception as e:
print(f"Error creating visualizations: {str(e)}")
return {
"status": "error",
"message": f"Error creating visualizations: {str(e)}"
}
def compare_reports(self, report_ids):
"""Compare multiple reports (2-3) and generate analysis with visualizations"""
if not report_ids or len(report_ids) < 2:
return {
"status": "error",
"message": "At least two report IDs are required for comparison."
}
if len(report_ids) > 3:
return {
"status": "error",
"message": "Maximum 3 reports can be compared at once."
}
# Verify all report IDs exist
for report_id in report_ids:
if report_id not in self.reports:
return {
"status": "error",
"message": f"Report ID {report_id} not found."
}
try:
# Get report objects
report_objects = [self.reports[report_id] for report_id in report_ids]
# Sort reports by date (oldest to newest)
report_objects.sort(key=lambda r: r.date if isinstance(r.date, datetime) else datetime.now())
# Extract common parameters across all reports
common_parameters = set(report_objects[0].parameters.keys())
for report in report_objects[1:]:
common_parameters = common_parameters.intersection(set(report.parameters.keys()))
# If no common parameters, return error
if not common_parameters:
return {
"status": "error",
"message": "No common parameters found across the reports for comparison."
}
# Create parameter trends
parameter_trends = {}
for param in common_parameters:
values = []
dates = []
statuses = []
for report in report_objects:
if param in report.parameters:
param_data = report.parameters[param]
values.append(param_data["value"])
dates.append(
report.date.strftime('%Y-%m-%d') if isinstance(report.date, datetime) else str(
report.date))
statuses.append(param_data.get("status", "unknown"))
parameter_trends[param] = {
"name": param,
"values": values,
"dates": dates,
"statuses": statuses
}
# Generate chart data for trends
trend_charts = []
for param, trend_data in parameter_trends.items():
# Get reference ranges if available
ref_min = None
ref_max = None
if param in STANDARD_RANGES:
if "min" in STANDARD_RANGES[param]:
ref_min = STANDARD_RANGES[param]["min"]
if "max" in STANDARD_RANGES[param]:
ref_max = STANDARD_RANGES[param]["max"]
# Calculate percent change between first and last value
if len(trend_data["values"]) >= 2:
first_val = trend_data["values"][0]
last_val = trend_data["values"][-1]
if first_val != 0: # Avoid division by zero
percent_change = ((last_val - first_val) / first_val) * 100
else:
percent_change = 0
# Determine if the change is good or bad
if "status" in trend_data:
first_status = trend_data["statuses"][0]
last_status = trend_data["statuses"][-1]
# Improved if: was abnormal and now normal OR was high and decreased OR was low and increased
if (first_status != "normal" and last_status == "normal") or \
(first_status == "high" and last_val < first_val) or \
(first_status == "low" and last_val > first_val):
trend = "improved"
# Worsened if: was normal and now abnormal OR was high and increased OR was low and decreased
elif (first_status == "normal" and last_status != "normal") or \
(first_status == "high" and last_val > first_val) or \
(first_status == "low" and last_val < first_val):
trend = "worsened"
else:
trend = "unchanged"
else:
trend = "unknown"
else:
percent_change = 0
trend = "unknown"
# Create chart data
chart = {
"type": "line",
"data": {
"labels": trend_data["dates"],
"values": trend_data["values"]
},
"metadata": {
"parameter": param,
"percent_change": round(percent_change, 2),
"trend": trend,
"reference_min": ref_min,
"reference_max": ref_max
},
"title": f"{param.upper()} Trend"
}
trend_charts.append(chart)
# Group parameters by category for card-based UI
categories = {
"blood_count": ["hemoglobin", "hb", "rbc", "wbc", "platelets"],
"glucose": ["glucose", "hba1c"],
"lipids": ["cholesterol", "ldl", "hdl", "triglycerides"],
"liver_function": ["ast", "alt", "bilirubin", "alp", "ggt"],
"kidney_function": ["creatinine", "urea", "uric_acid"],
"thyroid": ["tsh", "t3", "t4"],
"vitamins": ["vitamin_d", "vitamin_b12"],
"electrolytes": ["sodium", "potassium", "calcium"]
}
# Organize charts by category
categorized_charts = {}
uncategorized_charts = []
for chart in trend_charts:
param_name = chart["metadata"]["parameter"]
categorized = False
for category, params in categories.items():
if param_name in params:
if category not in categorized_charts:
categorized_charts[category] = []
categorized_charts[category].append(chart)
categorized = True
break
if not categorized:
uncategorized_charts.append(chart)
# Create a summary chart showing overall health trends
improved_count = sum(1 for chart in trend_charts if chart["metadata"]["trend"] == "improved")
worsened_count = sum(1 for chart in trend_charts if chart["metadata"]["trend"] == "worsened")
unchanged_count = sum(1 for chart in trend_charts if chart["metadata"]["trend"] == "unchanged")
unknown_count = sum(1 for chart in trend_charts if chart["metadata"]["trend"] == "unknown")
summary_chart = {
"type": "pie",
"data": {
"labels": ["Improved", "Worsened", "Unchanged", "Unknown"],
"values": [improved_count, worsened_count, unchanged_count, unknown_count]
},
"title": "Overall Health Trends"
}
# Generate insights using the LLM
insights_prompt = f"""
As a medical assistant, I need to generate insights about how a patient's health has changed between medical reports.
Here are the key changes:
"""
# Add significant changes to the prompt
for chart in trend_charts:
param = chart["metadata"]["parameter"]
change = chart["metadata"]["percent_change"]
trend = chart["metadata"]["trend"]
if abs(change) > 5: # Only include significant changes (>5%)
insights_prompt += f"\n- {param}: {change:+.1f}% change ({trend})"
insights_prompt += """
Based on these changes, please provide:
1. A brief overview of the overall health trend (improved, worsened, or mixed)
2. The most significant positive changes and what they might indicate
3. The most significant concerns and what they might indicate
4. 3-5 specific recommendations based on these trends
Keep your response under 400 words and use simple, non-technical language that a patient can understand.
DO NOT include disclaimers about not being a doctor or medical advice, just provide the information directly.
"""
# Use LLM to generate insights
health_insights = self.answer_question(insights_prompt)
# Assemble the complete comparison
comparison = {
"status": "success",
"report_count": len(report_ids),
"report_dates": [r.date.strftime('%Y-%m-%d') if isinstance(r.date, datetime) else str(r.date) for r
in
report_objects],
"common_parameters_count": len(common_parameters),
"parameter_trends": parameter_trends,
"categorized_charts": categorized_charts,
"uncategorized_charts": uncategorized_charts,
"summary_chart": summary_chart,
"health_insights": health_insights,
"statistics": {
"improved": improved_count,
"worsened": worsened_count,
"unchanged": unchanged_count,
"unknown": unknown_count
}
}
return comparison
except Exception as e:
print(f"Error comparing reports: {str(e)}")
return {
"status": "error",
"message": f"Error comparing reports: {str(e)}"
}
def generate_visualization_image(self, chart_data, width=600, height=400):
"""Generate visualization image based on chart data and return as base64"""
try:
plt.figure(figsize=(width / 100, height / 100), dpi=100)
# Handle different chart types
chart_type = chart_data.get("type", "bar")
data = chart_data.get("data", {})
title = chart_data.get("title", "Chart")
labels = data.get("labels", [])
values = data.get("values", [])
if chart_type == "bar":
plt.bar(labels, values)
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
elif chart_type == "line":
plt.plot(labels, values, marker='o')
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
# Add reference range if available
metadata = chart_data.get("metadata", {})
ref_min = metadata.get("reference_min")
ref_max = metadata.get("reference_max")
if ref_min is not None:
plt.axhline(y=ref_min, color='r', linestyle='--', alpha=0.5)
if ref_max is not None:
plt.axhline(y=ref_max, color='r', linestyle='--', alpha=0.5)
elif chart_type == "pie":
plt.pie(values, labels=labels, autopct='%1.1f%%', startangle=90)
plt.axis('equal')
else:
raise ValueError(f"Unsupported chart type: {chart_type}")
plt.title(title)
# Save the plot to a binary buffer
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
# Convert to base64
image_base64 = base64.b64encode(buf.read()).decode('utf-8')
plt.close()
return image_base64
except Exception as e:
print(f"Error generating chart: {str(e)}")
return None
def generate_interactive_chart(self, chart_data):
"""Generate an interactive Plotly chart based on chart_data"""
try:
chart_type = chart_data.get("type", "bar")
data = chart_data.get("data", {})
title = chart_data.get("title", "Chart")
labels = data.get("labels", [])
values = data.get("values", [])
if chart_type == "bar":
fig = px.bar(x=labels, y=values, title=title)
fig.update_layout(xaxis_title="", yaxis_title="Value")
elif chart_type == "line":
fig = px.line(x=labels, y=values, markers=True, title=title)
fig.update_layout(xaxis_title="Date", yaxis_title="Value")
# Add reference range if available
metadata = chart_data.get("metadata", {})
ref_min = metadata.get("reference_min")
ref_max = metadata.get("reference_max")
if ref_min is not None:
fig.add_shape(type="line", line_color="red", line_dash="dash",
x0=0, y0=ref_min, x1=1, y1=ref_min,
xref="paper", yref="y")
if ref_max is not None:
fig.add_shape(type="line", line_color="red", line_dash="dash",
x0=0, y0=ref_max, x1=1, y1=ref_max,
xref="paper", yref="y")
elif chart_type == "pie":
fig = px.pie(values=values, names=labels, title=title)
elif chart_type == "gauge":
# Extract gauge-specific properties
value = values[0] if values else 0
min_val = data.get("min", 0)
max_val = data.get("max", 100)
fig = go.Figure(go.Indicator(
mode="gauge+number",
value=value,
title={"text": title},
gauge={
"axis": {"range": [min_val, max_val]},
"bar": {"color": "darkblue"},
"steps": [
{"range": [min_val, min_val + (max_val - min_val) / 3], "color": "red"},
{"range": [min_val + (max_val - min_val) / 3, min_val + 2 * (max_val - min_val) / 3],
"color": "yellow"},
{"range": [min_val + 2 * (max_val - min_val) / 3, max_val], "color": "green"}
]
}
))
else:
raise ValueError(f"Unsupported chart type: {chart_type}")
# Set consistent layout properties
fig.update_layout(
title_x=0.5,
margin=dict(l=50, r=50, b=50, t=80),
height=400,
width=600
)
# Convert to JSON for use in HTML/JavaScript
chart_json = fig.to_json()
return chart_json
except Exception as e:
print(f"Error generating interactive chart: {str(e)}")
return None
analyzer = MedicalReportAnalyzer()
app = FastAPI()
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.mount("/static", StaticFiles(directory="static"), name="static")
@app.post("/process_user_report")
async def process_user_report_endpoint(report_file: UploadFile = File(...)):
try:
temp_dir = tempfile.mkdtemp()
temp_file_path = os.path.join(temp_dir, report_file.filename)
with open(temp_file_path, "wb") as f:
shutil.copyfileobj(report_file.file, f)
result = analyzer.process_user_report(temp_file_path)
shutil.rmtree(temp_dir)
return {
"status": "success",
"data": result
}
except Exception as e:
return {
"status": "error",
"message": str(e)
}
@app.get("/get_reference_ranges")
def get_reference_ranges():
return {
"status": "success",
"data": STANDARD_RANGES
}
@app.post("/generate_suggestions")
async def generate_suggestions(data: dict):
try:
user_metrics = data.get("metrics", {})
suggestions = analyzer.generate_health_suggestions(user_metrics)
return {
"status": "success",
"suggestions": suggestions
}
except Exception as e:
return {
"status": "error",
"message": str(e)
}
@app.get("/metrics_comparison")
def metrics_comparison(metric_name: str = Query(...)):
try:
comparison_data = analyzer.get_comparison_data(metric_name)
return {
"status": "success",
"data": comparison_data
}
except Exception as e:
return {
"status": "error",
"message": str(e)
}
@app.get("/user_history/{user_id}")
def get_user_history(user_id: str):
try:
history = analyzer.get_user_history(user_id)
return {
"status": "success",
"history": history
}
except Exception as e:
return {
"status": "error",
"message": str(e)
}
@app.post("/save_report_data")
async def save_report_data(data: dict):
try:
user_id = data.get("user_id")
report_data = data.get("report_data")
date = data.get("date", datetime.now().isoformat())
success = analyzer.save_user_data(user_id, report_data, date)
return {
"status": "success" if success else "error",
"message": "Data saved successfully" if success else "Failed to save data"
}
except Exception as e:
return {
"status": "error",
"message": str(e)
}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)