Spaces:
Runtime error
Runtime error
import os | |
import logging | |
import torch | |
import gradio as gr | |
from tqdm import tqdm | |
from PIL import Image | |
# LangChain & LangGraph | |
from langgraph.graph import StateGraph | |
from langgraph.checkpoint.memory import MemorySaver | |
from langchain.tools import tool | |
from langchain_community.vectorstores import FAISS | |
from langchain_community.embeddings import HuggingFaceEmbeddings | |
from langchain_community.document_loaders import PyPDFLoader | |
from langchain.text_splitter import RecursiveCharacterTextSplitter | |
# Web Search | |
from duckduckgo_search import DDGS | |
# Llama GGUF Model Loader | |
from llama_cpp import Llama | |
# ------------------------------ | |
# ๐น Setup Logging | |
# ------------------------------ | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# ------------------------------ | |
# ๐น Load GGUF Model with llama-cpp-python | |
# ------------------------------ | |
model_path = "./Bio-Medical-MultiModal-Llama-3-8B-V1.i1-Q6_K.gguf" # Update with actual GGUF model path | |
llm = Llama(model_path=model_path, n_ctx=2048, n_gpu_layers=35) # Optimized for Hugging Face T4 GPU | |
logger.info("Llama GGUF Model Loaded Successfully.") | |
# ------------------------------ | |
# ๐น Define Expert System Prompts | |
# ------------------------------ | |
GP_PROMPT = "You are a General Practitioner AI Assistant. Answer medical questions with scientifically accurate information." | |
RADIOLOGY_PROMPT = "You are a Radiology AI expert. Analyze medical images and provide diagnostic insights." | |
WEBSEARCH_PROMPT = "You are a Web Search AI. Retrieve up-to-date medical information." | |
# ------------------------------ | |
# ๐น FAISS Vector Store for RAG | |
# ------------------------------ | |
_vector_store_cache = None | |
def load_vectorstore(pdf_path="medical_docs.pdf"): | |
"""Loads PDF files into a FAISS vector store for RAG.""" | |
try: | |
loader = PyPDFLoader(pdf_path) | |
documents = loader.load() | |
text_splitter = RecursiveCharacterTextSplitter(chunk_size=512, chunk_overlap=50) | |
docs = text_splitter.split_documents(documents) | |
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | |
vector_store = FAISS.from_documents(docs, embeddings) | |
logger.info(f"Vector store loaded with {len(docs)} documents.") | |
return vector_store | |
except Exception as e: | |
logger.error(f"Error loading vector store: {str(e)}") | |
return None | |
def update_vector_store(pdf_file): | |
"""Updates FAISS vector store when a new PDF is uploaded.""" | |
pdf_path = "uploaded_medical_docs.pdf" | |
try: | |
with open(pdf_path, "wb") as f: | |
f.write(pdf_file.read()) | |
vector_store = load_vectorstore(pdf_path) | |
os.remove(pdf_path) # Clean up | |
return vector_store | |
except Exception as e: | |
logger.error(f"Error updating vector store: {str(e)}") | |
return _vector_store_cache # Fallback to cached version | |
if os.path.exists("medical_docs.pdf"): | |
_vector_store_cache = load_vectorstore("medical_docs.pdf") | |
else: | |
_vector_store_cache = None | |
vector_store = _vector_store_cache | |
# ------------------------------ | |
# ๐น Define AI Tools | |
# ------------------------------ | |
def analyze_medical_image(image_path: str): | |
"""Analyzes a medical image and returns a diagnostic explanation.""" | |
try: | |
image = Image.open(image_path) | |
except Exception as e: | |
logger.error(f"Error opening image: {str(e)}") | |
return "Error processing image." | |
# Process image using Llama GGUF model | |
output = llm(f"Analyze this medical image and provide a diagnosis:\n{image}") | |
return output["choices"][0]["text"] | |
def retrieve_medical_knowledge(query: str): | |
"""Retrieves medical knowledge from FAISS vector store.""" | |
if vector_store is None: | |
return "No external medical knowledge available." | |
retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 3}) | |
docs = retriever.get_relevant_documents(query) | |
citations = [f"[{i+1}] {doc.metadata.get('source', 'Unknown Source')}" for i, doc in enumerate(docs)] | |
citations_text = "\n".join(citations) | |
content = "\n".join([doc.page_content for doc in docs]) | |
return content + f"\n\n**Citations:**\n{citations_text}" | |
def web_search(query: str): | |
"""Performs a real-time web search using DuckDuckGo.""" | |
try: | |
results = ddg(query, max_results=3) | |
summary = "\n".join([f"{r['title']}: {r['body']} ({r['href']})" for r in results]) or "No relevant results found." | |
return summary | |
except Exception as e: | |
logger.error(f"Web search error: {str(e)}") | |
return "Error retrieving web search results." | |
# ------------------------------ | |
# ๐น Define Multi-Agent Workflow (LangGraph) | |
# ------------------------------ | |
class AgentState: | |
def __init__(self, query="", response="", image_path=None, expert=""): | |
self.query = query | |
self.response = response | |
self.image_path = image_path | |
self.expert = expert # "GP", "Radiology", "Web Search" | |
# Memory checkpointing | |
checkpointer = MemorySaver() | |
# Create LangGraph state graph | |
agent_graph = StateGraph(AgentState) | |
def route_query(state: AgentState): | |
"""Determines which AI expert should handle the query.""" | |
if state.image_path: | |
return "radiology_specialist" | |
elif any(word in state.query.lower() for word in ["latest", "update", "breaking news"]): | |
return "web_search_expert" | |
else: | |
return "general_practitioner" | |
def general_practitioner(state: AgentState): | |
"""GP Expert: Handles medical text queries and retrieves knowledge.""" | |
query = state.query | |
retrieved_info = retrieve_medical_knowledge.run(query) | |
output = llm(f"{GP_PROMPT}\nQ: {query}\nA:") | |
return AgentState(query=query, response=output["choices"][0]["text"] + "\n\n" + retrieved_info, expert="GP") | |
def radiology_specialist(state: AgentState): | |
"""Radiology Expert: Analyzes medical images.""" | |
image_analysis = analyze_medical_image.run(state.image_path) | |
return AgentState(query=state.query, response=image_analysis, expert="Radiology") | |
def web_search_expert(state: AgentState): | |
"""Web Search Expert: Retrieves the latest information.""" | |
search_result = web_search.run(state.query) | |
return AgentState(query=state.query, response=search_result, expert="Web Search") | |
# Add nodes | |
agent_graph.add_node("general_practitioner", general_practitioner) | |
agent_graph.add_node("radiology_specialist", radiology_specialist) | |
agent_graph.add_node("web_search_expert", web_search_expert) | |
agent_graph.add_conditional_edges("route_query", route_query, {"general_practitioner", "radiology_specialist", "web_search_expert"}) | |
agent_graph.set_entry_point("route_query") | |
# Compile graph | |
app = agent_graph.compile(checkpointer=checkpointer) | |
# ------------------------------ | |
# ๐น Gradio Interface | |
# ------------------------------ | |
with gr.Blocks(title="Llama3-Med Multi-Agent AI") as demo: | |
gr.Markdown("# ๐ฅ AI Medical Assistant") | |
with gr.Row(): | |
user_input = gr.Textbox(label="Your Question") | |
image_file = gr.Image(label="Upload Medical Image (Optional)", type="file") | |
pdf_file = gr.File(label="Upload PDF (Optional)", file_types=[".pdf"]) | |
submit_btn = gr.Button("Submit") | |
output_text = gr.Textbox(label="Assistant's Response", interactive=False) | |
submit_btn.click(fn=chat_with_agent, inputs=[user_input, image_file, pdf_file], outputs=output_text) | |
if __name__ == "__main__": | |
demo.launch() | |