Spaces:
Running
Running
# %% | |
import os | |
import json | |
import torch | |
import faiss | |
import numpy as np | |
from sentence_transformers import SentenceTransformer | |
from transformers import ( | |
pipeline, | |
TextGenerationPipeline, | |
AutoModelForCausalLM, | |
AutoTokenizer, | |
) | |
HF_TOKEN = os.environ["hf_token"] | |
SYSTEM_PROMPT = """You are a helpful question answering assistant. You will be given a context and a question. You need to provide the answer to the question based on the context. Answer briefly, based on the context. Only output the answer, and nothing else. Here is an example: | |
>> Context | |
Fascin is an actin-bundling protein that induces membrane protrusions and cell motility after the formation of lamellipodia or filopodia. Fascin expression has been associated with progression or prognosis in various neoplasms; however, its role in intrahepatic cholangiocarcinoma is unknown. | |
>> Question | |
What type of protein is fascin? | |
>> Answer | |
Actin-bundling protein | |
Now answer the user's question based on the user's given context. | |
""" | |
USER_PROMPT = """ | |
>> Context | |
{context} | |
>> Question | |
{question} | |
>> Answer | |
""" | |
def load_embedder(model_path: str, device: str) -> SentenceTransformer: | |
embedder = SentenceTransformer(model_path) | |
embedder.to(device) | |
return embedder | |
def load_contexts(context_file: str) -> list[str]: | |
contexts = [] | |
with open(context_file, "r") as f_in: | |
for line in f_in: | |
context = json.loads(line) | |
contexts.append(context["context"]) | |
return contexts | |
def load_index(index_file: str) -> faiss.Index: | |
return faiss.read_index(index_file) | |
def load_reader(model_path: str, device: str) -> TextGenerationPipeline: | |
model = AutoModelForCausalLM.from_pretrained(model_path, token=HF_TOKEN) | |
tokenizer = AutoTokenizer.from_pretrained(model_path, token=HF_TOKEN) | |
tokenizer.pad_token = tokenizer.eos_token | |
reader = pipeline( | |
"text-generation", | |
model=model, | |
tokenizer=tokenizer, | |
torch_dtype=torch.bfloat16, | |
token=HF_TOKEN, | |
device=device, | |
) | |
return reader | |
def construct_prompt(contexts: list[str], question: str) -> list[dict]: | |
return [ | |
{"role": "system", "content": SYSTEM_PROMPT}, | |
{ | |
"role": "user", | |
"content": USER_PROMPT.format( | |
context="\n".join(contexts), question=question | |
), | |
}, | |
] | |
def load_all( | |
embedder_path: str, | |
context_file: str, | |
index_file: str, | |
reader_path: str, | |
) -> tuple[SentenceTransformer, list[str], faiss.Index, TextGenerationPipeline]: | |
embedder = load_embedder(embedder_path, "cpu") | |
contexts = load_contexts(context_file) | |
index = load_index(index_file) | |
reader_device = "cuda" if torch.cuda.is_available() else "cpu" | |
reader = load_reader(reader_path, reader_device) | |
return { | |
"embedder": embedder, | |
"contexts": contexts, | |
"index": index, | |
"reader": reader, | |
} | |
def run_query( | |
question: str, | |
embedder: SentenceTransformer, | |
index: faiss.Index, | |
contexts: list[str], | |
reader: TextGenerationPipeline, | |
top_k: int = 3, | |
) -> tuple[list[int], list[str], str]: | |
query_embedding = embedder.encode([question], normalize_embeddings=True) | |
_, retrieved_context_ids = index.search(query_embedding, top_k) | |
retrieved_context_ids = np.array(retrieved_context_ids) # shape: (1, top_k) | |
retrieved_contexts = [] | |
for row in retrieved_context_ids: | |
retrieved_contexts.append( | |
[contexts[i] if contexts[i] is not None else "" for i in row] | |
) | |
# The code below is for a single question. | |
prompt = construct_prompt(retrieved_contexts[0], question) | |
answer = reader(prompt, max_new_tokens=128, return_full_text=False) | |
print(answer) | |
answer_text = answer[0]["generated_text"] | |
if ">> Answer" in answer_text: | |
answer_text = answer_text.split(">> Answer")[1].strip() | |
return retrieved_context_ids[0].tolist(), retrieved_contexts[0], answer_text | |
# %% | |
# embedder_path = "Snowflake/snowflake-arctic-embed-l" | |
# reader_path = "meta-llama/Llama-3.2-1B-Instruct" | |
# context_file = "../data/bioasq_contexts.jsonl" | |
# index_file = "../data/bioasq_contexts__snowflake-arctic-embed-l__float32_hnsw.index" | |
# embedder, contexts, index, reader = load_all( | |
# embedder_path, "cpu", context_file, index_file, reader_path, "mps" | |
# ) | |
# query = "What cellular structures does fascin induce?" | |
# retrieved_context_ids, retrieved_contexts, answer_text = run_query( | |
# query, embedder, index, contexts, reader | |
# ) | |
# %% | |