ZDPLI's picture
Create app.py
d3ebdbc verified
import os
import logging
import torch
import gradio as gr
from tqdm import tqdm
from PIL import Image
# LangChain & LangGraph
from langgraph.graph import StateGraph
from langgraph.checkpoint.memory import MemorySaver
from langchain.tools import tool
from langchain_community.vectorstores import FAISS
from langchain_community.embeddings import HuggingFaceEmbeddings
from langchain_community.document_loaders import PyPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
# Web Search
from duckduckgo_search import DDGS
# Llama GGUF Model Loader
from llama_cpp import Llama
# ------------------------------
# ๐Ÿ”น Setup Logging
# ------------------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ------------------------------
# ๐Ÿ”น Load GGUF Model with llama-cpp-python
# ------------------------------
model_path = "./Bio-Medical-MultiModal-Llama-3-8B-V1.i1-Q6_K.gguf" # Update with actual GGUF model path
llm = Llama(model_path=model_path, n_ctx=2048, n_gpu_layers=35) # Optimized for Hugging Face T4 GPU
logger.info("Llama GGUF Model Loaded Successfully.")
# ------------------------------
# ๐Ÿ”น Define Expert System Prompts
# ------------------------------
GP_PROMPT = "You are a General Practitioner AI Assistant. Answer medical questions with scientifically accurate information."
RADIOLOGY_PROMPT = "You are a Radiology AI expert. Analyze medical images and provide diagnostic insights."
WEBSEARCH_PROMPT = "You are a Web Search AI. Retrieve up-to-date medical information."
# ------------------------------
# ๐Ÿ”น FAISS Vector Store for RAG
# ------------------------------
_vector_store_cache = None
def load_vectorstore(pdf_path="medical_docs.pdf"):
"""Loads PDF files into a FAISS vector store for RAG."""
try:
loader = PyPDFLoader(pdf_path)
documents = loader.load()
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50)
docs = text_splitter.split_documents(documents)
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
vector_store = FAISS.from_documents(docs, embeddings)
logger.info(f"Vector store loaded with {len(docs)} documents.")
return vector_store
except Exception as e:
logger.error(f"Error loading vector store: {str(e)}")
return None
def update_vector_store(pdf_file):
"""Updates FAISS vector store when a new PDF is uploaded."""
pdf_path = "uploaded_medical_docs.pdf"
try:
with open(pdf_path, "wb") as f:
f.write(pdf_file.read())
vector_store = load_vectorstore(pdf_path)
os.remove(pdf_path) # Clean up
return vector_store
except Exception as e:
logger.error(f"Error updating vector store: {str(e)}")
return _vector_store_cache # Fallback to cached version
if os.path.exists("medical_docs.pdf"):
_vector_store_cache = load_vectorstore("medical_docs.pdf")
else:
_vector_store_cache = None
vector_store = _vector_store_cache
# ------------------------------
# ๐Ÿ”น Define AI Tools
# ------------------------------
@tool
def analyze_medical_image(image_path: str):
"""Analyzes a medical image and returns a diagnostic explanation."""
try:
image = Image.open(image_path)
except Exception as e:
logger.error(f"Error opening image: {str(e)}")
return "Error processing image."
# Process image using Llama GGUF model
output = llm(f"Analyze this medical image and provide a diagnosis:\n{image}")
return output["choices"][0]["text"]
@tool
def retrieve_medical_knowledge(query: str):
"""Retrieves medical knowledge from FAISS vector store."""
if vector_store is None:
return "No external medical knowledge available."
retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 3})
docs = retriever.get_relevant_documents(query)
citations = [f"[{i+1}] {doc.metadata.get('source', 'Unknown Source')}" for i, doc in enumerate(docs)]
citations_text = "\n".join(citations)
content = "\n".join([doc.page_content for doc in docs])
return content + f"\n\n**Citations:**\n{citations_text}"
@tool
def web_search(query: str):
"""Performs a real-time web search using DuckDuckGo."""
try:
results = ddg(query, max_results=3)
summary = "\n".join([f"{r['title']}: {r['body']} ({r['href']})" for r in results]) or "No relevant results found."
return summary
except Exception as e:
logger.error(f"Web search error: {str(e)}")
return "Error retrieving web search results."
# ------------------------------
# ๐Ÿ”น Define Multi-Agent Workflow (LangGraph)
# ------------------------------
class AgentState:
def __init__(self, query="", response="", image_path=None, expert=""):
self.query = query
self.response = response
self.image_path = image_path
self.expert = expert # "GP", "Radiology", "Web Search"
# Memory checkpointing
checkpointer = MemorySaver()
# Create LangGraph state graph
agent_graph = StateGraph(AgentState)
def route_query(state: AgentState):
"""Determines which AI expert should handle the query."""
if state.image_path:
return "radiology_specialist"
elif any(word in state.query.lower() for word in ["latest", "update", "breaking news"]):
return "web_search_expert"
else:
return "general_practitioner"
def general_practitioner(state: AgentState):
"""GP Expert: Handles medical text queries and retrieves knowledge."""
query = state.query
retrieved_info = retrieve_medical_knowledge.run(query)
output = llm(f"{GP_PROMPT}\nQ: {query}\nA:")
return AgentState(query=query, response=output["choices"][0]["text"] + "\n\n" + retrieved_info, expert="GP")
def radiology_specialist(state: AgentState):
"""Radiology Expert: Analyzes medical images."""
image_analysis = analyze_medical_image.run(state.image_path)
return AgentState(query=state.query, response=image_analysis, expert="Radiology")
def web_search_expert(state: AgentState):
"""Web Search Expert: Retrieves the latest information."""
search_result = web_search.run(state.query)
return AgentState(query=state.query, response=search_result, expert="Web Search")
# Add nodes
agent_graph.add_node("general_practitioner", general_practitioner)
agent_graph.add_node("radiology_specialist", radiology_specialist)
agent_graph.add_node("web_search_expert", web_search_expert)
agent_graph.add_conditional_edges("route_query", route_query, {"general_practitioner", "radiology_specialist", "web_search_expert"})
agent_graph.set_entry_point("route_query")
# Compile graph
app = agent_graph.compile(checkpointer=checkpointer)
# ------------------------------
# ๐Ÿ”น Gradio Interface
# ------------------------------
with gr.Blocks(title="Llama3-Med Multi-Agent AI") as demo:
gr.Markdown("# ๐Ÿฅ AI Medical Assistant")
with gr.Row():
user_input = gr.Textbox(label="Your Question")
image_file = gr.Image(label="Upload Medical Image (Optional)", type="file")
pdf_file = gr.File(label="Upload PDF (Optional)", file_types=[".pdf"])
submit_btn = gr.Button("Submit")
output_text = gr.Textbox(label="Assistant's Response", interactive=False)
submit_btn.click(fn=chat_with_agent, inputs=[user_input, image_file, pdf_file], outputs=output_text)
if __name__ == "__main__":
demo.launch()