File size: 3,017 Bytes
139fefe
 
38ed905
 
 
 
139fefe
088e816
 
38ed905
139fefe
 
 
 
 
4b4bf28
 
 
 
 
 
caf1faa
 
 
 
 
4b4bf28
 
 
139fefe
 
 
4b4bf28
 
 
 
 
 
088e816
139fefe
088e816
 
 
 
 
 
 
4b4bf28
088e816
 
 
 
8edfef8
 
088e816
139fefe
088e816
 
 
 
139fefe
088e816
 
 
4b4bf28
088e816
caf1faa
 
 
 
088e816
caf1faa
088e816
 
 
 
 
caf1faa
088e816
 
caf1faa
088e816
caf1faa
 
4b4bf28
 
 
 
088e816
481f3b1
088e816
481f3b1
088e816
 
 
 
481f3b1
088e816
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from operator import itemgetter

from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.base import format_document

from climateqa.engine.chains.prompts import answer_prompt_template,answer_prompt_without_docs_template,answer_prompt_images_template
from climateqa.engine.chains.prompts import papers_prompt_template

DEFAULT_DOCUMENT_PROMPT = PromptTemplate.from_template(template="{page_content}")

def _combine_documents(
    docs, document_prompt=DEFAULT_DOCUMENT_PROMPT, sep="\n\n"
):

    doc_strings =  []

    for i,doc in enumerate(docs):
        # chunk_type = "Doc" if doc.metadata["chunk_type"] == "text" else "Image"
        chunk_type = "Doc"
        if isinstance(doc,str):
            doc_formatted = doc
        else:
            doc_formatted = format_document(doc, document_prompt)
        doc_string = f"{chunk_type} {i+1}: " + doc_formatted
        doc_string = doc_string.replace("\n"," ") 
        doc_strings.append(doc_string)

    return sep.join(doc_strings)


def get_text_docs(x):
    return [doc for doc in x if doc.metadata["chunk_type"] == "text"]

def get_image_docs(x):
    return [doc for doc in x if doc.metadata["chunk_type"] == "image"]

def make_rag_chain(llm):
    prompt = ChatPromptTemplate.from_template(answer_prompt_template)
    chain = ({
        "context":lambda x : _combine_documents(x["documents"]),
        "query":itemgetter("query"),
        "language":itemgetter("language"),
        "audience":itemgetter("audience"),
    } | prompt | llm | StrOutputParser())
    return chain

def make_rag_chain_without_docs(llm):
    prompt = ChatPromptTemplate.from_template(answer_prompt_without_docs_template)
    chain = prompt | llm | StrOutputParser()
    return chain


def make_rag_node(llm,with_docs = True):

    if with_docs:
        rag_chain = make_rag_chain(llm)
    else:
        rag_chain = make_rag_chain_without_docs(llm)

    async def answer_rag(state,config):
        answer = await rag_chain.ainvoke(state,config)
        return {"answer":answer}

    return answer_rag




# def make_rag_papers_chain(llm):

#     prompt = ChatPromptTemplate.from_template(papers_prompt_template)
#     input_documents = {
#         "context":lambda x : _combine_documents(x["docs"]),
#         **pass_values(["question","language"])
#     }

#     chain = input_documents | prompt | llm | StrOutputParser()
#     chain = rename_chain(chain,"answer")

#     return chain






# def make_illustration_chain(llm):

#     prompt_with_images = ChatPromptTemplate.from_template(answer_prompt_images_template)

#     input_description_images = {
#         "images":lambda x : _combine_documents(get_image_docs(x["docs"])),
#         **pass_values(["question","audience","language","answer"]),
#     }

#     illustration_chain = input_description_images | prompt_with_images | llm | StrOutputParser()
#     return illustration_chain