GitChat / rag_101 /rag.py
kartavya23's picture
Upload 4 files
47ad957 verified
import time
from typing import List, Optional, Union
from langchain_community.chat_models import ChatOllama
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from retriever import (
create_parent_retriever,
load_embedding_model,
load_pdf,
load_reranker_model,
retrieve_context,
)
def main(
file: str = "2401.08406v3.pdf",
llm_name="mistral",
):
docs = load_pdf(files=file)
embedding_model = load_embedding_model()
retriever = create_parent_retriever(docs, embedding_model)
reranker_model = load_reranker_model()
llm = ChatOllama(model=llm_name)
prompt_template = ChatPromptTemplate.from_template(
(
"Please answer the following question based on the provided `context` that follows the question.\n"
"If you do not know the answer then just say 'I do not know'\n"
"question: {question}\n"
"context: ```{context}```\n"
)
)
chain = prompt_template | llm | StrOutputParser()
while True:
query = input("Ask question: ")
context = retrieve_context(
query, retriever=retriever, reranker_model=reranker_model
)[0]
print("LLM Response: ", end="")
for e in chain.stream({"context": context[0].page_content, "question": query}):
print(e, end="")
print()
time.sleep(0.1)
if __name__ == "__main__":
from jsonargparse import CLI
CLI(main)