jiviteshjain commited on
Commit
6965281
·
1 Parent(s): afd9c09
Files changed (3) hide show
  1. app.py +0 -2
  2. rag.py +2 -3
  3. requirements.txt +0 -1
app.py CHANGED
@@ -13,11 +13,9 @@ def get_rag_qa() -> dict:
13
  torch.cuda.empty_cache()
14
  return load_all(
15
  embedder_path="Snowflake/snowflake-arctic-embed-l",
16
- embedder_device="cpu",
17
  context_file="data/bioasq_contexts.jsonl",
18
  index_file="data/bioasq_contexts__snowflake-arctic-embed-l__float32_hnsw.index",
19
  reader_path="meta-llama/Llama-3.2-1B-Instruct",
20
- reader_device="mps",
21
  )
22
 
23
 
 
13
  torch.cuda.empty_cache()
14
  return load_all(
15
  embedder_path="Snowflake/snowflake-arctic-embed-l",
 
16
  context_file="data/bioasq_contexts.jsonl",
17
  index_file="data/bioasq_contexts__snowflake-arctic-embed-l__float32_hnsw.index",
18
  reader_path="meta-llama/Llama-3.2-1B-Instruct",
 
19
  )
20
 
21
 
rag.py CHANGED
@@ -92,15 +92,14 @@ def construct_prompt(contexts: list[str], question: str) -> list[dict]:
92
 
93
  def load_all(
94
  embedder_path: str,
95
- embedder_device: str,
96
  context_file: str,
97
  index_file: str,
98
  reader_path: str,
99
- reader_device: str,
100
  ) -> tuple[SentenceTransformer, list[str], faiss.Index, TextGenerationPipeline]:
101
- embedder = load_embedder(embedder_path, embedder_device)
102
  contexts = load_contexts(context_file)
103
  index = load_index(index_file)
 
104
  reader = load_reader(reader_path, reader_device)
105
 
106
  return {
 
92
 
93
  def load_all(
94
  embedder_path: str,
 
95
  context_file: str,
96
  index_file: str,
97
  reader_path: str,
 
98
  ) -> tuple[SentenceTransformer, list[str], faiss.Index, TextGenerationPipeline]:
99
+ embedder = load_embedder(embedder_path, "cpu")
100
  contexts = load_contexts(context_file)
101
  index = load_index(index_file)
102
+ reader_device = "cuda" if torch.cuda.is_available() else "cpu"
103
  reader = load_reader(reader_path, reader_device)
104
 
105
  return {
requirements.txt CHANGED
@@ -68,7 +68,6 @@ setuptools==75.6.0
68
  six==1.17.0
69
  smmap==5.0.1
70
  stack-data==0.6.3
71
- streamlit==1.41.1
72
  sympy==1.13.1
73
  tenacity==9.0.0
74
  threadpoolctl==3.5.0
 
68
  six==1.17.0
69
  smmap==5.0.1
70
  stack-data==0.6.3
 
71
  sympy==1.13.1
72
  tenacity==9.0.0
73
  threadpoolctl==3.5.0