File size: 2,820 Bytes
4c2a969
5b26a96
1434337
5b26a96
4c2a969
5b26a96
4c2a969
 
 
cbd0b83
1434337
 
 
 
 
 
5b26a96
1434337
 
 
f79211f
1434337
 
 
 
 
 
 
4c2a969
 
 
f79211f
4c2a969
 
a147158
4c2a969
1434337
4c2a969
1434337
 
 
 
4c2a969
 
 
1434337
4c2a969
4c41de2
 
 
 
 
4c2a969
 
 
 
 
5b26a96
1434337
5b26a96
 
 
 
1434337
4c2a969
 
f79211f
5b26a96
35f0167
4c2a969
4c41de2
5b26a96
 
f79211f
5b26a96
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import shutil
from typing import List

from haystack import Document
from haystack.document_stores import FAISSDocumentStore
from haystack.nodes import EmbeddingRetriever, PromptNode
from haystack.pipelines import Pipeline
import streamlit as st

from haystack_entailment_checker import EntailmentChecker
from app_utils.config import (
    STATEMENTS_PATH,
    INDEX_DIR,
    RETRIEVER_MODEL,
    RETRIEVER_MODEL_FORMAT,
    NLI_MODEL,
    PROMPT_MODEL,
)


@st.cache_data
def load_statements():
    """Load statements from file"""
    with open(STATEMENTS_PATH) as fin:
        statements = [
            line.strip() for line in fin.readlines() if not line.startswith("#")
        ]
    return statements


# cached to make index and models load only at start
@st.cache_resource
def start_haystack():
    """
    load document store, retriever, entailment checker and create pipeline
    """
    shutil.copy(f"{INDEX_DIR}/faiss_document_store.db", ".")
    document_store = FAISSDocumentStore(
        faiss_index_path=f"{INDEX_DIR}/my_faiss_index.faiss",
        faiss_config_path=f"{INDEX_DIR}/my_faiss_index.json",
    )
    print(f"Index size: {document_store.get_document_count()}")
    retriever = EmbeddingRetriever(
        document_store=document_store,
        embedding_model=RETRIEVER_MODEL,
        model_format=RETRIEVER_MODEL_FORMAT,
    )
    entailment_checker = EntailmentChecker(
        model_name_or_path=NLI_MODEL,
        use_gpu=False,
        entailment_contradiction_threshold=0.5,
    )

    pipe = Pipeline()
    pipe.add_node(component=retriever, name="retriever", inputs=["Query"])
    pipe.add_node(component=entailment_checker, name="ec", inputs=["retriever"])

    prompt_node = PromptNode(model_name_or_path=PROMPT_MODEL, max_length=150)

    return pipe, prompt_node


pipe, prompt_node = start_haystack()

# the pipeline is not included as parameter of the following function,
# because it is difficult to cache
@st.cache_resource
def check_statement(statement: str, retriever_top_k: int = 5):
    """Run query and verify statement"""
    params = {"retriever": {"top_k": retriever_top_k}}
    return pipe.run(statement, params=params)


@st.cache_resource
def explain_using_llm(
    statement: str, documents: List[Document], entailment_or_contradiction: str
) -> str:
    """Explain entailment/contradiction, by prompting a LLM"""
    premise = " \n".join([doc.content.replace("\n", ". ") for doc in documents])
    if entailment_or_contradiction == "entailment":
        verb = "entails"
    elif entailment_or_contradiction == "contradiction":
        verb = "contradicts"

    prompt = f"Premise: {premise}; Hypothesis: {statement}; Please explain in detail why the Premise {verb} the Hypothesis. Step by step Explanation:"

    print(prompt)
    return prompt_node(prompt)[0]