rishi002 commited on
Commit
38d15b8
·
verified ·
1 Parent(s): 05347e1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +383 -3
app.py CHANGED
@@ -33,9 +33,389 @@ if not HF_TOKEN:
33
  print("Warning: HF_TOKEN not found in environment variables. You may not be able to access gated models.")
34
 
35
  class MedicalReportAnalyzer:
36
- ...
37
- # Keep your existing MedicalReportAnalyzer implementation here
38
- # No changes needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  # Initialize the analyzer
41
  analyzer = MedicalReportAnalyzer()
 
33
  print("Warning: HF_TOKEN not found in environment variables. You may not be able to access gated models.")
34
 
35
  class MedicalReportAnalyzer:
36
+ def __init__(self):
37
+ self.vector_store = None
38
+ self.llm = None
39
+ self.qa_chain = None
40
+ self.user_report_data = "No report data available." # Default value
41
+ self.original_report_data = "No original report data available." # Store original data
42
+ # Initialize everything
43
+ self._load_or_create_vector_store()
44
+ self._initialize_llm()
45
+ self._setup_qa_chain()
46
+
47
+ def _load_or_create_vector_store(self):
48
+ """Load existing vector store or create a new one from knowledge documents"""
49
+ embeddings = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL)
50
+
51
+ # Check if vector store exists
52
+ if os.path.exists(VECTOR_STORE_PATH):
53
+ print("Loading existing vector store...")
54
+ self.vector_store = FAISS.load_local(VECTOR_STORE_PATH, embeddings)
55
+ else:
56
+ print("Creating new vector store from documents...")
57
+ # Create knowledge directory if it doesn't exist
58
+ os.makedirs(KNOWLEDGE_DIR, exist_ok=True)
59
+
60
+ # Check if there are documents to process
61
+ if len(os.listdir(KNOWLEDGE_DIR)) == 0:
62
+ print(f"Warning: No documents found in {KNOWLEDGE_DIR}. Please add medical PDFs.")
63
+ # Initialize empty vector store
64
+ self.vector_store = FAISS.from_texts(["No medical knowledge available yet."], embeddings)
65
+ self.vector_store.save_local(VECTOR_STORE_PATH)
66
+ return
67
+
68
+ # Load all PDFs from the knowledge directory
69
+ try:
70
+ # First try with DirectoryLoader
71
+ loader = DirectoryLoader(KNOWLEDGE_DIR, glob="**/*.pdf", loader_cls=PyPDFLoader)
72
+ documents = loader.load()
73
+
74
+ # Split documents into chunks
75
+ text_splitter = RecursiveCharacterTextSplitter(
76
+ chunk_size=1000,
77
+ chunk_overlap=200,
78
+ length_function=len
79
+ )
80
+ chunks = text_splitter.split_documents(documents)
81
+
82
+ # Create and save the vector store
83
+ self.vector_store = FAISS.from_documents(chunks, embeddings)
84
+ self.vector_store.save_local(VECTOR_STORE_PATH)
85
+ except Exception as e:
86
+ print(f"Error loading documents with DirectoryLoader: {str(e)}")
87
+ # Initialize with minimal data
88
+ self.vector_store = FAISS.from_texts(["No medical knowledge available yet."], embeddings)
89
+ self.vector_store.save_local(VECTOR_STORE_PATH)
90
+
91
+ def _initialize_llm(self):
92
+ """Initialize the language model with HF token authentication"""
93
+ print(f"Loading model {MODEL_NAME} on {DEVICE}...")
94
+ try:
95
+ # Use the HF_TOKEN for authentication
96
+ tokenizer = AutoTokenizer.from_pretrained(
97
+ MODEL_NAME,
98
+ token=HF_TOKEN
99
+ )
100
+ model = AutoModelForCausalLM.from_pretrained(
101
+ MODEL_NAME,
102
+ token=HF_TOKEN,
103
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
104
+ device_map="auto",
105
+ load_in_8bit=DEVICE == "cuda", # Use 8-bit quantization if on CUDA
106
+ )
107
+
108
+ # Create a text generation pipeline
109
+ pipe = pipeline(
110
+ "text-generation",
111
+ model=model,
112
+ tokenizer=tokenizer,
113
+ max_new_tokens=512,
114
+ temperature=0.1,
115
+ top_p=0.95,
116
+ repetition_penalty=1.1
117
+ )
118
+
119
+ # Create LangChain wrapper around the pipeline
120
+ self.llm = HuggingFacePipeline(pipeline=pipe)
121
+ except Exception as e:
122
+ print(f"Error loading the model: {str(e)}")
123
+ print("Falling back to a non-gated model...")
124
+ # Fallback to a non-gated model
125
+ fallback_model = "google/flan-t5-large"
126
+ tokenizer = AutoTokenizer.from_pretrained(fallback_model)
127
+ model = AutoModelForCausalLM.from_pretrained(
128
+ fallback_model,
129
+ torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
130
+ device_map="auto"
131
+ )
132
+ pipe = pipeline(
133
+ "text-generation",
134
+ model=model,
135
+ tokenizer=tokenizer,
136
+ max_new_tokens=512
137
+ )
138
+ self.llm = HuggingFacePipeline(pipeline=pipe)
139
+
140
+ def _setup_qa_chain(self):
141
+ """Set up the question-answering chain"""
142
+ # Define a custom prompt template for medical analysis
143
+ template = """
144
+ You are a medical assistant analyzing patient medical reports. Use the following pieces of context to answer the question. If you don't know the answer, just say that you don't know, don't try to make up an answer.
145
+
146
+ Patient Report Summary: {patient_data}
147
+
148
+ Context from medical knowledge base: {context}
149
+
150
+ Question: {question}
151
+
152
+ Answer:
153
+ """
154
+
155
+ prompt = PromptTemplate(
156
+ template=template,
157
+ input_variables=["context", "question", "patient_data"]
158
+ )
159
+
160
+ # Create the QA chain
161
+ self.qa_chain = RetrievalQA.from_chain_type(
162
+ llm=self.llm,
163
+ chain_type="stuff",
164
+ retriever=self.vector_store.as_retriever(search_kwargs={"k": 5}),
165
+ chain_type_kwargs={"prompt": prompt},
166
+ return_source_documents=False
167
+ )
168
+
169
+ def remove_header_information(self, text):
170
+ """Remove header information from the report text"""
171
+ # Store the original text
172
+ self.original_report_data = text
173
+
174
+ # Split the text into lines to analyze
175
+ lines = text.split('\n')
176
+
177
+ # Define patterns to identify header information
178
+ header_patterns = [
179
+ r'(Name\s*:)',
180
+ r'(Patient\s*Name\s*:)',
181
+ r'(DOB|Date of Birth\s*:)',
182
+ r'(Age\s*:)',
183
+ r'(Gender\s*:)',
184
+ r'(Lab No\.|Laboratory Number\s*:)',
185
+ r'(Patient ID\s*:)',
186
+ r'(Report Status\s*:)',
187
+ r'(Ref By|Referred By\s*:)',
188
+ r'(Collected\s*:)',
189
+ r'(Reported\s*:)',
190
+ r'(A/c Status\s*:)',
191
+ r'(Processed at\s*:)',
192
+ r'(Collected at\s*:)',
193
+ r'(Address\s*:)',
194
+ r'(Phone|Mobile|Mob\s*:)',
195
+ ]
196
+
197
+ # Create a regex pattern that matches any of the header patterns
198
+ combined_pattern = '|'.join(header_patterns)
199
+
200
+ # Find where the actual test results begin
201
+ test_results_start = -1
202
+ for i, line in enumerate(lines):
203
+ if re.search(r'(Test\s*Report|Test\s*Name|Test\s*Results|Results|HEMOGRAM|ROUTINE|EXAMINATION)', line, re.IGNORECASE):
204
+ test_results_start = i
205
+ break
206
+
207
+ # If we couldn't find the start of test results, look for key medical terms
208
+ if test_results_start == -1:
209
+ for i, line in enumerate(lines):
210
+ # Look for common test result sections
211
+ if re.search(r'(Hemoglobin|Blood|Urine|CBC|Glucose|Cholesterol|Protein|RBC|WBC)', line, re.IGNORECASE):
212
+ test_results_start = max(0, i-3) # Start a few lines before the first test result
213
+ break
214
+
215
+ # If we still couldn't find the start of test results, use a heuristic:
216
+ # Skip the first few lines which usually contain header information
217
+ if test_results_start == -1:
218
+ # Count lines with patient identifiable information
219
+ header_count = 0
220
+ for i, line in enumerate(lines):
221
+ if re.search(combined_pattern, line, re.IGNORECASE):
222
+ header_count += 1
223
+
224
+ # If we found several header lines, skip those plus a few more
225
+ if header_count > 0:
226
+ test_results_start = min(header_count + 5, len(lines) // 3)
227
+ else:
228
+ # If no clear header pattern was found, just skip the first 10% of lines as a fallback
229
+ test_results_start = max(1, len(lines) // 10)
230
+
231
+ # Return text from the determined start point
232
+ clean_text = '\n'.join(lines[test_results_start:])
233
+
234
+ # If this dramatically shortened the text, use a less aggressive approach
235
+ if len(clean_text) < len(text) * 0.5:
236
+ print("Warning: Header removal may have removed too much content. Using alternative approach.")
237
+ # Alternative approach: Just remove lines with header patterns
238
+ filtered_lines = []
239
+ for line in lines:
240
+ if not re.search(combined_pattern, line, re.IGNORECASE):
241
+ filtered_lines.append(line)
242
+ clean_text = '\n'.join(filtered_lines)
243
+
244
+ return clean_text
245
+
246
+ def extract_text_from_pdf_pymupdf(self, pdf_path):
247
+ """Extract text from PDF using PyMuPDF (more robust than PyPDF)"""
248
+ text = ""
249
+ try:
250
+ doc = fitz.open(pdf_path)
251
+ for page in doc:
252
+ text += page.get_text()
253
+ doc.close()
254
+ return text
255
+ except Exception as e:
256
+ print(f"PyMuPDF extraction error: {str(e)}")
257
+ return None
258
+
259
+ def extract_text_from_pdf_pypdf(self, pdf_path):
260
+ """Extract text using PyPDF as a backup method"""
261
+ try:
262
+ loader = PyPDFLoader(pdf_path)
263
+ pages = loader.load()
264
+ return "\n".join([page.page_content for page in pages])
265
+ except Exception as e:
266
+ print(f"PyPDF extraction error: {str(e)}")
267
+ return None
268
+
269
+ def process_user_report(self, report_file):
270
+ """Process the uploaded medical report with multiple fallback methods"""
271
+ if report_file is None:
272
+ return "No file uploaded. Please upload a medical report."
273
+
274
+ # Ensure the uploaded file is read as bytes
275
+ temp_dir = tempfile.mkdtemp()
276
+ try:
277
+ # Copy the uploaded file to the temp directory
278
+ temp_file_path = os.path.join(temp_dir, "user_report.pdf")
279
+
280
+ # Handle file based on its type
281
+ try:
282
+ if isinstance(report_file, str): # If it's a file path
283
+ shutil.copy(report_file, temp_file_path)
284
+ elif hasattr(report_file, 'name'): # Gradio file object
285
+ with open(temp_file_path, 'wb') as f:
286
+ with open(report_file.name, 'rb') as source:
287
+ f.write(source.read())
288
+ else: # Try to handle as bytes or file-like object
289
+ with open(temp_file_path, 'wb') as f:
290
+ f.write(report_file.read() if hasattr(report_file, 'read') else report_file)
291
+ except Exception as e:
292
+ print(f"Error saving file: {str(e)}")
293
+ return f"Error saving the uploaded file: {str(e)}"
294
+
295
+ # Try multiple methods to extract text from the PDF
296
+ text = None
297
+
298
+ # Method 1: PyMuPDF
299
+ text = self.extract_text_from_pdf_pymupdf(temp_file_path)
300
+
301
+ # Method 2: PyPDF as fallback
302
+ if not text:
303
+ text = self.extract_text_from_pdf_pypdf(temp_file_path)
304
+
305
+ # Method 3: Last resort - try to read as raw text
306
+ if not text:
307
+ try:
308
+ with open(temp_file_path, 'r', errors='ignore') as f:
309
+ text = f.read()
310
+ except Exception as e:
311
+ print(f"Raw text reading error: {str(e)}")
312
+
313
+ # If we got text, process it
314
+ if text and len(text.strip()) > 0:
315
+ # Remove header information from the text
316
+ cleaned_text = self.remove_header_information(text)
317
+
318
+ # Store the cleaned text
319
+ self.user_report_data = cleaned_text
320
+
321
+ # Split into chunks if needed
322
+ text_splitter = RecursiveCharacterTextSplitter(
323
+ chunk_size=1000,
324
+ chunk_overlap=200,
325
+ length_function=len
326
+ )
327
+ chunks = text_splitter.split_text(cleaned_text)
328
+
329
+ # Check if too much text was removed
330
+ original_length = len(text.strip())
331
+ cleaned_length = len(cleaned_text.strip())
332
+ removal_percentage = (original_length - cleaned_length) / original_length * 100
333
+
334
+ if removal_percentage > 80:
335
+ return f"Report processed successfully, but significant content may have been filtered. Original length: {original_length} chars. Cleaned length: {cleaned_length} chars. Extracted approximately {len(chunks)} text chunks."
336
+ else:
337
+ return f"Report processed successfully. Removed approximately {removal_percentage:.1f}% of header content. Extracted {len(chunks)} text chunks."
338
+ else:
339
+ self.user_report_data = "Unable to extract text from the provided PDF. This is an empty report placeholder."
340
+ return "Warning: Could not extract text from the PDF. The file may be corrupted, password-protected, or contain only images. Processing will continue with limited data."
341
+
342
+ finally:
343
+ # Clean up the temporary directory and file
344
+ shutil.rmtree(temp_dir)
345
+
346
+ def answer_question(self, question):
347
+ """Answer a question based on the uploaded report and knowledge base"""
348
+ if not self.user_report_data or self.user_report_data == "No report data available.":
349
+ return "No report has been processed or text extraction failed. Please upload a medical report first."
350
+
351
+ # Get context from knowledge base
352
+ try:
353
+ retrieved_docs = self.vector_store.similarity_search(question, k=5)
354
+ context = "\n\n".join([doc.page_content for doc in retrieved_docs])
355
+
356
+ # Check if question is about patient demographics or identification
357
+ demographic_patterns = [
358
+ r'(patient|name|age|gender|birth|dob|address|phone|contact|id|identification)',
359
+ r'(doctor|physician|referring|referred by)',
360
+ r'(date|time|collected|processed|reported)',
361
+ r'(lab|laboratory|number|id)'
362
+ ]
363
+
364
+ combined_demo_pattern = '|'.join(demographic_patterns)
365
+
366
+ # If question might be about demographics, check if we need to use original data
367
+ if re.search(combined_demo_pattern, question, re.IGNORECASE):
368
+ # For demographic questions, we can use the original report that includes headers
369
+ # But only if we have specific identification information requests
370
+ specific_id_patterns = [
371
+ r'(name of|patient name|who is|what is the name)',
372
+ r'(exact age|age of|how old)',
373
+ r'(address of|where|location|contact details)',
374
+ r'(doctor name|name of doctor|referring doctor|who referred)',
375
+ r'(date of|when was|time of|report date)',
376
+ r'(lab number|patient id|identification number)'
377
+ ]
378
+
379
+ specific_id_pattern = '|'.join(specific_id_patterns)
380
+
381
+ # If it's a direct question about patient identity, don't answer
382
+ if re.search(specific_id_pattern, question, re.IGNORECASE):
383
+ return "I'm unable to provide specific patient identification information. This feature is disabled to protect patient privacy. Please ask about medical test results or interpretations instead."
384
+
385
+ # Create the inputs dict for the QA chain
386
+ inputs = {
387
+ "query": question,
388
+ "context": context,
389
+ "patient_data": self.user_report_data
390
+ }
391
+
392
+ # Run the chain with the correct parameter structure
393
+ result = self.qa_chain(inputs)
394
+
395
+ # Extract the answer from the result
396
+ if isinstance(result, dict) and 'result' in result:
397
+ return result['result']
398
+ else:
399
+ return str(result)
400
+
401
+ except Exception as e:
402
+ print(f"Error answering question: {str(e)}")
403
+ error_msg = f"Error processing your question: {str(e)}."
404
+
405
+ # Try direct LLM call as fallback
406
+ try:
407
+ direct_prompt = f"""
408
+ Question about medical report: {question}
409
+
410
+ Patient data available: {self.user_report_data[:800]}... (truncated)
411
+
412
+ Please answer based on this information:
413
+ """
414
+
415
+ direct_result = self.llm(direct_prompt)
416
+ return f"{error_msg} Fallback answer: {direct_result}"
417
+ except:
418
+ return f"{error_msg} Please try a different question or report."
419
 
420
  # Initialize the analyzer
421
  analyzer = MedicalReportAnalyzer()