Spaces:
Sleeping
Sleeping
File size: 1,562 Bytes
47ad957 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
import time
from typing import List, Optional, Union
from langchain_community.chat_models import ChatOllama
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from retriever import (
create_parent_retriever,
load_embedding_model,
load_pdf,
load_reranker_model,
retrieve_context,
)
def main(
file: str = "2401.08406v3.pdf",
llm_name="mistral",
):
docs = load_pdf(files=file)
embedding_model = load_embedding_model()
retriever = create_parent_retriever(docs, embedding_model)
reranker_model = load_reranker_model()
llm = ChatOllama(model=llm_name)
prompt_template = ChatPromptTemplate.from_template(
(
"Please answer the following question based on the provided `context` that follows the question.\n"
"If you do not know the answer then just say 'I do not know'\n"
"question: {question}\n"
"context: ```{context}```\n"
)
)
chain = prompt_template | llm | StrOutputParser()
while True:
query = input("Ask question: ")
context = retrieve_context(
query, retriever=retriever, reranker_model=reranker_model
)[0]
print("LLM Response: ", end="")
for e in chain.stream({"context": context[0].page_content, "question": query}):
print(e, end="")
print()
time.sleep(0.1)
if __name__ == "__main__":
from jsonargparse import CLI
CLI(main) |