import gc import streamlit as st import torch from rag import load_all, run_query @st.cache_resource( show_spinner="Loading models and indices. This might take a while..." ) def get_rag_qa() -> dict: gc.collect() torch.cuda.empty_cache() return load_all( embedder_path="Snowflake/snowflake-arctic-embed-l", context_file="data/bioasq_contexts.jsonl", index_file="data/bioasq_contexts__snowflake-arctic-embed-l__float32_hnsw.index", reader_path="meta-llama/Llama-3.2-1B-Instruct", ) left_column, cent_column, last_column = st.columns(3) with cent_column: st.image("cover.webp", width=400) st.title("Ask the BioASQ Database Anything!") # Initialize the RagQA model, might be already cached. _ = get_rag_qa() # Run QA st.subheader("Ask away:") question = st.text_input("Ask away:", "", label_visibility="collapsed") submit = st.button("Submit") st.markdown( """ > **For example, ask things like:** > > What is the Bartter syndrome? > Which genes have been found to be associated with restless leg syndrome? > Which diseases can be treated with Afamelanotide? --- """, unsafe_allow_html=False, ) if submit: if not question.strip(): st.error("Machine Learning still can't read minds. Please enter a question.") else: try: with st.spinner( "Combing through 3000+ documents from the BioASQ database..." ): rag_qa = get_rag_qa() retrieved_context_ids, sources, answer = run_query(question, **rag_qa) print(answer) print(retrieved_context_ids) print(sources) st.subheader("Answer:") st.write(answer) st.write("") with st.expander("Show Sources"): st.subheader("Sources:") for i, (context_id, source) in enumerate( zip(retrieved_context_ids, sources) ): st.markdown(f"**BioASQ Document ID:** {context_id}") st.markdown(f"**Text:**") st.write(source) if i < len(sources) - 1: st.markdown("---") except Exception as e: st.error(f"An error occurred: {e}")