Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import os | |
| import nest_asyncio | |
| import re | |
| from pathlib import Path | |
| import typing as t | |
| import base64 | |
| from mimetypes import guess_type | |
| from llama_parse import LlamaParse | |
| from llama_index.core.schema import TextNode | |
| from llama_index.core import VectorStoreIndex, StorageContext, load_index_from_storage, Settings | |
| from llama_index.embeddings.openai import OpenAIEmbedding | |
| from llama_index.llms.openai import OpenAI | |
| from llama_index.core.query_engine import CustomQueryEngine | |
| from llama_index.multi_modal_llms.openai import OpenAIMultiModal | |
| from llama_index.core.prompts import PromptTemplate | |
| from llama_index.core.schema import ImageNode | |
| from llama_index.core.base.response.schema import Response | |
| from typing import Optional, List | |
| nest_asyncio.apply() | |
| # Setting API keys | |
| os.environ["OPENAI_API_KEY"] = os.getenv('OPENAI_API_KEY') | |
| os.environ["LLAMA_CLOUD_API_KEY"] = os.getenv('LLAMA_CLOUD_API_KEY') | |
| # Initialize Streamlit app | |
| st.title("Medical Knowledge Base & Query System") | |
| st.sidebar.title("Settings") | |
| # User input for file upload | |
| st.sidebar.subheader("Upload Knowledge Base") | |
| uploaded_file = st.sidebar.file_uploader("Upload a medical text book (pdf)", type=["jpg", "png", "pdf"]) | |
| # # Ensure the 'files' directory exists | |
| # if not os.path.exists("files"): | |
| # os.makedirs("files") | |
| # Initialize the parser | |
| parser = LlamaParse( | |
| result_type="markdown", | |
| parsing_instruction="You are given a medical textbook on medicine", | |
| use_vendor_multimodal_model=True, | |
| vendor_multimodal_model_name="gpt-4o-mini-2024-07-18", | |
| show_progress=True, | |
| verbose=True, | |
| invalidate_cache=True, | |
| do_not_cache=True, | |
| num_workers=8, | |
| language="en" | |
| ) | |
| # Initialize md_json_objs as an empty list | |
| md_json_objs = [] | |
| # Upload and process file | |
| if uploaded_file: | |
| st.sidebar.write("Processing file...") | |
| file_path = f"{uploaded_file.name}" | |
| with open(file_path, "wb") as f: | |
| f.write(uploaded_file.read()) | |
| # Parse the uploaded image | |
| md_json_objs = parser.get_json_result([file_path]) | |
| image_dicts = parser.get_images(md_json_objs, download_path="data_images") | |
| # Extract and display parsed information | |
| st.write("File successfully processed!") | |
| st.write(f"Processed file: {uploaded_file.name}") | |
| # Function to encode image to data URL | |
| def local_image_to_data_url(image_path): | |
| mime_type, _ = guess_type(image_path) | |
| if mime_type is None: | |
| mime_type = 'image/png' | |
| with open(image_path, "rb") as image_file: | |
| base64_encoded_data = base64.b64encode(image_file.read()).decode('utf-8') | |
| return f"data:{mime_type};base64,{base64_encoded_data}" | |
| # Function to get sorted image files | |
| def get_page_number(file_name): | |
| match = re.search(r"-page-(\d+)\.jpg$", str(file_name)) | |
| if match: | |
| return int(match.group(1)) | |
| return 0 | |
| def _get_sorted_image_files(image_dir): | |
| raw_files = [f for f in list(Path(image_dir).iterdir()) if f.is_file()] | |
| sorted_files = sorted(raw_files, key=get_page_number) | |
| return sorted_files | |
| def get_text_nodes(md_json_objs, image_dir) -> t.List[TextNode]: | |
| nodes = [] | |
| for result in md_json_objs: | |
| json_dicts = result["pages"] | |
| document_name = result["file_path"].split('/')[-1] | |
| docs = [doc["md"] for doc in json_dicts] | |
| image_files = _get_sorted_image_files(image_dir) | |
| for idx, doc in enumerate(docs): | |
| node = TextNode( | |
| text=doc, | |
| metadata={"image_path": str(image_files[idx]), "page_num": idx + 1, "document_name": document_name}, | |
| ) | |
| nodes.append(node) | |
| return nodes | |
| # Load text nodes if md_json_objs is not empty | |
| if md_json_objs: | |
| text_nodes = get_text_nodes(md_json_objs, "data_images") | |
| else: | |
| text_nodes = [] | |
| # Setup index and LLM | |
| embed_model = OpenAIEmbedding(model="text-embedding-3-large") | |
| llm = OpenAI("gpt-4o-mini-2024-07-18") | |
| Settings.llm = llm | |
| Settings.embed_model = embed_model | |
| if not os.path.exists("storage_manuals"): | |
| index = VectorStoreIndex(text_nodes, embed_model=embed_model) | |
| index.storage_context.persist(persist_dir="./storage_manuals") | |
| else: | |
| ctx = StorageContext.from_defaults(persist_dir="./storage_manuals") | |
| index = load_index_from_storage(ctx) | |
| retriever = index.as_retriever() | |
| # Query input | |
| st.subheader("Ask a Question") | |
| query_text = st.text_input("Enter your query:") | |
| uploaded_query_image = st.file_uploader("Upload a query image (if any):", type=["jpg", "png"]) | |
| # Encode query image if provided | |
| encoded_image_url = None | |
| if uploaded_query_image: | |
| query_image_path = f"{uploaded_query_image.name}" | |
| with open(query_image_path, "wb") as img_file: | |
| img_file.write(uploaded_query_image.read()) | |
| encoded_image_url = local_image_to_data_url(query_image_path) | |
| # Setup query engine | |
| QA_PROMPT_TMPL = """ | |
| You are a friendly medical chatbot designed to assist users by providing accurate and detailed responses to medical questions based on information from medical books. | |
| ### Context: | |
| --------------------- | |
| {context_str} | |
| --------------------- | |
| ### Query Text: | |
| {query_str} | |
| ### Query Image: | |
| --------------------- | |
| {encoded_image_url} | |
| --------------------- | |
| ### Answer: | |
| """ | |
| QA_PROMPT = PromptTemplate(QA_PROMPT_TMPL) | |
| gpt_4o_mm = OpenAIMultiModal(model="gpt-4o-mini-2024-07-18") | |
| class MultimodalQueryEngine(CustomQueryEngine): | |
| def __init__(self, qa_prompt, retriever, multi_modal_llm, node_postprocessors=[]): | |
| super().__init__(qa_prompt=qa_prompt, retriever=retriever, multi_modal_llm=multi_modal_llm, node_postprocessors=node_postprocessors) | |
| def custom_query(self, query_str): | |
| nodes = self.retriever.retrieve(query_str) | |
| image_nodes = [NodeWithScore(node=ImageNode(image_path=n.node.metadata["image_path"])) for n in nodes] | |
| ctx_str = "\n\n".join([r.node.get_content().strip() for r in nodes]) | |
| fmt_prompt = self.qa_prompt.format(context_str=ctx_str, query_str=query_str, encoded_image_url=encoded_image_url) | |
| llm_response = self.multi_modal_llm.complete(prompt=fmt_prompt, image_documents=[image_node.node for image_node in image_nodes]) | |
| return Response(response=str(llm_response), source_nodes=nodes, metadata={"text_nodes": text_nodes, "image_nodes": image_nodes}) | |
| query_engine = MultimodalQueryEngine(QA_PROMPT, retriever, gpt_4o_mm) | |
| # Handle query | |
| if query_text: | |
| st.write("Querying...") | |
| response = query_engine.custom_query(query_text) | |
| st.markdown(response.response) |