File size: 7,354 Bytes
6d80262
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
# LLM
# Ollama for local tests
from langchain.llms import Ollama
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.llms import Ollama
# Ref.: https://mistral.ai/news/mixtral-of-experts/#instructed-models
# Q5_K_M quantzation flavor for best quality/recommended tradeoff (memory is no problem here)
# Ref.: https://huggingface.co/TheBloke/Mistral-7B-Instruct-v0.2-GGUF#provided-files
MISTRAL = "mistral:7b-instruct-v0.2-q5_K_M"
# Q4_K quantization flavor for best memory/quality/recommended tradeoff
# Ref.: https://huggingface.co/TheBloke/Mixtral-8x7B-Instruct-v0.1-GGUF#provided-files
# mixtral:8x7b-instruct-v0.1-q4_K_M was sadly still too big for my Mac
MIXTRAL = "mixtral:8x7b-instruct-v0.1-q3_K_L"
# Llama2 13B 
# Ref.: https://huggingface.co/TheBloke/Llama-2-13B-GGUF
LLAMA2 = "llama2:13b-chat-q5_K_M"
mistral = Ollama(
    model=MISTRAL,
    callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]),
    # Ref.: https://api.python.langchain.com/en/latest/llms/langchain_community.llms.ollama.Ollama.html#langchain_community.llms.ollama.Ollama.format
    # format="json"
)
mixtral = Ollama(
    model=MIXTRAL,
    callback_manager=CallbackManager([StreamingStdOutCallbackHandler()])
)
llama2 = Ollama(
    model=LLAMA2,
    callback_manager=CallbackManager([StreamingStdOutCallbackHandler()])
)


# LOAD
from langchain_community.document_loaders import UnstructuredMarkdownLoader
from langchain_community.document_loaders import TextLoader
from langchain_community.document_loaders import PyPDFLoader

FILES = {
    'md': [
        # "Présentation modes dégradés-20230120_112423-Enregistrement de la réunion.md",
        "YouTube - Mode secours telephonie.md"
    ],
    'pdf': [
        # "SI-Samu_Fiche procédure_Mode dégradé_Perte de CRRA.pdf",
        # "[SI-Samu] Fiche mémo - Procédure Mode dégradé.pdf",
        "SI-Samu_Documentation_produit_SF4_J18HF2_20231219 - mode secours seul.pdf",
        # "SI-Samu_Documentation_produit_SF4_J18HF2_20231219.pdf"
    ]
}

def load_data(files):
    data = {'md': [], 'pdf': []}
    for pdf in files['pdf']:
        data['pdf'].extend(PyPDFLoader('resources/' + pdf).load())
    for md in files['md']:
        data['md'].extend(TextLoader('resources/' + md).load())
    return data

def to_full_data(data):
    return [
        *data['md'],
        *data['pdf']
    ]

# SPLIT
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.text_splitter import MarkdownHeaderTextSplitter

def split_MD_then_recursive(data):
    # - First use MarkDown title splitter on .MD and then RecursiveSplitter on all
    # MD splits
    markdown_splitter = MarkdownHeaderTextSplitter(headers_to_split_on=[
        ("#", "Titre 1"),
        ("##", "Titre 2"),
        ("###", "Titre 3"),
    ], strip_headers=False)
    md_header_splits = data['pdf'].copy()
    for md in data['md']:
        md_header_splits.extend(markdown_splitter.split_text(md.page_content))

    # Char-level splits
    text_splitter=RecursiveCharacterTextSplitter(
        chunk_size=500, 
        chunk_overlap=50  # to improve results quality
    )
    # Split
    return text_splitter.split_documents(md_header_splits)

# EMBED
# Directly done in the different scripts

# RETRIEVE
from langchain.storage import InMemoryStore
from langchain.retrievers import ParentDocumentRetriever, BM25Retriever, EnsembleRetriever
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import Chroma

