mediVedaLLM / app.py
rishi002's picture
Update app.py
d555611 verified
import os
import gradio as gr
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain.chains import RetrievalQA
from langchain_core.prompts import PromptTemplate
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
from collections import OrderedDict
import torch
import re
import google.generativeai as genai
import os
# Constants
DATA_PATH = "dataFolder/"
DB_FAISS_PATH = "/tmp/vectorstore/db_faiss"
CACHE_DIR = "/tmp/models_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
# Google AI API setup
try:
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
if GOOGLE_API_KEY:
genai.configure(api_key=GOOGLE_API_KEY)
except Exception as e:
print(f"Google AI API setup error: {e}")
# Conversation state tracking
conversation_state = {
"awaiting_health_info": False,
"original_query": "",
"health_info": "",
"last_query": ""
}
# Load the embedding model
embedding_model = HuggingFaceEmbeddings(
model_name="rishi002/all-MiniLM-L6-v2",
cache_folder=CACHE_DIR
)
# Load or create FAISS database
def load_or_create_faiss():
if not os.path.exists(DB_FAISS_PATH):
print("πŸ”„ Creating FAISS Database...")
from embeddings import load_pdf_files, create_chunks # Your custom chunking logic
documents = load_pdf_files(DATA_PATH)
text_chunks = create_chunks(documents)
db = FAISS.from_documents(text_chunks, embedding_model)
db.save_local(DB_FAISS_PATH)
else:
print("βœ… FAISS Database Exists. Loading...")
return FAISS.load_local(DB_FAISS_PATH, embedding_model, allow_dangerous_deserialization=True)
db = load_or_create_faiss()
# πŸ”½ Load Phi-3 locally
def load_phi3_pipeline():
model_name = "microsoft/Phi-3-mini-4k-instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32)
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1)
return pipe
phi3_pipe = load_phi3_pipeline()
# Prompt templates
CUSTOM_PROMPT_TEMPLATE = """
Use the pieces of information provided in the context to answer the user's question.
If you don't know the answer, just say that you don't know. Don't make up an answer.
Only provide information from the given context.
Keep your answer concise and avoid repeating the same information.
Each important point should be stated only once.
NOTE: SUMMARIZE YOUR ANSWERS STRICTLY WITHIN 300 WORDS.
Context: {context}
Question: {question}
Start the answer directly.
"""
HEALTH_ASSESSMENT_PROMPT = """
Before I provide specific recommendations for your health concern, I need to understand your situation better.
Please share the following information so I can give you more appropriate guidance:
- What symptoms are you experiencing and for how long?
- Do you have any existing medical conditions?
- Are you currently taking any medications?
- Do you have any known allergies?
- Have you tried any treatments for this condition already?
Just share what you're comfortable with - even partial information will help me provide better advice.
"""
MEDICAL_ADVICE_PROMPT = """
Use the pieces of information provided in the context to answer the user's health question.
Consider the user's personal health information while formulating your response.
If the information in the context doesn't fully address the specific situation of the user, acknowledge this limitation.
Always prioritize safety and encourage consulting a healthcare professional for personalized medical advice.
Context: {context}
User's question: {question}
User's health information: {health_info}
Start the answer directly. Keep your answer concise and avoid repeating the same information.
Each important point should be stated only once.
NOTE: SUMMARIZE YOUR ANSWERS STRICTLY WITHIN 300 WORDS.
"""
# Function to detect if query requires health assessment
def requires_health_assessment(query):
# List of health conditions, symptoms, and remedy-seeking phrases
health_conditions = [
'headache', 'migraine', 'pain', 'ache', 'fever', 'cold', 'flu', 'cough',
'sore throat', 'nausea', 'vomiting', 'diarrhea', 'constipation', 'allergy',
'rash', 'itch', 'insomnia', 'anxiety', 'depression', 'stress', 'fatigue',
'tired', 'dizzy', 'dizziness', 'burn', 'cut', 'wound', 'injury', 'sprain',
'inflammation', 'swelling', 'infection', 'diabetes', 'hypertension', 'blood pressure',
'cholesterol', 'heart', 'arthritis', 'asthma', 'indigestion', 'heartburn',
'acne', 'eczema', 'psoriasis', 'sinus', 'congestion', 'back pain', 'joint pain',
'stomach', 'bloating', 'gas', 'cramp', 'menstrual', 'period', 'pregnancy'
]
remedy_phrases = [
'remedy', 'remedies', 'treatment', 'cure', 'heal', 'relief', 'relieve',
'manage', 'help with', 'help for', 'works for', 'good for', 'effective for',
'medicine for', 'medication for', 'drug for', 'herb for', 'supplement for',
'therapy for', 'solution for', 'way to treat', 'way to cure', 'how to treat',
'how to cure', 'how to manage', 'how to relieve', 'what helps', 'what works',
'home remedy', 'natural remedy', 'alternative treatment', 'recommendation for',
'advice for', 'tips for', 'should I take', 'is it safe', 'can I use'
]
query_lower = query.lower()
# Check for personal health indicators
personal_indicators = ['i have', 'i am experiencing', 'i feel', 'i am having',
'my headache', 'my pain', 'my symptom', 'my condition',
'for me', 'my health', 'my problem', 'i suffer']
for indicator in personal_indicators:
if indicator in query_lower:
return True
# Check if any health condition is mentioned
for condition in health_conditions:
if condition in query_lower:
# Check if a remedy phrase is also present
for phrase in remedy_phrases:
if phrase in query_lower:
return True
# Also check for common remedy-seeking structures
remedy_patterns = [
r'(?i)for (my|the|this|a) ' + re.escape(condition),
r'(?i)' + re.escape(condition) + r' (remedy|treatment|cure|relief)',
r'(?i)(help|treat|cure|relieve|manage) (my|the|this|a) ' + re.escape(condition),
r'(?i)how (to|can|do) (I|you) (treat|cure|relieve|manage) (my|the|this|a) ' + re.escape(condition)
]
for pattern in remedy_patterns:
if re.search(pattern, query_lower):
return True
return False
# Function to get health advice from Google AI Flash API
def get_google_ai_health_advice(query, health_info):
try:
if GOOGLE_API_KEY:
# Use Google's Gemini Flash API for health queries
model = genai.GenerativeModel('gemini-2.0-flash')
# Prepare prompt with context from RAG
retriever = db.as_retriever(search_kwargs={'k': 3})
docs = retriever.get_relevant_documents(query)
context = "\n\n".join([doc.page_content for doc in docs])
prompt = f"""
Based on the following medical information:
{context}
And considering the user's health information:
{health_info}
Please provide advice for this question: {query}
Keep your answer concise, evidence-based, and focus on safe recommendations.
Include a reminder to consult healthcare professionals for personalized advice.
"""
response = model.generate_content(prompt)
return response.text
else:
# Fallback to regular RAG if API key not available
return None
except Exception as e:
print(f"Google AI API error: {e}")
return None
# Create qa_chain with the same pattern as your original code
def create_qa_chain():
prompt = PromptTemplate(template=CUSTOM_PROMPT_TEMPLATE, input_variables=["context", "question"])
from langchain.llms.base import LLM
from typing import Optional, List
class HuggingFaceLocalLLM(LLM):
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
output = phi3_pipe(prompt, max_new_tokens=512, do_sample=False)[0]["generated_text"]
return output.replace(prompt, "").strip()
@property
def _identifying_params(self):
return {"name": "local-phi-3"}
@property
def _llm_type(self):
return "custom-local-llm"
return RetrievalQA.from_chain_type(
llm=HuggingFaceLocalLLM(),
chain_type="stuff",
retriever=db.as_retriever(search_kwargs={'k': 3}),
return_source_documents=False,
chain_type_kwargs={'prompt': prompt}
)
# Create additional health-aware QA chain
def create_health_qa_chain():
prompt = PromptTemplate(
template=MEDICAL_ADVICE_PROMPT,
input_variables=["context", "question", "health_info"]
)
from langchain.llms.base import LLM
from typing import Optional, List
class HuggingFaceLocalLLM(LLM):
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
output = phi3_pipe(prompt, max_new_tokens=512, do_sample=False)[0]["generated_text"]
return output.replace(prompt, "").strip()
@property
def _identifying_params(self):
return {"name": "local-phi-3"}
@property
def _llm_type(self):
return "custom-local-llm"
return RetrievalQA.from_chain_type(
llm=HuggingFaceLocalLLM(),
chain_type="stuff",
retriever=db.as_retriever(search_kwargs={'k': 3}),
return_source_documents=False,
chain_type_kwargs={'prompt': prompt}
)
# Main QA Chain (original from your code)
qa_chain = create_qa_chain()
# Additional health-aware QA chain
health_qa_chain = create_health_qa_chain()
# Keep the ask_question function with the original signature
def ask_question(query: str):
global conversation_state
try:
# Detect repeated questions to prevent loops
if query == conversation_state["last_query"]:
conversation_state["awaiting_health_info"] = False
conversation_state["last_query"] = query
# If we're waiting for health info
if conversation_state["awaiting_health_info"]:
# User has provided health info
health_info = query
original_query = conversation_state["original_query"]
conversation_state["awaiting_health_info"] = False
# Try to use Google AI Flash API first
google_response = get_google_ai_health_advice(original_query, health_info)
if google_response:
# Add disclaimer to Google AI response
disclaimer = "\n\nPlease note: This information is for educational purposes only and is not a substitute for professional medical advice. Always consult with a qualified healthcare provider for personalized recommendations."
final_response = google_response
# Reset conversation state
conversation_state = {
"awaiting_health_info": False,
"original_query": "",
"health_info": "",
"last_query": query
}
return final_response, []
# Fallback to our existing health QA chain
try:
# Use the health-aware QA chain
qa_inputs = {
'query': original_query,
'health_info': health_info
}
# Add context about health info to the chain
response = health_qa_chain.invoke(qa_inputs)
result = response["result"]
# Clean up response
sentences = [s.strip() for s in result.split('.') if s.strip()]
unique_sentences = list(OrderedDict.fromkeys(sentences))
cleaned_result = '. '.join(unique_sentences) + '.'
# Add disclaimer
disclaimer = "\n\nPlease note: This information is for educational purposes only and is not a substitute for professional medical advice. Always consult with a qualified healthcare provider for personalized recommendations."
final_response = cleaned_result + disclaimer
# Reset conversation state
conversation_state = {
"awaiting_health_info": False,
"original_query": "",
"health_info": "",
"last_query": query
}
return final_response, []
except Exception as e:
# If health QA chain fails, use the standard QA chain as ultimate fallback
response = qa_chain.invoke({'query': original_query})
result = response["result"]
sentences = [s.strip() for s in result.split('.') if s.strip()]
unique_sentences = list(OrderedDict.fromkeys(sentences))
cleaned_result = '. '.join(unique_sentences) + '.'
disclaimer = "\n\nNote: I could not fully incorporate your health information. This is general advice only. Please consult a healthcare professional."
final_response = cleaned_result + disclaimer
# Reset conversation state
conversation_state = {
"awaiting_health_info": False,
"original_query": "",
"health_info": "",
"last_query": query
}
return final_response, []
# New query - check if it requires health assessment
elif requires_health_assessment(query):
# Store the original query and set the flag
conversation_state["awaiting_health_info"] = True
conversation_state["original_query"] = query
# Return the health assessment prompt
return HEALTH_ASSESSMENT_PROMPT, []
# Standard query - use the original QA chain
else:
response = qa_chain.invoke({'query': query})
result = response["result"]
sentences = [s.strip() for s in result.split('.') if s.strip()]
unique_sentences = list(OrderedDict.fromkeys(sentences))
cleaned_result = '. '.join(unique_sentences) + '.'
return cleaned_result, []
except Exception as e:
return f"Error: {str(e)}", []
# Keep the original interface
iface = gr.Interface(fn=ask_question, inputs="text", outputs=["text", "json"])
iface.launch(share=True)