import streamlit as st import os import nest_asyncio import re from pathlib import Path import typing as t import base64 from mimetypes import guess_type from llama_parse import LlamaParse from llama_index.core.schema import TextNode from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage, Settings from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.llms.openai import OpenAI from llama_index.core.query_engine import CustomQueryEngine from llama_index.multi_modal_llms.openai import OpenAIMultiModal from llama_index.core.prompts import PromptTemplate from llama_index.core.schema import ImageNode from llama_index.core.base.response.schema import Response from typing import Optional, List nest_asyncio.apply() # Setting API keys os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY') os.environ["LLAMA_CLOUD_API_KEY"] = os.getenv('LLAMA_CLOUD_API_KEY') # Initialize Streamlit app st.title("Medical Knowledge Base & Query System") st.sidebar.title("Settings") # User input for file upload st.sidebar.subheader("Upload Knowledge Base") uploaded_file = st.sidebar.file_uploader("Upload a medical text book (pdf)", type=["jpg", "png", "pdf"]) # # Ensure the 'files' directory exists # if not os.path.exists("files"): # os.makedirs("files") # Initialize the parser parser = LlamaParse( result_type="markdown", parsing_instruction="You are given a medical textbook on medicine", use_vendor_multimodal_model=True, vendor_multimodal_model_name="gpt-4o-mini-2024-07-18", show_progress=True, verbose=True, invalidate_cache=True, do_not_cache=True, num_workers=8, language="en" ) # Initialize md_json_objs as an empty list md_json_objs = [] # Upload and process file if uploaded_file: st.sidebar.write("Processing file...") file_path = f"{uploaded_file.name}" with open(file_path, "wb") as f: f.write(uploaded_file.read()) # Parse the uploaded image md_json_objs = parser.get_json_result([file_path]) image_dicts = parser.get_images(md_json_objs, download_path="data_images") # Extract and display parsed information st.write("File successfully processed!") st.write(f"Processed file: {uploaded_file.name}") # Function to encode image to data URL def local_image_to_data_url(image_path): mime_type, _ = guess_type(image_path) if mime_type is None: mime_type = 'image/png' with open(image_path, "rb") as image_file: base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8') return f"data:{mime_type};base64,{base64_encoded_data}" # Function to get sorted image files def get_page_number(file_name): match = re.search(r"-page-(\d+)\.jpg$", str(file_name)) if match: return int(match.group(1)) return 0 def _get_sorted_image_files(image_dir): raw_files = [f for f in list(Path(image_dir).iterdir()) if f.is_file()] sorted_files = sorted(raw_files, key=get_page_number) return sorted_files def get_text_nodes(md_json_objs, image_dir) -> t.List[TextNode]: nodes = [] for result in md_json_objs: json_dicts = result["pages"] document_name = result["file_path"].split('/')[-1] docs = [doc["md"] for doc in json_dicts] image_files = _get_sorted_image_files(image_dir) for idx, doc in enumerate(docs): node = TextNode( text=doc, metadata={"image_path": str(image_files[idx]), "page_num": idx + 1, "document_name": document_name}, ) nodes.append(node) return nodes # Load text nodes if md_json_objs is not empty if md_json_objs: text_nodes = get_text_nodes(md_json_objs, "data_images") else: text_nodes = [] # Setup index and LLM embed_model = OpenAIEmbedding(model="text-embedding-3-large") llm = OpenAI("gpt-4o-mini-2024-07-18") Settings.llm = llm Settings.embed_model = embed_model if not os.path.exists("storage_manuals"): index = VectorStoreIndex(text_nodes, embed_model=embed_model) index.storage_context.persist(persist_dir="./storage_manuals") else: ctx = StorageContext.from_defaults(persist_dir="./storage_manuals") index = load_index_from_storage(ctx) retriever = index.as_retriever() # Query input st.subheader("Ask a Question") query_text = st.text_input("Enter your query:") uploaded_query_image = st.file_uploader("Upload a query image (if any):", type=["jpg", "png"]) # Encode query image if provided encoded_image_url = None if uploaded_query_image: query_image_path = f"{uploaded_query_image.name}" with open(query_image_path, "wb") as img_file: img_file.write(uploaded_query_image.read()) encoded_image_url = local_image_to_data_url(query_image_path) # Setup query engine QA_PROMPT_TMPL = """ You are a friendly medical chatbot designed to assist users by providing accurate and detailed responses to medical questions based on information from medical books. ### Context: --------------------- {context_str} --------------------- ### Query Text: {query_str} ### Query Image: --------------------- {encoded_image_url} --------------------- ### Answer: """ QA_PROMPT = PromptTemplate(QA_PROMPT_TMPL) gpt_4o_mm = OpenAIMultiModal(model="gpt-4o-mini-2024-07-18") class MultimodalQueryEngine(CustomQueryEngine): # def __init__(self, qa_prompt, retriever, multi_modal_llm, node_postprocessors=[]): # super().__init__(qa_prompt=qa_prompt, retriever=retriever, multi_modal_llm=multi_modal_llm, node_postprocessors=node_postprocessors) # def custom_query(self, query_str): # nodes = self.retriever.retrieve(query_str) # image_nodes = [NodeWithScore(node=ImageNode(image_path=n.node.metadata["image_path"])) for n in nodes] # ctx_str = "\n\n".join([r.node.get_content().strip() for r in nodes]) # fmt_prompt = self.qa_prompt.format(context_str=ctx_str, query_str=query_str, encoded_image_url=encoded_image_url) # llm_response = self.multi_modal_llm.complete(prompt=fmt_prompt, image_documents=[image_node.node for image_node in image_nodes]) # return Response(response=str(llm_response), source_nodes=nodes, metadata={"text_nodes": text_nodes, "image_nodes": image_nodes}) class MultimodalQueryEngine(CustomQueryEngine): qa_prompt: PromptTemplate retriever: BaseRetriever multi_modal_llm: OpenAIMultiModal node_postprocessors: Optional[List[BaseNodePostprocessor]] def __init__( self, qa_prompt: PromptTemplate, retriever: BaseRetriever, multi_modal_llm: OpenAIMultiModal, node_postprocessors: Optional[List[BaseNodePostprocessor]] = [], ): super().__init__( qa_prompt=qa_prompt, retriever=retriever, multi_modal_llm=multi_modal_llm, node_postprocessors=node_postprocessors ) def custom_query(self, query_str: str): # retrieve most relevant nodes nodes = self.retriever.retrieve(query_str) for postprocessor in self.node_postprocessors: nodes = postprocessor.postprocess_nodes( nodes, query_bundle=QueryBundle(query_str) ) # create image nodes from the image associated with those nodes image_nodes = [ NodeWithScore(node=ImageNode(image_path=n.node.metadata["image_path"])) for n in nodes ] # create context string from parsed markdown text ctx_str = "\n\n".join( [r.node.get_content(metadata_mode=MetadataMode.LLM).strip() for r in nodes] ) # prompt for the LLM fmt_prompt = self.qa_prompt.format( context_str=ctx_str, query_str=query_str, encoded_image_url=encoded_image_url ) # use the multimodal LLM to interpret images and generate a response to the prompt llm_response = self.multi_modal_llm.complete( prompt=fmt_prompt, image_documents=[image_node.node for image_node in image_nodes], ) return Response( response=str(llm_response), source_nodes=nodes, metadata={"text_nodes": nodes, "image_nodes": image_nodes}, ) query_engine = MultimodalQueryEngine(QA_PROMPT, retriever, gpt_4o_mm) # Handle query if query_text: st.write("Querying...") response = query_engine.custom_query(query_text) st.markdown(response.response)