# Ensemble is based on weight fusion (Reciprocal Rank Fusion) | Ref.: https://safjan.com/implementing-rank-fusion-in-python/
def get_parent_ensemble_retriever(embeddings, full_data, all_splits, k=4, parent_chunk_size=2000, child_chunk_size=400, collection_name="store"):
    # - ParentDocumentRetriever: embed small chunks but retrieve with bigger context
    # This text splitter is used to create the parent documents
    parent_splitter = RecursiveCharacterTextSplitter(chunk_size=parent_chunk_size)
    # This text splitter is used to create the child documents
    # It should create documents smaller than the parent (don't make bigger than 512 as most embeddings trunk after that)
    child_splitter = RecursiveCharacterTextSplitter(chunk_size=child_chunk_size)
    # The vectorstore to use to index the child chunks
    parent_vectorstore = Chroma(
        collection_name=collection_name, 
        embedding_function=embeddings
    )
    # The storage layer for the parent documents
    parent_store = InMemoryStore()
    parent_retriever = ParentDocumentRetriever(
        vectorstore=parent_vectorstore,
        docstore=parent_store,
        child_splitter=child_splitter,
        parent_splitter=parent_splitter,
        search_kwargs={
            "k": k,
            # "score_threshold": 0.5
        },
        # search_type="mmr"
    )
    parent_retriever.add_documents(full_data)

    # - EnsembleRetriever
    # BM25 logic
    bm25_retriever = BM25Retriever.from_texts(
        list(map(lambda s: s.page_content, all_splits)), 
        metadatas=list(map(lambda s: {"retriever": "BM25 sparse similiarity", **s.metadata}, all_splits))
        
    )
    bm25_retriever.k = k

    # Ensemble of BM25 + vectorstore on parent retriever
    return EnsembleRetriever(
        retrievers=[parent_retriever, bm25_retriever], weights=[0.5, 0.5]
    )

# PROMPT
# Add more context to query + update system prompt to make it speak French
# Ref.: https://stackoverflow.com/questions/76554411/unable-to-pass-prompt-template-to-retrievalqa-in-langchain
# Ref.: https://community.openai.com/t/how-to-prevent-chatgpt-from-answering-questions-that-are-outside-the-scope-of-the-provided-context-in-the-system-role-message/112027/7
from langchain import PromptTemplate
template = """
System: You are helping a user of "bandeau téléphonique SI-SAMU" (a CTI - Computer Telephony Integration - system) during system failure as he needs to use its local backup phone.
Context information is below. Given the context information and not prior knowledge, answer the query.
Language: Answer in French and using "vous".
---
Context: {context}
---
Question: {question}
---
Réponse :
"""
PROMPT = PromptTemplate(template=template, input_variables=['question', 'context'])

# RESULTS
class color:
   PURPLE = '\033[95m'
   CYAN = '\033[96m'
   DARKCYAN = '\033[36m'
   BLUE = '\033[94m'
   GREEN = '\033[92m'
   YELLOW = '\033[93m'
   RED = '\033[91m'
   BOLD = '\033[1m'
   UNDERLINE = '\033[4m'
   END = '\033[0m'

def parse_answer(answer):
    print(f">> {answer['query']}")
    print(f">> {answer['result']}")
    print(">> Sources :")
    for doc in answer['source_documents']:
        page = ''
        if 'page' in doc.metadata:
            page = f" (page {doc.metadata['page']})"
        source = ''
        if 'source' in doc.metadata:
            source = doc.metadata['source']
        titles = ['Titre 1', 'Titre 2', 'Titre 3']
        for title in titles:
            if title in doc.metadata:
                source += f" > {doc.metadata[title]}"
        retriever = f"B25" if 'retriever' in doc.metadata else "vectorstore"
        print(f">>> {color.BOLD}{source}{page} [{retriever}]{color.END}: {doc.page_content}\n---")
    print("--------\n\n")