ricoh51 commited on
Commit
65d97fa
·
1 Parent(s): b70b72a

First commit

Browse files
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ venv/
2
+ __pycache__/
3
+ .vscode/
4
+ .gradio/
5
+ .env
6
+ files/rag_app/
.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.11
files/drane.jpg ADDED
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ pypdf
2
+ openai
3
+ huggingface-hub
4
+ ollama
5
+ mistralai
src/amodel.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from enum import Enum
3
+
4
+ class ModelType(Enum):
5
+ ''' Les différentes technos de models '''
6
+ MTOPENAI = 1
7
+ MTOLLAMA = 2
8
+ MTHUGGINGFACE = 3
9
+ MTMISTRAL = 4
10
+
11
+ @classmethod
12
+ def to_str(self, mt:int)->str:
13
+ match mt:
14
+ case 1: return "MTOPENAI"
15
+ case 2: return "MTOLLAMA"
16
+ case 3: return "MTHUGGINGFACE"
17
+ case 4: return "MTMISTRAL"
18
+ case _: return "UNKNOWN"
19
+
20
+ class AModel(ABC):
21
+ '''
22
+ Classe abstraite de base pour tous les models :
23
+ Ollama en local
24
+ OpenAI distant
25
+ HuggingFace distant
26
+ HuggingFace dans une app
27
+ ...
28
+ '''
29
+
30
+ @abstractmethod
31
+ def ask_llm(self, question:str)->str:
32
+ pass
33
+
34
+ @abstractmethod
35
+ def create_vector(self, chunk:str)->list[float]:
36
+ pass
37
+
38
+ @abstractmethod
39
+ def create_vectors(self, chunks:list[str])->list[list[float]]:
40
+ pass
41
+
42
+ def get_llm_name(self):
43
+ return self.llm_name
44
+
45
+ def set_llm_name(self, llm_name:str):
46
+ self.llm_name = llm_name
47
+
48
+ def get_feature_name(self):
49
+ return self.feature_name
50
+
51
+ def set_feature_name(self, feature_name:str):
52
+ self.feature_name = feature_name
53
+
54
+ def get_temperature(self):
55
+ return self.temperature
56
+
57
+ def set_temperature(self, temperature:float):
58
+ self.temperature = temperature
src/astore.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class AStore(ABC):
5
+ '''
6
+ Classe abstraite de base pour tous les stores :
7
+ Chroma
8
+ Perso
9
+ ...
10
+ '''
11
+
12
+ @abstractmethod
13
+ def reset(self)->None:
14
+ pass
15
+
16
+ @abstractmethod
17
+ def print_infos(self)->None:
18
+ pass
19
+
20
+ @abstractmethod
21
+ def add_to_collection(self, collection_name:str, source:str, vectors:list[list[float]], chunks:list[str])->None:
22
+ pass
23
+
24
+ @abstractmethod
25
+ def delete_collection(self, name:str)->None:
26
+ pass
27
+
28
+ @abstractmethod
29
+ def get_similar_vector(self, vector:list[float], collection_name:str)->list[float]:
30
+ pass
31
+
32
+ @abstractmethod
33
+ def get_similar_chunk(self, query_vector:list[float], collection_name:str)->tuple[str, str]:
34
+ pass
35
+
36
+ @abstractmethod
37
+ def get_similar_chunks(self, query_vector:list[float], count:int, collection_name:str):
38
+ pass
39
+
src/chunker.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+
3
+ class Chunker:
4
+ '''
5
+ Tronçonnage d'un texte en chunks
6
+ '''
7
+
8
+ def __init__(self):
9
+ pass
10
+
11
+ def split_basic(self, text:str, char_count:int, overlap:int)->list[str]:
12
+ '''
13
+ Découpe le texte avec des '\n'.
14
+ La taille d'un chunk est de max count + 2 * overlap
15
+ chunk = o1-c-o2
16
+ o1: les mots du chunk précédent ajoutés, il y en a 'overlap' ou 0 pour le premier chunk
17
+ c: partie centrale du chunk
18
+ o2: les mots du chunk suivant ajoutés, il y en a 'overlap' ou 0 pour le dernier chunk
19
+ Args:
20
+ char_count: le nombre de caractères dans un chunk (sans compter les mots ajoutés par recouvrement)
21
+ overlap: le nombre de caractères du chunk précédent (et suivant) ajoutés au début (et à la fin) du chunk
22
+ Return:
23
+ La liste des chunks
24
+ '''
25
+ # La liste qui sera renvoyée
26
+ chunks:list[str] = [] # la liste qui sera renvoyée
27
+ # Découpage du texte en morceaux de 'char_count' caractères
28
+ n:int = len(text)
29
+ size:int = n // char_count + 1 # nombre de chunks
30
+ for i in range(size):
31
+ start = i*char_count if i == 0 else i*char_count - overlap
32
+ stop = (i+1)*char_count if i == size - 1 else (i+1)*char_count + overlap
33
+ s = slice(start, stop)
34
+ chunk:str = text[s]
35
+ chunks.append(chunk)
36
+ return chunks
src/model_huggingface.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from .amodel import AModel
3
+ from huggingface_hub import InferenceClient
4
+ import numpy as np # feature_extraction renvoie un array numpy...
5
+
6
+
7
+ class HuggingFaceModel(AModel):
8
+
9
+ def __init__(self, llm_name:str, feature_name:str, temperature:float=0.0):
10
+ self.llm_name:str = llm_name
11
+ self.feature_name:str = feature_name
12
+ self.temperature = temperature
13
+ # La variable HF_ACTIVE a été créée dans les settings de l'app sur HuggingFace
14
+ if (os.getenv("HF_ACTIVE")): # Lancement depuis l'app sur HuggingFace
15
+ api_token = os.getenv("HF_TOKEN")
16
+ else: # Lancement depuis mon ordi
17
+ # print("Launch Rag in HuggingFace local")
18
+ from dotenv import load_dotenv # Trick: ne passe pas dans une app sur HuggingFace
19
+ load_dotenv()
20
+ api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
21
+ try:
22
+ self.model = InferenceClient(api_key=api_token)
23
+ except:
24
+ raise
25
+
26
+ def ask_llm(self, question:str)->str:
27
+ messages = [{"role": "user", "content": question}]
28
+ try:
29
+ resp = self.model.chat.completions.create(
30
+ model=self.llm_name,
31
+ messages=messages,
32
+ max_tokens=500,
33
+ temperature=self.temperature,
34
+ # stream=True
35
+ )
36
+ return resp.choices[0].message.content
37
+ except:
38
+ raise
39
+
40
+ def create_vector(self, chunk:str)->list[float]:
41
+ resp = self.model.feature_extraction(
42
+ text=chunk,
43
+ # normalize=True, # Only available on server powered by Text-Embedding-Inference.
44
+ model=self.feature_name, # normalisé ??
45
+ )
46
+ return resp
47
+
48
+ def create_vectors(self, chunks:list[str])->list[list[float]]:
49
+ '''
50
+ Pas de batch pour la création de vectors sur HuggingFace, on les passe un par un
51
+ '''
52
+ vectors = []
53
+ try:
54
+ for chunk in chunks:
55
+ v = self.create_vector(chunk)
56
+ if not isinstance(v, np.ndarray):
57
+ raise
58
+ vectors.append(v.tolist())
59
+ return vectors
60
+ except:
61
+ raise
src/model_mistral.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from dotenv import load_dotenv
4
+ from .amodel import AModel
5
+ from mistralai import Mistral
6
+
7
+ class MistralModel(AModel):
8
+ '''
9
+ https://docs.mistral.ai/capabilities/completion/
10
+ https://docs.mistral.ai/capabilities/embeddings/
11
+ temperature entre 0.0 et 0.7
12
+ '''
13
+
14
+ def __init__(self, llm_name:str, feature_name:str, temperature:float=0.0):
15
+ self.llm_name:str = llm_name
16
+ self.feature_name:str = feature_name
17
+ self.temperature = temperature
18
+ load_dotenv()
19
+ try:
20
+ self.model = Mistral(api_key=os.getenv("MISTRAL_API_KEY"))
21
+ except:
22
+ raise
23
+
24
+
25
+ def ask_llm(self, question:str)->str:
26
+ try:
27
+ response = self.model.chat.complete(
28
+ model=self.llm_name,
29
+ messages = [{ "role": "user", "content": question, },],
30
+ temperature=self.temperature
31
+ )
32
+ return response.choices[0].message.content
33
+ except:
34
+ raise
35
+
36
+ def create_vector(self, chunk:str)->list[float]:
37
+ '''
38
+ Renvoie un vecteur de taille 1024 à partir de chunk
39
+ '''
40
+ try:
41
+ response = self.model.embeddings.create(
42
+ model=self.feature_name,
43
+ # inputs=["Embed this sentence.", "As well as this one."],
44
+ inputs=[chunk]
45
+ )
46
+ return response.data[0].embedding
47
+ except:
48
+ raise
49
+
50
+ def create_vectors(self, chunks:list[str])->list[list[float]]:
51
+ '''
52
+ Renvoie n vecteurs de taille 1024 à partir de la liste chunks
53
+ '''
54
+ try:
55
+ response = self.model.embeddings.create(
56
+ model=self.feature_name,
57
+ inputs=chunks,
58
+ )
59
+ n:int = len(chunks)
60
+ result = [response.data[i].embedding for i in range(n)]
61
+ return result
62
+ except:
63
+ raise
src/model_ollama.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .amodel import AModel
2
+ import ollama
3
+ import numpy as np
4
+
5
+ class OllamaModel(AModel):
6
+
7
+ def __init__(self, llm_name:str, feature_name:str, temperature:float=0.0):
8
+ self.llm_name:str = llm_name
9
+ self.feature_name:str = feature_name
10
+ self.temperature = temperature
11
+
12
+ def ask_llm(self, question:str)->str:
13
+ try:
14
+ resp = ollama.chat(
15
+ model=self.llm_name,
16
+ messages=[{'role':'user', 'content':question}],
17
+ stream=False,
18
+ options={"temperature":self.temperature})
19
+ return resp.message.content
20
+ except:
21
+ raise
22
+
23
+ def create_vector(self, chunk:str)->list[float]:
24
+ '''
25
+ TODO: Vérifier s'il ne faut pas utiliser 'embed' plutôt que 'embeddings'
26
+ '''
27
+ try:
28
+ resp = ollama.embeddings(
29
+ model=self.feature_name,
30
+ prompt=chunk)
31
+ return self.normalize(resp.embedding).tolist()
32
+ except:
33
+ raise
34
+
35
+ def normalize(self, v:list[float]):
36
+ norm = np.linalg.norm(v)
37
+ if norm == 0:
38
+ return v
39
+ return v / norm
40
+
41
+ def create_vectors(self, chunks:list[str])->list[list[float]]:
42
+ try:
43
+ resp = ollama.embed(
44
+ model=self.feature_name,
45
+ input=chunks)
46
+ # print(resp.embeddings)
47
+ return resp.embeddings
48
+ except:
49
+ raise
src/model_openai.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from dotenv import load_dotenv
4
+ from .amodel import AModel
5
+ from openai import OpenAI
6
+
7
+ class OpenAIModel(AModel):
8
+ '''
9
+ https://platform.openai.com/docs/guides/text-generation
10
+
11
+ '''
12
+
13
+ def __init__(self, llm_name:str, feature_name:str, temperature:float=0.0):
14
+ self.llm_name:str = llm_name
15
+ self.feature_name:str = feature_name
16
+ self.temperature = temperature
17
+ load_dotenv()
18
+ try:
19
+ self.model = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
20
+ except:
21
+ raise
22
+
23
+ def ask_llm(self, question:str)->str:
24
+ try:
25
+ response = self.model.chat.completions.create(
26
+ # model="gpt-4o-mini",
27
+ model=self.llm_name,
28
+ messages=[
29
+ {"role":"system", "content":""},
30
+ {"role":"user", "content":question},
31
+ ],
32
+ temperature=self.temperature
33
+ )
34
+ return response.choices[0].message.content
35
+ except:
36
+ raise
37
+
38
+ def create_vector(self, chunk:str)->list[float]:
39
+ '''
40
+ 8192 tokens max
41
+ '''
42
+ # les embeddings d'OpenAI sont normalisés à 1
43
+ try:
44
+ response = self.model.embeddings.create(
45
+ input=chunk,
46
+ model=self.feature_name
47
+ )
48
+ return response.data[0].embedding
49
+ except:
50
+ raise
51
+
52
+ def create_vectors(self, chunks:list[str])->list[list[float]]:
53
+ '''
54
+ Pas plus de 2048 chunks
55
+ '''
56
+ try:
57
+ response = self.model.embeddings.create(
58
+ input=chunks,
59
+ model=self.feature_name
60
+ )
61
+ n:int = len(chunks)
62
+ result = [response.data[i].embedding for i in range(n)]
63
+ return result
64
+ except:
65
+ raise
src/rag.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ from pypdf import PdfReader
4
+
5
+ from .chunker import Chunker
6
+ from .amodel import ModelType
7
+ from .model_openai import OpenAIModel
8
+ from .model_huggingface import HuggingFaceModel
9
+ from .model_ollama import OllamaModel
10
+ from .model_mistral import MistralModel
11
+ from .store import Store
12
+
13
+
14
+ class Rag:
15
+ '''
16
+ Classe qui s'occupe de toute la chaine du RAG.
17
+ Elle permet :
18
+ d'interroger un llm directement (sans RAG) avec ask_llm()
19
+ d'interroger le RAG lui même avec ask_rag()
20
+ d'ajouter des documents à la base de données du RAG
21
+ de remettre la base à zéro
22
+ de créer des vecteurs
23
+ de charger des pdf
24
+ '''
25
+
26
+ # Le prompt qui sera utilisé uniquement avec le RAG
27
+ prompt_template = """
28
+ En vous basant **uniquement** sur les informations fournies dans le contexte
29
+ ci-dessous, répondez à la question posée.
30
+ Les équations seront écrites en latex.
31
+ Si vous ne trouvez pas la réponse dans le contexte, répondez "Je ne sais pas".
32
+ Contexte : {context}
33
+ Question : {question}
34
+ """
35
+
36
+ def __init__(self, model_type:ModelType, store_dir:str) -> None:
37
+ '''
38
+ Constructeur du Rag
39
+ Args:
40
+ model_type: la techno utilisée
41
+ store_dir: le répertoire de persistance de la base de données ou None
42
+ Exception:
43
+ Si le model ne peut pas être créé
44
+ Si le type de model est inconnu
45
+ '''
46
+ self.model_type = model_type
47
+ try:
48
+ match model_type:
49
+ case ModelType.MTOPENAI:
50
+ self.model = OpenAIModel("gpt-4o-mini", "text-embedding-3-small", 0)
51
+ case ModelType.MTHUGGINGFACE:
52
+ self.model = HuggingFaceModel("meta-llama/Meta-Llama-3-8B-Instruct", "sentence-transformers/all-MiniLM-l6-v2", 0)
53
+ case ModelType.MTOLLAMA:
54
+ self.model = OllamaModel("llama3.2:1b", "nomic-embed-text", 0.0)
55
+ case ModelType.MTMISTRAL:
56
+ self.model = MistralModel("mistral-large-latest", "mistral-embed", 0.0)
57
+ case _:
58
+ raise Exception("Rag.__init__: Unknown model type: {mt} : {v}".format(mt=ModelType.to_str(model_type), v=model_type))
59
+ self.emb_store = Store(store_dir) # persistant
60
+ # self.emb_store = Store(None) # éphémère
61
+ except Exception as e:
62
+ raise
63
+
64
+ def get_llm_name(self):
65
+ return self.model.get_llm_name()
66
+
67
+ def get_feature_name(self):
68
+ return self.model.get_feature_name()
69
+
70
+ def get_temperature(self):
71
+ return self.model.get_temperature()
72
+
73
+ def set_temperature(self, temperature:float):
74
+ self.model.set_temperature(temperature)
75
+
76
+ def reset_store(self):
77
+ self.emb_store.reset()
78
+
79
+ def delete_collection(self, name:str)->None:
80
+ self.emb_store.delete_collection(name)
81
+
82
+ def create_vectors(self, chunks:list[str])->list[list[float]]:
83
+ '''
84
+ Renvoie les vecteurs correspondant à 'chunks', calculés par 'emb_model'
85
+ Args:
86
+ chunks: les extraits de texte à calculer
87
+ Return:
88
+ la liste des vecteurs calculés
89
+ '''
90
+ vectors:list = []
91
+ tokens:int = 0
92
+ vectors:list[list[float]] = self.model.create_vectors(chunks) # batch si le model le permet
93
+ # for chunk in chunks:
94
+ # vector:list[float] = self.model.create_vector(chunk=chunk)
95
+ # vectors.append(vector)
96
+ return vectors
97
+
98
+ def load_pdf(self, file_name:str)->str:
99
+ ''' Charge le fichier 'file_name' et renvoie son contenu sous forme de texte. '''
100
+ reader = PdfReader(file_name)
101
+ content = ""
102
+ for page in reader.pages:
103
+ content += page.extract_text() + "\n"
104
+ return content
105
+
106
+ def get_chunks(self, text:str)->list:
107
+ '''
108
+ Découpe le 'text' en chunks de taille chunk_size avec un recouvrement
109
+ Args:
110
+ text: Le texte à découper
111
+ Return:
112
+ La liste des chunks
113
+ '''
114
+ # splitter = RecursiveCharacterTextSplitter(
115
+ # # separator="\n",
116
+ # chunk_size=1000,
117
+ # chunk_overlap=200,
118
+ # length_function=len,
119
+ # is_separator_regex=False
120
+ # )
121
+ # chunks = splitter.split_text(text)
122
+ # print("get_chunks: " + str(len(chunks)))
123
+ chunker = Chunker()
124
+ chunks = chunker.split_basic(text=text, char_count=1000, overlap=200)
125
+ return chunks
126
+
127
+ def add_pdf_to_store(self, file_name:str, collection_name:str)->None:
128
+ '''
129
+ Ajoute un pdf à la base de données du RAG.
130
+ Args:
131
+ file_name: le chemin vers le fichier à ajouter
132
+ collection_name: Le nom de la collection dans laquelle il faut ajouter les chunks
133
+ La collection est créée si elle n'existe pas.
134
+ '''
135
+ text:str = self.load_pdf(file_name)
136
+ chunks:list[str] = self.get_chunks(text)
137
+ self.add_chunks_to_store(chunks=chunks, collection_name=collection_name, source=file_name)
138
+
139
+ def add_pdf_stream_to_store(self, stream, collection_name:str)->None:
140
+ '''
141
+ Ajoute un stream provenant de file_uploader de streamlit par exemple
142
+ '''
143
+ text:str = self.load_pdf(stream)
144
+ chunks:list[str] = self.get_chunks(text)
145
+ self.add_chunks_to_store(chunks=chunks, collection_name=collection_name, source="stream")
146
+
147
+ def add_chunks_to_store(self, chunks:list[str], collection_name:str, source:str)->None:
148
+ '''
149
+ Ajoute des chunks à la base de données du RAG.
150
+ Args:
151
+ chunks: les chunks à ajouter
152
+ collection_name: Le nom de la collection dans laquelle il faut ajouter les chunks
153
+ La collection est créée si elle n'existe pas.
154
+ source: la source des chunks (nom du fichier, url ...)
155
+ '''
156
+ vectors = self.create_vectors(chunks=chunks)
157
+ self.emb_store.add_to_collection(
158
+ collection_name=collection_name,
159
+ source=source,
160
+ vectors=vectors,
161
+ chunks=chunks
162
+ )
163
+
164
+
165
+ def ask_llm(self, question:str)->str:
166
+ '''
167
+ Pose une question au llm_model, attend sa réponse et la renvoie.
168
+ Args:
169
+ question: La question qu'on veut lui poser
170
+ Returns:
171
+ La réponse du llm_model
172
+ '''
173
+ return self.model.ask_llm(question=question)
174
+
175
+ def ask_rag(self, question:str, collection_name:str)->tuple[str, str, list[str], list[str]]:
176
+ '''
177
+ Pose une question au RAG, attend sa réponse et la renvoie.
178
+ Args:
179
+ question: La question qu'on veut lui poser
180
+ collection_name: le nom de la collection que l'on veut interroger
181
+ Returns:
182
+ Le prompt effectivement donné au llm_model
183
+ La réponse du llm_model
184
+ Les sources du RAG utilisées
185
+ Les ids des documents du RAG
186
+ '''
187
+ if not question:
188
+ return "", "Error: No question !", [], []
189
+ if not collection_name:
190
+ return "", "Error: No collection specified !", [], []
191
+ if not collection_name in self.emb_store.get_collection_names():
192
+ return "", "Error: {name} is no more in the database !".format(name=collection_name), [], []
193
+ # Transformer la 'question' en vecteur avec emb_model
194
+ query_vector:list[float] = self.model.create_vector(question)
195
+ # Récupérer les chunks du store similaires à la question
196
+ chunks, sources, ids = self.emb_store.get_similar_chunks(
197
+ query_vector=query_vector,
198
+ count=2,
199
+ collection_name=collection_name
200
+ )
201
+ # Préparer le prompt final à partir du prompt_template
202
+ prompt:str = self.prompt_template.format(
203
+ context="\n\n\n".join(chunks),
204
+ question=question
205
+ )
206
+ # demander au llm_model de répondre
207
+ resp:str = self.ask_llm(question=prompt)
208
+
209
+ return prompt, resp, sources, ids
210
+
211
+ def test_cours_TSTL()->None:
212
+ # Test placé ici pendant la mise au point
213
+ STORE_DIR = "./db/chroma_vectors"
214
+ # rag = Rag(ModelType.MTOPENAI, store_dir=STORE_DIR)
215
+ rag = Rag(ModelType.MTHUGGINGFACE, store_dir=STORE_DIR)
216
+ # rag = Rag(llm_type=ModelType.MTHUGGINGFACE, emb_type=ModelType.MTHUGGINGFACE, store_dir=STORE_DIR)
217
+
218
+ rag.reset_store()
219
+ rag.add_pdf_to_store("chap-1-Statique.pdf", "T_SPCL")
220
+ # rag.add_pdf_to_store("chap-2-Regulation.pdf", "T_SPCL")
221
+ # rag.add_pdf_to_store("chap-3-Dynamique.pdf", "T_SPCL")
222
+ # rag.add_pdf_to_store("chap-4-Echangeurs.pdf", "T_SPCL")
223
+
224
+ rag.emb_store.print_infos()
225
+
226
+ prompt, resp, sources, ids = rag.ask_rag(
227
+ question="Quelle est la différence entre une pression relative et une pression absolue?",
228
+ # question="Qu'est-ce qu'un échangeur à contre-courant?",
229
+ # question="Quelle est la formule de la résistance thermique? Réponds brièvement",
230
+ # question="Quelle est l'équation de Bernouilli avec les termes de pompe et pertes de charges? Réponds brièvement",
231
+ # question="Que signifie le terme de vitesse dans l'équation de Bernouilli ?",
232
+ # question="Transforme 1 bar en mètre de colonne d'eau",
233
+ # question="A quoi correspond HMT d'une pompe?",
234
+ collection_name="T_SPCL"
235
+ )
236
+ print(prompt)
237
+ print("---------------------------")
238
+ print(resp)
239
+ print("---------------------------")
240
+ print("sources:", sources)
241
+ print("ids=", ids)
242
+
243
+ # print(rag.ask_llm("Quelle est l'équation de Bernouilli avec les termes de pompe et pertes de charges? Réponds brièvement"))
244
+
245
+ if __name__ == "__main__":
246
+ test_cours_TSTL()
247
+
248
+
249
+
src/store.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from math import sqrt
2
+ import operator
3
+ import json
4
+ import os
5
+ from pathlib import Path
6
+ from .astore import AStore
7
+
8
+ class Document:
9
+ '''
10
+ Un document est :
11
+ une chaîne de caractère, le chunk
12
+ une source, livre ou page ou chapitre
13
+ un vecteur issu d'un modèle d'embedding
14
+ un id, calculé automatiquement par la collection
15
+ '''
16
+ def __init__(self, chunk:str, source:str, vec:list[float], idd:int):
17
+ self.chunk = chunk
18
+ self.source = source
19
+ self.vec = vec
20
+ self.idd = idd
21
+
22
+ def get_json(self):
23
+ json = {
24
+ 'c':self.chunk,
25
+ 's':self.source,
26
+ 'v':self.vec
27
+ }
28
+
29
+ return json
30
+
31
+ class Collection:
32
+ '''
33
+ Une Collection est :
34
+ un nom
35
+ une liste de documents
36
+ un id, calculé automatiquement par le Store
37
+ Une collection est sauvée dans un fichier idc.col
38
+ le nom de la collection
39
+ la liste des documents:
40
+ chunk
41
+ source
42
+ vector
43
+ '''
44
+ def __init__(self,
45
+ name:str,
46
+ docs:list[Document],
47
+ idc:int):
48
+ self.name = name
49
+ self.docs = docs
50
+ self.idc = idc
51
+
52
+ def add_document(self, chunk:str, source:str, vec:list[float])->Document:
53
+ '''
54
+ Ajoute un document à la collection
55
+ Args:
56
+ chunk: le texte du document
57
+ source: la source du document (livre, chap...)
58
+ vec: la représentation vectorielle du document
59
+ Returns:
60
+ un Document ou None si problème rencontré
61
+ Raise:
62
+ si un des paramètres n'est pas défini
63
+ '''
64
+ if chunk == None or source == None or vec == None:
65
+ raise Exception("Document error: chunk, source or vec is None !")
66
+ idd:int = len(self.docs) + 1
67
+ doc:Document = Document(chunk, source, vec, idd)
68
+ self.docs.append(doc)
69
+ return doc
70
+
71
+ def get_length_octets(self)->int:
72
+ '''
73
+ Return la taille en octets de la collection
74
+ '''
75
+ if len(self.docs) == 0:
76
+ return 0
77
+ vector_size = len(self.docs[0])
78
+ return len(self.docs) * vector_size * 4 # un float sur 4 octets
79
+
80
+ @classmethod
81
+ def from_disk(self, file_path:str):
82
+ '''
83
+ Méthode de classe qui renvoie une Collection à partir d'un fichier de la base
84
+ Args:
85
+ file_path: le chemin vers le fichier
86
+ Return:
87
+ la Collection
88
+ Exception:
89
+ si le fichier n'existe pas ou qu'on ne peut pas le lire
90
+ '''
91
+ if not os.path.exists(file_path):
92
+ raise Exception("File {file} doesn't exist !".format(file=file_path))
93
+ idc:int = int(Path(file_path).stem)
94
+ # print("Collection.from_disk, reading : ", idc)
95
+ try:
96
+ with open(file_path, "r") as f:
97
+ datas = json.load(f)
98
+ name:str = datas['name']
99
+ docs = []
100
+ idd: int = 1
101
+ for d in datas['docs']:
102
+ doc:Document = Document(d['c'], d['s'], d['v'], idd)
103
+ docs.append(doc)
104
+ idd += 1
105
+ return Collection(name, docs, idc)
106
+ except:
107
+ raise Exception("Unable to read {file_path} !".format(file_path=file_path))
108
+
109
+ def save(self, persist_dir:str):
110
+ '''
111
+ La collection est enregistrée avec le nom idc.col dans le persist_dir
112
+ Args:
113
+ persist_dir: le chemin du repertoire de la bdd
114
+ Exception:
115
+ Si on ne peut pas sauver sur le disque
116
+ '''
117
+
118
+ file_path:str = os.path.join(persist_dir, str(self.idc)) + ".col"
119
+ # print("Collection.save : ", file_path)
120
+ json_object = {
121
+ 'name':self.name,
122
+ 'docs':[]
123
+ }
124
+ for doc in self.docs:
125
+ json_object['docs'].append(doc.get_json())
126
+ json_object = json.dumps(json_object)
127
+ try:
128
+ with open(file_path, "w+") as f:
129
+ f.write(json_object)
130
+ except:
131
+ raise Exception("Unable to save the collection {name}, id={id} !".format(name=self.name, id=self.idc))
132
+
133
+ def delete(self, persist_dir:str)->None:
134
+ '''
135
+ Supprime la collection de la bdd
136
+ Args:
137
+ persist_dir: le chemin du repertoire de la bdd
138
+ Exception:
139
+ Si on ne peut pas supprimer du disque
140
+ '''
141
+ self.docs.clear()
142
+ file_path:str = os.path.join(persist_dir, str(self.idc)) + ".col"
143
+ try:
144
+ os.remove(file_path)
145
+ except:
146
+ raise Exception("Unable to delete the collection {name}, id={id} !".format(name=self.name, id=self.idc))
147
+
148
+ class Store(AStore):
149
+ '''
150
+ Un store est une liste de collections.
151
+ A chaque création, ajout ou suppression d'un élément, la base est sauvée si elle est persistante
152
+ Sur le disque, dans store_dir:
153
+ Un sous-repertoire par collection, portant le nom de la collection
154
+ Dans chaque sous-repertoire d'une collection : la liste des vecteurs
155
+
156
+ '''
157
+ def __init__(self, persist_dir:str):
158
+ ''' Constructeur de Store
159
+ Args:
160
+ dir_name: le répertoire persistant de la base de données ou None
161
+ Exception:
162
+ Dans le cas d'une base persistante:
163
+ Impossible de créer le répertoire persistant
164
+ Impossible de lire les collections
165
+ '''
166
+ self.persist_dir = persist_dir
167
+ self.collections = []
168
+ if persist_dir == None: # store éphémère
169
+ pass # Rien à faire
170
+ else:
171
+ # Charger la liste des collections
172
+ try:
173
+ self._create_persist_dir()
174
+ files = [os.path.join(persist_dir, f) for f in os.listdir(persist_dir) if os.path.isfile(os.path.join(persist_dir, f))]
175
+ for f in files:
176
+ col: Collection = Collection.from_disk(f)
177
+ self.collections.append(col)
178
+ except Exception as e:
179
+ raise
180
+
181
+ def reset(self)->None:
182
+ '''
183
+ Vide la base et l'efface du disque si elle est persistante
184
+ Exception:
185
+ Dans le cas d'une base persistante:
186
+ Impossible de créer le répertoire persistant
187
+ Impossible de lire les collections
188
+ '''
189
+ self.collections = []
190
+ if self.persist_dir == None: # store éphémère
191
+ pass
192
+ else:
193
+ try:
194
+ # Supprimer les fichiers du disque
195
+ if os.path.exists(self.persist_dir):
196
+ files = [os.path.join(self.persist_dir, f) for f in os.listdir(self.persist_dir) if os.path.isfile(os.path.join(self.persist_dir, f))]
197
+ # print(files)
198
+ for f in files:
199
+ os.remove(f)
200
+ os.rmdir(self.persist_dir)
201
+ except Exception as e:
202
+ raise
203
+
204
+ def get_collection_names(self)->list[str]:
205
+ return [col.name for col in self.collections]
206
+
207
+ def print_infos(self)->None:
208
+ ''' Affiche le nombre de collections et pour chaque collection, affiche son nom et son nombre de documents '''
209
+ print("-------- STORE INFOS ---------------")
210
+ for col in self.collections:
211
+ print(col.name)
212
+ # idds = [doc.idd for doc in col.docs]
213
+ # print("\t", idds)
214
+ print("\tdocuments:", len(col.docs))
215
+ print("-------- /STORE INFOS ---------------")
216
+
217
+ def get_collection(self, collection_name:str)->Collection:
218
+ '''
219
+ Renvoie la collection dont le nom est 'collection_name' ou None si elle n'existe pas
220
+ '''
221
+ for col in self.collections:
222
+ if col.name == collection_name:
223
+ return col
224
+ return None
225
+
226
+ def _create_persist_dir(self):
227
+ '''
228
+ Recrée le répertoir persistant s'il a disparu après un reset par exemple
229
+ Exception:
230
+ Si on ne peut pas créer le 'persist_dir'
231
+ '''
232
+ # Vérifier si le persist_dir existe, sinon le créer
233
+ print("Persist_dir:" + self.persist_dir)
234
+ try:
235
+ if not os.path.exists(self.persist_dir):
236
+ os.mkdir(self.persist_dir)
237
+ except:
238
+ raise Exception("Unable to create the persit directory: {dir}".format(dir=self.persist_dir))
239
+
240
+ def create_collection(self, name:str)->Collection:
241
+ '''
242
+ Crée et renvoie une nouvelle collection vide de documents
243
+ Args:
244
+ name: le nom de la création à créer
245
+ Exception:
246
+ Dans le cas d'une base persistante:
247
+ Impossible de créer le répertoire persistant
248
+ Impossible de sauver la collection
249
+ '''
250
+ idc:int = len(self.collections) + 1
251
+ col:Collection = Collection(name, [], idc)
252
+ if self.persist_dir != None:
253
+ try:
254
+ self._create_persist_dir()
255
+ col.save(self.persist_dir)
256
+ except:
257
+ raise
258
+ return col
259
+
260
+
261
+ def add_to_collection(self, collection_name:str, source:str, vectors:list[list[float]], chunks:list[str])->None:
262
+ '''
263
+ Ajoute une liste de vecteurs à la collection 'collection_name'
264
+ Args:
265
+ collection_name: le nom de la collection
266
+ source: la source unique des chunks, par exemple un nom de fichier, une url ...
267
+ vectors: la liste des vecteurs obtenus à l'aide d'un modèle d'embeddings
268
+ chunks: la liste des chunks (documents) correspondant aux vecteurs
269
+ Exception:
270
+ Dans le cas d'une base persistante:
271
+ Impossible de créer le répertoire persistant
272
+ Impossible de sauver la collection
273
+ '''
274
+ col:Collection = self.get_collection(collection_name)
275
+ if col == None:
276
+ col = self.create_collection(collection_name)
277
+ self.collections.append(col)
278
+ for i in range(len(chunks)):
279
+ col.add_document(chunks[i], source, vectors[i])
280
+ if self.persist_dir != None:
281
+ try:
282
+ self._create_persist_dir()
283
+ col.save(self.persist_dir)
284
+ except:
285
+ raise
286
+
287
+ def delete_collection(self, name:str)->None:
288
+ ''' Vide et supprime la collection dont le nom est 'name', et la supprime du disque si elle est persistante '''
289
+ col = self.get_collection(name)
290
+ if col != None:
291
+ self.collections.remove(col)
292
+ if self.persist_dir != None:
293
+ try:
294
+ self._create_persist_dir()
295
+ col.delete(self.persist_dir)
296
+ except:
297
+ raise
298
+
299
+ def normalize(self, v:list[float])->list[float]:
300
+ '''
301
+ Normalement les LLMs renvoient des vecteurs normalisés mais:
302
+ c'est pas sûr pour ceux que je n'ai pas testés
303
+ c'est pratique d'avoir cette méthode pour 'test_store.py'
304
+ Args:
305
+ v: le vecteur à normaliser
306
+ Returns:
307
+ le vecteur normalisé
308
+ '''
309
+ norm = 0.0
310
+ for i in range(len(v)):
311
+ norm += v[i] * v[i]
312
+ norm = sqrt(norm)
313
+ if norm == 0.0:
314
+ return v.copy()
315
+ result = [None] * len(v)
316
+ for i in range(len(v)):
317
+ result[i] = v[i] / norm
318
+ return result
319
+
320
+ def dot_product(self, v1:list[float], v2:list[float])->float:
321
+ '''
322
+ Le produit scalaire est utilisé pour une similarité en cosinus:
323
+ cos(a) = (vecA dot vecB) / (A.B)
324
+ si les vecteurs A et B sont normalisés, le cos est simplement le produit scalaire
325
+ Args:
326
+ v1, v2: les deux vecteurs à multiplier
327
+ Returns:
328
+ Un float égal à v1 dot v2
329
+ '''
330
+ result = 0.0
331
+ for i in range(len(v1)):
332
+ result += v1[i] * v2[i]
333
+ return result
334
+
335
+ def get_similar_vector(self, vector:list[float], collection_name:str)->list[float]:
336
+ '''
337
+ Renvoie le vecteur de 'collection' le pus similaire à 'vector'.
338
+ Args:
339
+ vector: un vecteur obtenu avec le même modèle d'embeddings que les vecteurs de la 'collection'
340
+ collection_name: le nom de la collection de la base dans laquelle on cherche une similarité
341
+ Return:
342
+ Le vecteur le plus similaire 'vector'
343
+ '''
344
+ col:Collection = self.get_collection(collection_name)
345
+ best_doc:Document = None
346
+ best_dp: float = -20.0
347
+ if col != None:
348
+ for doc in col.docs:
349
+ dp:float = self.dot_product(vector, doc.vec)
350
+ if dp > best_dp:
351
+ best_dp = dp
352
+ best_doc = doc
353
+ return best_doc.vec
354
+ else:
355
+ return None
356
+
357
+ def get_similar_chunk(self, query_vector:list[float], collection_name:str)->tuple[str, str]:
358
+ '''
359
+ Renvoie le document de la 'collection' le plus similaire à 'query_vector'.
360
+ Args:
361
+ query_vector: un vecteur obtenu avec le même modèle d'embeddings que les vecteurs de la 'collection'
362
+ collection: la collection de la base dans laquelle on cherche une similarité
363
+ Returns:
364
+ Un tuple contenant:
365
+ le document
366
+ la source du document
367
+ '''
368
+ col:Collection = self.get_collection(collection_name)
369
+ best_doc:Document = None
370
+ best_dp: float = -20.0
371
+ if col != None:
372
+ for doc in col.docs:
373
+ dp:float = self.dot_product(query_vector, doc.vec)
374
+ print(dp)
375
+ if dp > best_dp:
376
+ best_dp = dp
377
+ best_doc = doc
378
+ return best_doc.chunk, best_doc.source
379
+ else:
380
+ return None, None
381
+
382
+ def get_similar_chunks(self, query_vector:list[float], count:int, collection_name:str):
383
+ '''
384
+ Returns:
385
+ Un tuple contenant:
386
+ les documents
387
+ la source des documents
388
+ les ids des documents
389
+ a[0:count-1]
390
+ '''
391
+ # start:int = time.time()
392
+ col:Collection = self.get_collection(collection_name)
393
+ if col == None:
394
+ return None, None, None
395
+ bests:list[dict] = []
396
+ # Ajouter tous les docs avec leur dotproduct à la liste bests
397
+ for doc in col.docs:
398
+ dp:float = self.dot_product(query_vector, doc.vec)
399
+ bests.append({'doc':doc, 'dp':dp})
400
+ # Trier la liste en reverse à partir de la clé 'dp'
401
+ bests.sort(key=operator.itemgetter('dp'), reverse=True)
402
+ # Adapter le nombre de documents à renvoyer s'il n'y a pas assez de chunks
403
+ n:int = count if len(bests) >= count else len(bests)
404
+ # print("get_similar_chunks, count=", count, ", n=", n)
405
+ # Créer les variables de retour
406
+ docs = [b['doc'].chunk for b in bests[0:n]]
407
+ source = bests[0]['doc'].source if n > 0 else None
408
+ ids = [b['doc'].idd for b in bests[0:n]]
409
+ # print("my_store.get_similar_chunks:", time.time() - start, "s")
410
+ return docs, source, ids
411
+