Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain.chains import RetrievalQA | |
from langchain_core.prompts import PromptTemplate | |
from sentence_transformers import SentenceTransformer | |
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
from collections import OrderedDict | |
import torch | |
import re | |
import google.generativeai as genai | |
import os | |
# Constants | |
DATA_PATH = "dataFolder/" | |
DB_FAISS_PATH = "/tmp/vectorstore/db_faiss" | |
CACHE_DIR = "/tmp/models_cache" | |
os.makedirs(CACHE_DIR, exist_ok=True) | |
# Google AI API setup | |
try: | |
GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY") | |
if GOOGLE_API_KEY: | |
genai.configure(api_key=GOOGLE_API_KEY) | |
except Exception as e: | |
print(f"Google AI API setup error: {e}") | |
# Conversation state tracking | |
conversation_state = { | |
"awaiting_health_info": False, | |
"original_query": "", | |
"health_info": "", | |
"last_query": "" | |
} | |
# Load the embedding model | |
embedding_model = HuggingFaceEmbeddings( | |
model_name="rishi002/all-MiniLM-L6-v2", | |
cache_folder=CACHE_DIR | |
) | |
# Load or create FAISS database | |
def load_or_create_faiss(): | |
if not os.path.exists(DB_FAISS_PATH): | |
print("π Creating FAISS Database...") | |
from embeddings import load_pdf_files, create_chunks # Your custom chunking logic | |
documents = load_pdf_files(DATA_PATH) | |
text_chunks = create_chunks(documents) | |
db = FAISS.from_documents(text_chunks, embedding_model) | |
db.save_local(DB_FAISS_PATH) | |
else: | |
print("β FAISS Database Exists. Loading...") | |
return FAISS.load_local(DB_FAISS_PATH, embedding_model, allow_dangerous_deserialization=True) | |
db = load_or_create_faiss() | |
# π½ Load Phi-3 locally | |
def load_phi3_pipeline(): | |
model_name = "microsoft/Phi-3-mini-4k-instruct" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32) | |
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device=0 if torch.cuda.is_available() else -1) | |
return pipe | |
phi3_pipe = load_phi3_pipeline() | |
# Prompt templates | |
CUSTOM_PROMPT_TEMPLATE = """ | |
Use the pieces of information provided in the context to answer the user's question. | |
If you don't know the answer, just say that you don't know. Don't make up an answer. | |
Only provide information from the given context. | |
Keep your answer concise and avoid repeating the same information. | |
Each important point should be stated only once. | |
NOTE: SUMMARIZE YOUR ANSWERS STRICTLY WITHIN 300 WORDS. | |
Context: {context} | |
Question: {question} | |
Start the answer directly. | |
""" | |
HEALTH_ASSESSMENT_PROMPT = """ | |
Before I provide specific recommendations for your health concern, I need to understand your situation better. | |
Please share the following information so I can give you more appropriate guidance: | |
- What symptoms are you experiencing and for how long? | |
- Do you have any existing medical conditions? | |
- Are you currently taking any medications? | |
- Do you have any known allergies? | |
- Have you tried any treatments for this condition already? | |
Just share what you're comfortable with - even partial information will help me provide better advice. | |
""" | |
MEDICAL_ADVICE_PROMPT = """ | |
Use the pieces of information provided in the context to answer the user's health question. | |
Consider the user's personal health information while formulating your response. | |
If the information in the context doesn't fully address the specific situation of the user, acknowledge this limitation. | |
Always prioritize safety and encourage consulting a healthcare professional for personalized medical advice. | |
Context: {context} | |
User's question: {question} | |
User's health information: {health_info} | |
Start the answer directly. Keep your answer concise and avoid repeating the same information. | |
Each important point should be stated only once. | |
NOTE: SUMMARIZE YOUR ANSWERS STRICTLY WITHIN 300 WORDS. | |
""" | |
# Function to detect if query requires health assessment | |
def requires_health_assessment(query): | |
# List of health conditions, symptoms, and remedy-seeking phrases | |
health_conditions = [ | |
'headache', 'migraine', 'pain', 'ache', 'fever', 'cold', 'flu', 'cough', | |
'sore throat', 'nausea', 'vomiting', 'diarrhea', 'constipation', 'allergy', | |
'rash', 'itch', 'insomnia', 'anxiety', 'depression', 'stress', 'fatigue', | |
'tired', 'dizzy', 'dizziness', 'burn', 'cut', 'wound', 'injury', 'sprain', | |
'inflammation', 'swelling', 'infection', 'diabetes', 'hypertension', 'blood pressure', | |
'cholesterol', 'heart', 'arthritis', 'asthma', 'indigestion', 'heartburn', | |
'acne', 'eczema', 'psoriasis', 'sinus', 'congestion', 'back pain', 'joint pain', | |
'stomach', 'bloating', 'gas', 'cramp', 'menstrual', 'period', 'pregnancy' | |
] | |
remedy_phrases = [ | |
'remedy', 'remedies', 'treatment', 'cure', 'heal', 'relief', 'relieve', | |
'manage', 'help with', 'help for', 'works for', 'good for', 'effective for', | |
'medicine for', 'medication for', 'drug for', 'herb for', 'supplement for', | |
'therapy for', 'solution for', 'way to treat', 'way to cure', 'how to treat', | |
'how to cure', 'how to manage', 'how to relieve', 'what helps', 'what works', | |
'home remedy', 'natural remedy', 'alternative treatment', 'recommendation for', | |
'advice for', 'tips for', 'should I take', 'is it safe', 'can I use' | |
] | |
query_lower = query.lower() | |
# Check for personal health indicators | |
personal_indicators = ['i have', 'i am experiencing', 'i feel', 'i am having', | |
'my headache', 'my pain', 'my symptom', 'my condition', | |
'for me', 'my health', 'my problem', 'i suffer'] | |
for indicator in personal_indicators: | |
if indicator in query_lower: | |
return True | |
# Check if any health condition is mentioned | |
for condition in health_conditions: | |
if condition in query_lower: | |
# Check if a remedy phrase is also present | |
for phrase in remedy_phrases: | |
if phrase in query_lower: | |
return True | |
# Also check for common remedy-seeking structures | |
remedy_patterns = [ | |
r'(?i)for (my|the|this|a) ' + re.escape(condition), | |
r'(?i)' + re.escape(condition) + r' (remedy|treatment|cure|relief)', | |
r'(?i)(help|treat|cure|relieve|manage) (my|the|this|a) ' + re.escape(condition), | |
r'(?i)how (to|can|do) (I|you) (treat|cure|relieve|manage) (my|the|this|a) ' + re.escape(condition) | |
] | |
for pattern in remedy_patterns: | |
if re.search(pattern, query_lower): | |
return True | |
return False | |
# Function to get health advice from Google AI Flash API | |
def get_google_ai_health_advice(query, health_info): | |
try: | |
if GOOGLE_API_KEY: | |
# Use Google's Gemini Flash API for health queries | |
model = genai.GenerativeModel('gemini-2.0-flash') | |
# Prepare prompt with context from RAG | |
retriever = db.as_retriever(search_kwargs={'k': 3}) | |
docs = retriever.get_relevant_documents(query) | |
context = "\n\n".join([doc.page_content for doc in docs]) | |
prompt = f""" | |
Based on the following medical information: | |
{context} | |
And considering the user's health information: | |
{health_info} | |
Please provide advice for this question: {query} | |
Keep your answer concise, evidence-based, and focus on safe recommendations. | |
Include a reminder to consult healthcare professionals for personalized advice. | |
""" | |
response = model.generate_content(prompt) | |
return response.text | |
else: | |
# Fallback to regular RAG if API key not available | |
return None | |
except Exception as e: | |
print(f"Google AI API error: {e}") | |
return None | |
# Create qa_chain with the same pattern as your original code | |
def create_qa_chain(): | |
prompt = PromptTemplate(template=CUSTOM_PROMPT_TEMPLATE, input_variables=["context", "question"]) | |
from langchain.llms.base import LLM | |
from typing import Optional, List | |
class HuggingFaceLocalLLM(LLM): | |
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: | |
output = phi3_pipe(prompt, max_new_tokens=512, do_sample=False)[0]["generated_text"] | |
return output.replace(prompt, "").strip() | |
def _identifying_params(self): | |
return {"name": "local-phi-3"} | |
def _llm_type(self): | |
return "custom-local-llm" | |
return RetrievalQA.from_chain_type( | |
llm=HuggingFaceLocalLLM(), | |
chain_type="stuff", | |
retriever=db.as_retriever(search_kwargs={'k': 3}), | |
return_source_documents=False, | |
chain_type_kwargs={'prompt': prompt} | |
) | |
# Create additional health-aware QA chain | |
def create_health_qa_chain(): | |
prompt = PromptTemplate( | |
template=MEDICAL_ADVICE_PROMPT, | |
input_variables=["context", "question", "health_info"] | |
) | |
from langchain.llms.base import LLM | |
from typing import Optional, List | |
class HuggingFaceLocalLLM(LLM): | |
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str: | |
output = phi3_pipe(prompt, max_new_tokens=512, do_sample=False)[0]["generated_text"] | |
return output.replace(prompt, "").strip() | |
def _identifying_params(self): | |
return {"name": "local-phi-3"} | |
def _llm_type(self): | |
return "custom-local-llm" | |
return RetrievalQA.from_chain_type( | |
llm=HuggingFaceLocalLLM(), | |
chain_type="stuff", | |
retriever=db.as_retriever(search_kwargs={'k': 3}), | |
return_source_documents=False, | |
chain_type_kwargs={'prompt': prompt} | |
) | |
# Main QA Chain (original from your code) | |
qa_chain = create_qa_chain() | |
# Additional health-aware QA chain | |
health_qa_chain = create_health_qa_chain() | |
# Keep the ask_question function with the original signature | |
def ask_question(query: str): | |
global conversation_state | |
try: | |
# Detect repeated questions to prevent loops | |
if query == conversation_state["last_query"]: | |
conversation_state["awaiting_health_info"] = False | |
conversation_state["last_query"] = query | |
# If we're waiting for health info | |
if conversation_state["awaiting_health_info"]: | |
# User has provided health info | |
health_info = query | |
original_query = conversation_state["original_query"] | |
conversation_state["awaiting_health_info"] = False | |
# Try to use Google AI Flash API first | |
google_response = get_google_ai_health_advice(original_query, health_info) | |
if google_response: | |
# Add disclaimer to Google AI response | |
disclaimer = "\n\nPlease note: This information is for educational purposes only and is not a substitute for professional medical advice. Always consult with a qualified healthcare provider for personalized recommendations." | |
final_response = google_response | |
# Reset conversation state | |
conversation_state = { | |
"awaiting_health_info": False, | |
"original_query": "", | |
"health_info": "", | |
"last_query": query | |
} | |
return final_response, [] | |
# Fallback to our existing health QA chain | |
try: | |
# Use the health-aware QA chain | |
qa_inputs = { | |
'query': original_query, | |
'health_info': health_info | |
} | |
# Add context about health info to the chain | |
response = health_qa_chain.invoke(qa_inputs) | |
result = response["result"] | |
# Clean up response | |
sentences = [s.strip() for s in result.split('.') if s.strip()] | |
unique_sentences = list(OrderedDict.fromkeys(sentences)) | |
cleaned_result = '. '.join(unique_sentences) + '.' | |
# Add disclaimer | |
disclaimer = "\n\nPlease note: This information is for educational purposes only and is not a substitute for professional medical advice. Always consult with a qualified healthcare provider for personalized recommendations." | |
final_response = cleaned_result + disclaimer | |
# Reset conversation state | |
conversation_state = { | |
"awaiting_health_info": False, | |
"original_query": "", | |
"health_info": "", | |
"last_query": query | |
} | |
return final_response, [] | |
except Exception as e: | |
# If health QA chain fails, use the standard QA chain as ultimate fallback | |
response = qa_chain.invoke({'query': original_query}) | |
result = response["result"] | |
sentences = [s.strip() for s in result.split('.') if s.strip()] | |
unique_sentences = list(OrderedDict.fromkeys(sentences)) | |
cleaned_result = '. '.join(unique_sentences) + '.' | |
disclaimer = "\n\nNote: I could not fully incorporate your health information. This is general advice only. Please consult a healthcare professional." | |
final_response = cleaned_result + disclaimer | |
# Reset conversation state | |
conversation_state = { | |
"awaiting_health_info": False, | |
"original_query": "", | |
"health_info": "", | |
"last_query": query | |
} | |
return final_response, [] | |
# New query - check if it requires health assessment | |
elif requires_health_assessment(query): | |
# Store the original query and set the flag | |
conversation_state["awaiting_health_info"] = True | |
conversation_state["original_query"] = query | |
# Return the health assessment prompt | |
return HEALTH_ASSESSMENT_PROMPT, [] | |
# Standard query - use the original QA chain | |
else: | |
response = qa_chain.invoke({'query': query}) | |
result = response["result"] | |
sentences = [s.strip() for s in result.split('.') if s.strip()] | |
unique_sentences = list(OrderedDict.fromkeys(sentences)) | |
cleaned_result = '. '.join(unique_sentences) + '.' | |
return cleaned_result, [] | |
except Exception as e: | |
return f"Error: {str(e)}", [] | |
# Keep the original interface | |
iface = gr.Interface(fn=ask_question, inputs="text", outputs=["text", "json"]) | |
iface.launch(share=True) |