Update app.py
Browse files
app.py
CHANGED
@@ -33,9 +33,389 @@ if not HF_TOKEN:
|
|
33 |
print("Warning: HF_TOKEN not found in environment variables. You may not be able to access gated models.")
|
34 |
|
35 |
class MedicalReportAnalyzer:
|
36 |
-
|
37 |
-
|
38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
# Initialize the analyzer
|
41 |
analyzer = MedicalReportAnalyzer()
|
|
|
33 |
print("Warning: HF_TOKEN not found in environment variables. You may not be able to access gated models.")
|
34 |
|
35 |
class MedicalReportAnalyzer:
|
36 |
+
def __init__(self):
|
37 |
+
self.vector_store = None
|
38 |
+
self.llm = None
|
39 |
+
self.qa_chain = None
|
40 |
+
self.user_report_data = "No report data available." # Default value
|
41 |
+
self.original_report_data = "No original report data available." # Store original data
|
42 |
+
# Initialize everything
|
43 |
+
self._load_or_create_vector_store()
|
44 |
+
self._initialize_llm()
|
45 |
+
self._setup_qa_chain()
|
46 |
+
|
47 |
+
def _load_or_create_vector_store(self):
|
48 |
+
"""Load existing vector store or create a new one from knowledge documents"""
|
49 |
+
embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
|
50 |
+
|
51 |
+
# Check if vector store exists
|
52 |
+
if os.path.exists(VECTOR_STORE_PATH):
|
53 |
+
print("Loading existing vector store...")
|
54 |
+
self.vector_store = FAISS.load_local(VECTOR_STORE_PATH, embeddings)
|
55 |
+
else:
|
56 |
+
print("Creating new vector store from documents...")
|
57 |
+
# Create knowledge directory if it doesn't exist
|
58 |
+
os.makedirs(KNOWLEDGE_DIR, exist_ok=True)
|
59 |
+
|
60 |
+
# Check if there are documents to process
|
61 |
+
if len(os.listdir(KNOWLEDGE_DIR)) == 0:
|
62 |
+
print(f"Warning: No documents found in {KNOWLEDGE_DIR}. Please add medical PDFs.")
|
63 |
+
# Initialize empty vector store
|
64 |
+
self.vector_store = FAISS.from_texts(["No medical knowledge available yet."], embeddings)
|
65 |
+
self.vector_store.save_local(VECTOR_STORE_PATH)
|
66 |
+
return
|
67 |
+
|
68 |
+
# Load all PDFs from the knowledge directory
|
69 |
+
try:
|
70 |
+
# First try with DirectoryLoader
|
71 |
+
loader = DirectoryLoader(KNOWLEDGE_DIR, glob="**/*.pdf", loader_cls=PyPDFLoader)
|
72 |
+
documents = loader.load()
|
73 |
+
|
74 |
+
# Split documents into chunks
|
75 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
76 |
+
chunk_size=1000,
|
77 |
+
chunk_overlap=200,
|
78 |
+
length_function=len
|
79 |
+
)
|
80 |
+
chunks = text_splitter.split_documents(documents)
|
81 |
+
|
82 |
+
# Create and save the vector store
|
83 |
+
self.vector_store = FAISS.from_documents(chunks, embeddings)
|
84 |
+
self.vector_store.save_local(VECTOR_STORE_PATH)
|
85 |
+
except Exception as e:
|
86 |
+
print(f"Error loading documents with DirectoryLoader: {str(e)}")
|
87 |
+
# Initialize with minimal data
|
88 |
+
self.vector_store = FAISS.from_texts(["No medical knowledge available yet."], embeddings)
|
89 |
+
self.vector_store.save_local(VECTOR_STORE_PATH)
|
90 |
+
|
91 |
+
def _initialize_llm(self):
|
92 |
+
"""Initialize the language model with HF token authentication"""
|
93 |
+
print(f"Loading model {MODEL_NAME} on {DEVICE}...")
|
94 |
+
try:
|
95 |
+
# Use the HF_TOKEN for authentication
|
96 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
97 |
+
MODEL_NAME,
|
98 |
+
token=HF_TOKEN
|
99 |
+
)
|
100 |
+
model = AutoModelForCausalLM.from_pretrained(
|
101 |
+
MODEL_NAME,
|
102 |
+
token=HF_TOKEN,
|
103 |
+
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
|
104 |
+
device_map="auto",
|
105 |
+
load_in_8bit=DEVICE == "cuda", # Use 8-bit quantization if on CUDA
|
106 |
+
)
|
107 |
+
|
108 |
+
# Create a text generation pipeline
|
109 |
+
pipe = pipeline(
|
110 |
+
"text-generation",
|
111 |
+
model=model,
|
112 |
+
tokenizer=tokenizer,
|
113 |
+
max_new_tokens=512,
|
114 |
+
temperature=0.1,
|
115 |
+
top_p=0.95,
|
116 |
+
repetition_penalty=1.1
|
117 |
+
)
|
118 |
+
|
119 |
+
# Create LangChain wrapper around the pipeline
|
120 |
+
self.llm = HuggingFacePipeline(pipeline=pipe)
|
121 |
+
except Exception as e:
|
122 |
+
print(f"Error loading the model: {str(e)}")
|
123 |
+
print("Falling back to a non-gated model...")
|
124 |
+
# Fallback to a non-gated model
|
125 |
+
fallback_model = "google/flan-t5-large"
|
126 |
+
tokenizer = AutoTokenizer.from_pretrained(fallback_model)
|
127 |
+
model = AutoModelForCausalLM.from_pretrained(
|
128 |
+
fallback_model,
|
129 |
+
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
|
130 |
+
device_map="auto"
|
131 |
+
)
|
132 |
+
pipe = pipeline(
|
133 |
+
"text-generation",
|
134 |
+
model=model,
|
135 |
+
tokenizer=tokenizer,
|
136 |
+
max_new_tokens=512
|
137 |
+
)
|
138 |
+
self.llm = HuggingFacePipeline(pipeline=pipe)
|
139 |
+
|
140 |
+
def _setup_qa_chain(self):
|
141 |
+
"""Set up the question-answering chain"""
|
142 |
+
# Define a custom prompt template for medical analysis
|
143 |
+
template = """
|
144 |
+
You are a medical assistant analyzing patient medical reports. Use the following pieces of context to answer the question. If you don't know the answer, just say that you don't know, don't try to make up an answer.
|
145 |
+
|
146 |
+
Patient Report Summary: {patient_data}
|
147 |
+
|
148 |
+
Context from medical knowledge base: {context}
|
149 |
+
|
150 |
+
Question: {question}
|
151 |
+
|
152 |
+
Answer:
|
153 |
+
"""
|
154 |
+
|
155 |
+
prompt = PromptTemplate(
|
156 |
+
template=template,
|
157 |
+
input_variables=["context", "question", "patient_data"]
|
158 |
+
)
|
159 |
+
|
160 |
+
# Create the QA chain
|
161 |
+
self.qa_chain = RetrievalQA.from_chain_type(
|
162 |
+
llm=self.llm,
|
163 |
+
chain_type="stuff",
|
164 |
+
retriever=self.vector_store.as_retriever(search_kwargs={"k": 5}),
|
165 |
+
chain_type_kwargs={"prompt": prompt},
|
166 |
+
return_source_documents=False
|
167 |
+
)
|
168 |
+
|
169 |
+
def remove_header_information(self, text):
|
170 |
+
"""Remove header information from the report text"""
|
171 |
+
# Store the original text
|
172 |
+
self.original_report_data = text
|
173 |
+
|
174 |
+
# Split the text into lines to analyze
|
175 |
+
lines = text.split('\n')
|
176 |
+
|
177 |
+
# Define patterns to identify header information
|
178 |
+
header_patterns = [
|
179 |
+
r'(Name\s*:)',
|
180 |
+
r'(Patient\s*Name\s*:)',
|
181 |
+
r'(DOB|Date of Birth\s*:)',
|
182 |
+
r'(Age\s*:)',
|
183 |
+
r'(Gender\s*:)',
|
184 |
+
r'(Lab No\.|Laboratory Number\s*:)',
|
185 |
+
r'(Patient ID\s*:)',
|
186 |
+
r'(Report Status\s*:)',
|
187 |
+
r'(Ref By|Referred By\s*:)',
|
188 |
+
r'(Collected\s*:)',
|
189 |
+
r'(Reported\s*:)',
|
190 |
+
r'(A/c Status\s*:)',
|
191 |
+
r'(Processed at\s*:)',
|
192 |
+
r'(Collected at\s*:)',
|
193 |
+
r'(Address\s*:)',
|
194 |
+
r'(Phone|Mobile|Mob\s*:)',
|
195 |
+
]
|
196 |
+
|
197 |
+
# Create a regex pattern that matches any of the header patterns
|
198 |
+
combined_pattern = '|'.join(header_patterns)
|
199 |
+
|
200 |
+
# Find where the actual test results begin
|
201 |
+
test_results_start = -1
|
202 |
+
for i, line in enumerate(lines):
|
203 |
+
if re.search(r'(Test\s*Report|Test\s*Name|Test\s*Results|Results|HEMOGRAM|ROUTINE|EXAMINATION)', line, re.IGNORECASE):
|
204 |
+
test_results_start = i
|
205 |
+
break
|
206 |
+
|
207 |
+
# If we couldn't find the start of test results, look for key medical terms
|
208 |
+
if test_results_start == -1:
|
209 |
+
for i, line in enumerate(lines):
|
210 |
+
# Look for common test result sections
|
211 |
+
if re.search(r'(Hemoglobin|Blood|Urine|CBC|Glucose|Cholesterol|Protein|RBC|WBC)', line, re.IGNORECASE):
|
212 |
+
test_results_start = max(0, i-3) # Start a few lines before the first test result
|
213 |
+
break
|
214 |
+
|
215 |
+
# If we still couldn't find the start of test results, use a heuristic:
|
216 |
+
# Skip the first few lines which usually contain header information
|
217 |
+
if test_results_start == -1:
|
218 |
+
# Count lines with patient identifiable information
|
219 |
+
header_count = 0
|
220 |
+
for i, line in enumerate(lines):
|
221 |
+
if re.search(combined_pattern, line, re.IGNORECASE):
|
222 |
+
header_count += 1
|
223 |
+
|
224 |
+
# If we found several header lines, skip those plus a few more
|
225 |
+
if header_count > 0:
|
226 |
+
test_results_start = min(header_count + 5, len(lines) // 3)
|
227 |
+
else:
|
228 |
+
# If no clear header pattern was found, just skip the first 10% of lines as a fallback
|
229 |
+
test_results_start = max(1, len(lines) // 10)
|
230 |
+
|
231 |
+
# Return text from the determined start point
|
232 |
+
clean_text = '\n'.join(lines[test_results_start:])
|
233 |
+
|
234 |
+
# If this dramatically shortened the text, use a less aggressive approach
|
235 |
+
if len(clean_text) < len(text) * 0.5:
|
236 |
+
print("Warning: Header removal may have removed too much content. Using alternative approach.")
|
237 |
+
# Alternative approach: Just remove lines with header patterns
|
238 |
+
filtered_lines = []
|
239 |
+
for line in lines:
|
240 |
+
if not re.search(combined_pattern, line, re.IGNORECASE):
|
241 |
+
filtered_lines.append(line)
|
242 |
+
clean_text = '\n'.join(filtered_lines)
|
243 |
+
|
244 |
+
return clean_text
|
245 |
+
|
246 |
+
def extract_text_from_pdf_pymupdf(self, pdf_path):
|
247 |
+
"""Extract text from PDF using PyMuPDF (more robust than PyPDF)"""
|
248 |
+
text = ""
|
249 |
+
try:
|
250 |
+
doc = fitz.open(pdf_path)
|
251 |
+
for page in doc:
|
252 |
+
text += page.get_text()
|
253 |
+
doc.close()
|
254 |
+
return text
|
255 |
+
except Exception as e:
|
256 |
+
print(f"PyMuPDF extraction error: {str(e)}")
|
257 |
+
return None
|
258 |
+
|
259 |
+
def extract_text_from_pdf_pypdf(self, pdf_path):
|
260 |
+
"""Extract text using PyPDF as a backup method"""
|
261 |
+
try:
|
262 |
+
loader = PyPDFLoader(pdf_path)
|
263 |
+
pages = loader.load()
|
264 |
+
return "\n".join([page.page_content for page in pages])
|
265 |
+
except Exception as e:
|
266 |
+
print(f"PyPDF extraction error: {str(e)}")
|
267 |
+
return None
|
268 |
+
|
269 |
+
def process_user_report(self, report_file):
|
270 |
+
"""Process the uploaded medical report with multiple fallback methods"""
|
271 |
+
if report_file is None:
|
272 |
+
return "No file uploaded. Please upload a medical report."
|
273 |
+
|
274 |
+
# Ensure the uploaded file is read as bytes
|
275 |
+
temp_dir = tempfile.mkdtemp()
|
276 |
+
try:
|
277 |
+
# Copy the uploaded file to the temp directory
|
278 |
+
temp_file_path = os.path.join(temp_dir, "user_report.pdf")
|
279 |
+
|
280 |
+
# Handle file based on its type
|
281 |
+
try:
|
282 |
+
if isinstance(report_file, str): # If it's a file path
|
283 |
+
shutil.copy(report_file, temp_file_path)
|
284 |
+
elif hasattr(report_file, 'name'): # Gradio file object
|
285 |
+
with open(temp_file_path, 'wb') as f:
|
286 |
+
with open(report_file.name, 'rb') as source:
|
287 |
+
f.write(source.read())
|
288 |
+
else: # Try to handle as bytes or file-like object
|
289 |
+
with open(temp_file_path, 'wb') as f:
|
290 |
+
f.write(report_file.read() if hasattr(report_file, 'read') else report_file)
|
291 |
+
except Exception as e:
|
292 |
+
print(f"Error saving file: {str(e)}")
|
293 |
+
return f"Error saving the uploaded file: {str(e)}"
|
294 |
+
|
295 |
+
# Try multiple methods to extract text from the PDF
|
296 |
+
text = None
|
297 |
+
|
298 |
+
# Method 1: PyMuPDF
|
299 |
+
text = self.extract_text_from_pdf_pymupdf(temp_file_path)
|
300 |
+
|
301 |
+
# Method 2: PyPDF as fallback
|
302 |
+
if not text:
|
303 |
+
text = self.extract_text_from_pdf_pypdf(temp_file_path)
|
304 |
+
|
305 |
+
# Method 3: Last resort - try to read as raw text
|
306 |
+
if not text:
|
307 |
+
try:
|
308 |
+
with open(temp_file_path, 'r', errors='ignore') as f:
|
309 |
+
text = f.read()
|
310 |
+
except Exception as e:
|
311 |
+
print(f"Raw text reading error: {str(e)}")
|
312 |
+
|
313 |
+
# If we got text, process it
|
314 |
+
if text and len(text.strip()) > 0:
|
315 |
+
# Remove header information from the text
|
316 |
+
cleaned_text = self.remove_header_information(text)
|
317 |
+
|
318 |
+
# Store the cleaned text
|
319 |
+
self.user_report_data = cleaned_text
|
320 |
+
|
321 |
+
# Split into chunks if needed
|
322 |
+
text_splitter = RecursiveCharacterTextSplitter(
|
323 |
+
chunk_size=1000,
|
324 |
+
chunk_overlap=200,
|
325 |
+
length_function=len
|
326 |
+
)
|
327 |
+
chunks = text_splitter.split_text(cleaned_text)
|
328 |
+
|
329 |
+
# Check if too much text was removed
|
330 |
+
original_length = len(text.strip())
|
331 |
+
cleaned_length = len(cleaned_text.strip())
|
332 |
+
removal_percentage = (original_length - cleaned_length) / original_length * 100
|
333 |
+
|
334 |
+
if removal_percentage > 80:
|
335 |
+
return f"Report processed successfully, but significant content may have been filtered. Original length: {original_length} chars. Cleaned length: {cleaned_length} chars. Extracted approximately {len(chunks)} text chunks."
|
336 |
+
else:
|
337 |
+
return f"Report processed successfully. Removed approximately {removal_percentage:.1f}% of header content. Extracted {len(chunks)} text chunks."
|
338 |
+
else:
|
339 |
+
self.user_report_data = "Unable to extract text from the provided PDF. This is an empty report placeholder."
|
340 |
+
return "Warning: Could not extract text from the PDF. The file may be corrupted, password-protected, or contain only images. Processing will continue with limited data."
|
341 |
+
|
342 |
+
finally:
|
343 |
+
# Clean up the temporary directory and file
|
344 |
+
shutil.rmtree(temp_dir)
|
345 |
+
|
346 |
+
def answer_question(self, question):
|
347 |
+
"""Answer a question based on the uploaded report and knowledge base"""
|
348 |
+
if not self.user_report_data or self.user_report_data == "No report data available.":
|
349 |
+
return "No report has been processed or text extraction failed. Please upload a medical report first."
|
350 |
+
|
351 |
+
# Get context from knowledge base
|
352 |
+
try:
|
353 |
+
retrieved_docs = self.vector_store.similarity_search(question, k=5)
|
354 |
+
context = "\n\n".join([doc.page_content for doc in retrieved_docs])
|
355 |
+
|
356 |
+
# Check if question is about patient demographics or identification
|
357 |
+
demographic_patterns = [
|
358 |
+
r'(patient|name|age|gender|birth|dob|address|phone|contact|id|identification)',
|
359 |
+
r'(doctor|physician|referring|referred by)',
|
360 |
+
r'(date|time|collected|processed|reported)',
|
361 |
+
r'(lab|laboratory|number|id)'
|
362 |
+
]
|
363 |
+
|
364 |
+
combined_demo_pattern = '|'.join(demographic_patterns)
|
365 |
+
|
366 |
+
# If question might be about demographics, check if we need to use original data
|
367 |
+
if re.search(combined_demo_pattern, question, re.IGNORECASE):
|
368 |
+
# For demographic questions, we can use the original report that includes headers
|
369 |
+
# But only if we have specific identification information requests
|
370 |
+
specific_id_patterns = [
|
371 |
+
r'(name of|patient name|who is|what is the name)',
|
372 |
+
r'(exact age|age of|how old)',
|
373 |
+
r'(address of|where|location|contact details)',
|
374 |
+
r'(doctor name|name of doctor|referring doctor|who referred)',
|
375 |
+
r'(date of|when was|time of|report date)',
|
376 |
+
r'(lab number|patient id|identification number)'
|
377 |
+
]
|
378 |
+
|
379 |
+
specific_id_pattern = '|'.join(specific_id_patterns)
|
380 |
+
|
381 |
+
# If it's a direct question about patient identity, don't answer
|
382 |
+
if re.search(specific_id_pattern, question, re.IGNORECASE):
|
383 |
+
return "I'm unable to provide specific patient identification information. This feature is disabled to protect patient privacy. Please ask about medical test results or interpretations instead."
|
384 |
+
|
385 |
+
# Create the inputs dict for the QA chain
|
386 |
+
inputs = {
|
387 |
+
"query": question,
|
388 |
+
"context": context,
|
389 |
+
"patient_data": self.user_report_data
|
390 |
+
}
|
391 |
+
|
392 |
+
# Run the chain with the correct parameter structure
|
393 |
+
result = self.qa_chain(inputs)
|
394 |
+
|
395 |
+
# Extract the answer from the result
|
396 |
+
if isinstance(result, dict) and 'result' in result:
|
397 |
+
return result['result']
|
398 |
+
else:
|
399 |
+
return str(result)
|
400 |
+
|
401 |
+
except Exception as e:
|
402 |
+
print(f"Error answering question: {str(e)}")
|
403 |
+
error_msg = f"Error processing your question: {str(e)}."
|
404 |
+
|
405 |
+
# Try direct LLM call as fallback
|
406 |
+
try:
|
407 |
+
direct_prompt = f"""
|
408 |
+
Question about medical report: {question}
|
409 |
+
|
410 |
+
Patient data available: {self.user_report_data[:800]}... (truncated)
|
411 |
+
|
412 |
+
Please answer based on this information:
|
413 |
+
"""
|
414 |
+
|
415 |
+
direct_result = self.llm(direct_prompt)
|
416 |
+
return f"{error_msg} Fallback answer: {direct_result}"
|
417 |
+
except:
|
418 |
+
return f"{error_msg} Please try a different question or report."
|
419 |
|
420 |
# Initialize the analyzer
|
421 |
analyzer = MedicalReportAnalyzer()
|