jiviteshjain commited on
Commit
8042e59
·
1 Parent(s): 7e24f8a

Track files with lfs.

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ data/** filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+
3
+ import streamlit as st
4
+ import torch
5
+ from rag import load_all, run_query
6
+
7
+
8
+ @st.cache_resource(
9
+ show_spinner="Loading models and indices. This might take a while..."
10
+ )
11
+ def get_rag_qa() -> dict:
12
+ gc.collect()
13
+ torch.cuda.empty_cache()
14
+ return load_all(
15
+ embedder_path="Snowflake/snowflake-arctic-embed-l",
16
+ embedder_device="cpu",
17
+ context_file="data/bioasq_contexts.jsonl",
18
+ index_file="data/bioasq_contexts__snowflake-arctic-embed-l__float32_hnsw.index",
19
+ reader_path="meta-llama/Llama-3.2-1B-Instruct",
20
+ reader_device="mps",
21
+ )
22
+
23
+
24
+ left_column, cent_column, last_column = st.columns(3)
25
+ with cent_column:
26
+ st.image("cover.webp", width=400)
27
+ st.title("Ask the BioASQ Database Anything!")
28
+
29
+ # Initialize the RagQA model, might be already cached.
30
+ _ = get_rag_qa()
31
+
32
+ # Run QA
33
+ st.subheader("Ask away:")
34
+ question = st.text_input("Ask away:", "", label_visibility="collapsed")
35
+ submit = st.button("Submit")
36
+
37
+ st.markdown(
38
+ """
39
+ > **For example, ask things like:**
40
+ >
41
+ > What is the Bartter syndrome?
42
+ > Which genes have been found to be associated with restless leg syndrome?
43
+ > Which diseases can be treated with Afamelanotide?
44
+ ---
45
+ """,
46
+ unsafe_allow_html=False,
47
+ )
48
+
49
+ if submit:
50
+ if not question.strip():
51
+ st.error("Machine Learning still can't read minds. Please enter a question.")
52
+ else:
53
+ try:
54
+ with st.spinner(
55
+ "Combing through 3000+ documents from the BioASQ database..."
56
+ ):
57
+ rag_qa = get_rag_qa()
58
+ retrieved_context_ids, sources, answer = run_query(question, **rag_qa)
59
+ print(answer)
60
+ print(retrieved_context_ids)
61
+ print(sources)
62
+
63
+ st.subheader("Answer:")
64
+ st.write(answer)
65
+
66
+ st.write("")
67
+
68
+ with st.expander("Show Sources"):
69
+ st.subheader("Sources:")
70
+ for i, (context_id, source) in enumerate(
71
+ zip(retrieved_context_ids, sources)
72
+ ):
73
+ st.markdown(f"**BioASQ Document ID:** {context_id}")
74
+ st.markdown(f"**Text:**")
75
+ st.write(source)
76
+ if i < len(sources) - 1:
77
+ st.markdown("---")
78
+
79
+ except Exception as e:
80
+ st.error(f"An error occurred: {e}")
cover.webp ADDED
data/bioasq_contexts.jsonl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1bb0fb8e100386c48d37f3a489593c326a474ed8bde13b834c929637a0c0bbc7
3
+ size 4753372
data/bioasq_contexts__snowflake-arctic-embed-l__float32_hnsw.index ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f4fe738c0ca9c5846dacb07d932360fa9d41d967f0028fcb329fc55958f0834
3
+ size 15377790
rag.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # %%
2
+ import os
3
+ import json
4
+
5
+ import torch
6
+ import faiss
7
+ import numpy as np
8
+ from sentence_transformers import SentenceTransformer
9
+ from transformers import (
10
+ pipeline,
11
+ TextGenerationPipeline,
12
+ AutoModelForCausalLM,
13
+ AutoTokenizer,
14
+ )
15
+
16
+ HF_TOKEN = os.environ["hf_token"]
17
+
18
+ SYSTEM_PROMPT = """You are a helpful question answering assistant. You will be given a context and a question. You need to provide the answer to the question based on the context. Answer briefly, based on the context. Only output the answer, and nothing else. Here is an example:
19
+
20
+ >> Context
21
+ Fascin is an actin-bundling protein that induces membrane protrusions and cell motility after the formation of lamellipodia or filopodia. Fascin expression has been associated with progression or prognosis in various neoplasms; however, its role in intrahepatic cholangiocarcinoma is unknown.
22
+
23
+ >> Question
24
+ What type of protein is fascin?
25
+
26
+ >> Answer
27
+ Actin-bundling protein
28
+
29
+ Now answer the user's question based on the user's given context.
30
+ """
31
+
32
+ USER_PROMPT = """
33
+ >> Context
34
+ {context}
35
+
36
+ >> Question
37
+ {question}
38
+
39
+ >> Answer
40
+ """
41
+
42
+
43
+ def load_embedder(model_path: str, device: str) -> SentenceTransformer:
44
+ embedder = SentenceTransformer(model_path)
45
+ embedder.to(device)
46
+ return embedder
47
+
48
+
49
+ def load_contexts(context_file: str) -> list[str]:
50
+ contexts = []
51
+ with open(context_file, "r") as f_in:
52
+ for line in f_in:
53
+ context = json.loads(line)
54
+ contexts.append(context["context"])
55
+
56
+ return contexts
57
+
58
+
59
+ def load_index(index_file: str) -> faiss.Index:
60
+ return faiss.read_index(index_file)
61
+
62
+
63
+ def load_reader(model_path: str, device: str) -> TextGenerationPipeline:
64
+ model = AutoModelForCausalLM.from_pretrained(model_path)
65
+
66
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
67
+ tokenizer.pad_token = tokenizer.eos_token
68
+
69
+ reader = pipeline(
70
+ "text-generation",
71
+ model=model,
72
+ tokenizer=tokenizer,
73
+ torch_dtype=torch.bfloat16,
74
+ token=HF_TOKEN,
75
+ device=device,
76
+ )
77
+
78
+ return reader
79
+
80
+
81
+ def construct_prompt(contexts: list[str], question: str) -> list[dict]:
82
+ return [
83
+ {"role": "system", "content": SYSTEM_PROMPT},
84
+ {
85
+ "role": "user",
86
+ "content": USER_PROMPT.format(
87
+ context="\n".join(contexts), question=question
88
+ ),
89
+ },
90
+ ]
91
+
92
+
93
+ def load_all(
94
+ embedder_path: str,
95
+ embedder_device: str,
96
+ context_file: str,
97
+ index_file: str,
98
+ reader_path: str,
99
+ reader_device: str,
100
+ ) -> tuple[SentenceTransformer, list[str], faiss.Index, TextGenerationPipeline]:
101
+ embedder = load_embedder(embedder_path, embedder_device)
102
+ contexts = load_contexts(context_file)
103
+ index = load_index(index_file)
104
+ reader = load_reader(reader_path, reader_device)
105
+
106
+ return {
107
+ "embedder": embedder,
108
+ "contexts": contexts,
109
+ "index": index,
110
+ "reader": reader,
111
+ }
112
+
113
+
114
+ def run_query(
115
+ question: str,
116
+ embedder: SentenceTransformer,
117
+ index: faiss.Index,
118
+ contexts: list[str],
119
+ reader: TextGenerationPipeline,
120
+ top_k: int = 3,
121
+ ) -> tuple[list[int], list[str], str]:
122
+ query_embedding = embedder.encode([question], normalize_embeddings=True)
123
+ _, retrieved_context_ids = index.search(query_embedding, top_k)
124
+ retrieved_context_ids = np.array(retrieved_context_ids) # shape: (1, top_k)
125
+
126
+ retrieved_contexts = []
127
+ for row in retrieved_context_ids:
128
+ retrieved_contexts.append(
129
+ [contexts[i] if contexts[i] is not None else "" for i in row]
130
+ )
131
+
132
+ # The code below is for a single question.
133
+ prompt = construct_prompt(retrieved_contexts[0], question)
134
+ answer = reader(prompt, max_new_tokens=128, return_full_text=False)
135
+ print(answer)
136
+ answer_text = answer[0]["generated_text"]
137
+ if ">> Answer" in answer_text:
138
+ answer_text = answer_text.split(">> Answer")[1].strip()
139
+
140
+ return retrieved_context_ids[0].tolist(), retrieved_contexts[0], answer_text
141
+
142
+
143
+ # %%
144
+ # embedder_path = "Snowflake/snowflake-arctic-embed-l"
145
+ # reader_path = "meta-llama/Llama-3.2-1B-Instruct"
146
+ # context_file = "../data/bioasq_contexts.jsonl"
147
+ # index_file = "../data/bioasq_contexts__snowflake-arctic-embed-l__float32_hnsw.index"
148
+
149
+ # embedder, contexts, index, reader = load_all(
150
+ # embedder_path, "cpu", context_file, index_file, reader_path, "mps"
151
+ # )
152
+
153
+ # query = "What cellular structures does fascin induce?"
154
+
155
+ # retrieved_context_ids, retrieved_contexts, answer_text = run_query(
156
+ # query, embedder, index, contexts, reader
157
+ # )
158
+
159
+
160
+ # %%