Spaces:
Running
Running
File size: 2,820 Bytes
4c2a969 5b26a96 1434337 5b26a96 4c2a969 5b26a96 4c2a969 cbd0b83 1434337 5b26a96 1434337 f79211f 1434337 4c2a969 f79211f 4c2a969 a147158 4c2a969 1434337 4c2a969 1434337 4c2a969 1434337 4c2a969 4c41de2 4c2a969 5b26a96 1434337 5b26a96 1434337 4c2a969 f79211f 5b26a96 35f0167 4c2a969 4c41de2 5b26a96 f79211f 5b26a96 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
import shutil
from typing import List
from haystack import Document
from haystack.document_stores import FAISSDocumentStore
from haystack.nodes import EmbeddingRetriever, PromptNode
from haystack.pipelines import Pipeline
import streamlit as st
from haystack_entailment_checker import EntailmentChecker
from app_utils.config import (
STATEMENTS_PATH,
INDEX_DIR,
RETRIEVER_MODEL,
RETRIEVER_MODEL_FORMAT,
NLI_MODEL,
PROMPT_MODEL,
)
@st.cache_data
def load_statements():
"""Load statements from file"""
with open(STATEMENTS_PATH) as fin:
statements = [
line.strip() for line in fin.readlines() if not line.startswith("#")
]
return statements
# cached to make index and models load only at start
@st.cache_resource
def start_haystack():
"""
load document store, retriever, entailment checker and create pipeline
"""
shutil.copy(f"{INDEX_DIR}/faiss_document_store.db", ".")
document_store = FAISSDocumentStore(
faiss_index_path=f"{INDEX_DIR}/my_faiss_index.faiss",
faiss_config_path=f"{INDEX_DIR}/my_faiss_index.json",
)
print(f"Index size: {document_store.get_document_count()}")
retriever = EmbeddingRetriever(
document_store=document_store,
embedding_model=RETRIEVER_MODEL,
model_format=RETRIEVER_MODEL_FORMAT,
)
entailment_checker = EntailmentChecker(
model_name_or_path=NLI_MODEL,
use_gpu=False,
entailment_contradiction_threshold=0.5,
)
pipe = Pipeline()
pipe.add_node(component=retriever, name="retriever", inputs=["Query"])
pipe.add_node(component=entailment_checker, name="ec", inputs=["retriever"])
prompt_node = PromptNode(model_name_or_path=PROMPT_MODEL, max_length=150)
return pipe, prompt_node
pipe, prompt_node = start_haystack()
# the pipeline is not included as parameter of the following function,
# because it is difficult to cache
@st.cache_resource
def check_statement(statement: str, retriever_top_k: int = 5):
"""Run query and verify statement"""
params = {"retriever": {"top_k": retriever_top_k}}
return pipe.run(statement, params=params)
@st.cache_resource
def explain_using_llm(
statement: str, documents: List[Document], entailment_or_contradiction: str
) -> str:
"""Explain entailment/contradiction, by prompting a LLM"""
premise = " \n".join([doc.content.replace("\n", ". ") for doc in documents])
if entailment_or_contradiction == "entailment":
verb = "entails"
elif entailment_or_contradiction == "contradiction":
verb = "contradicts"
prompt = f"Premise: {premise}; Hypothesis: {statement}; Please explain in detail why the Premise {verb} the Hypothesis. Step by step Explanation:"
print(prompt)
return prompt_node(prompt)[0]
|