First commit
Browse files- .gitignore +6 -0
- .python-version +1 -0
- files/drane.jpg +0 -0
- requirements.txt +5 -0
- src/amodel.py +58 -0
- src/astore.py +39 -0
- src/chunker.py +36 -0
- src/model_huggingface.py +61 -0
- src/model_mistral.py +63 -0
- src/model_ollama.py +49 -0
- src/model_openai.py +65 -0
- src/rag.py +249 -0
- src/store.py +411 -0
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
venv/
|
2 |
+
__pycache__/
|
3 |
+
.vscode/
|
4 |
+
.gradio/
|
5 |
+
.env
|
6 |
+
files/rag_app/
|
.python-version
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
3.11
|
files/drane.jpg
ADDED
![]() |
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pypdf
|
2 |
+
openai
|
3 |
+
huggingface-hub
|
4 |
+
ollama
|
5 |
+
mistralai
|
src/amodel.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
from enum import Enum
|
3 |
+
|
4 |
+
class ModelType(Enum):
|
5 |
+
''' Les différentes technos de models '''
|
6 |
+
MTOPENAI = 1
|
7 |
+
MTOLLAMA = 2
|
8 |
+
MTHUGGINGFACE = 3
|
9 |
+
MTMISTRAL = 4
|
10 |
+
|
11 |
+
@classmethod
|
12 |
+
def to_str(self, mt:int)->str:
|
13 |
+
match mt:
|
14 |
+
case 1: return "MTOPENAI"
|
15 |
+
case 2: return "MTOLLAMA"
|
16 |
+
case 3: return "MTHUGGINGFACE"
|
17 |
+
case 4: return "MTMISTRAL"
|
18 |
+
case _: return "UNKNOWN"
|
19 |
+
|
20 |
+
class AModel(ABC):
|
21 |
+
'''
|
22 |
+
Classe abstraite de base pour tous les models :
|
23 |
+
Ollama en local
|
24 |
+
OpenAI distant
|
25 |
+
HuggingFace distant
|
26 |
+
HuggingFace dans une app
|
27 |
+
...
|
28 |
+
'''
|
29 |
+
|
30 |
+
@abstractmethod
|
31 |
+
def ask_llm(self, question:str)->str:
|
32 |
+
pass
|
33 |
+
|
34 |
+
@abstractmethod
|
35 |
+
def create_vector(self, chunk:str)->list[float]:
|
36 |
+
pass
|
37 |
+
|
38 |
+
@abstractmethod
|
39 |
+
def create_vectors(self, chunks:list[str])->list[list[float]]:
|
40 |
+
pass
|
41 |
+
|
42 |
+
def get_llm_name(self):
|
43 |
+
return self.llm_name
|
44 |
+
|
45 |
+
def set_llm_name(self, llm_name:str):
|
46 |
+
self.llm_name = llm_name
|
47 |
+
|
48 |
+
def get_feature_name(self):
|
49 |
+
return self.feature_name
|
50 |
+
|
51 |
+
def set_feature_name(self, feature_name:str):
|
52 |
+
self.feature_name = feature_name
|
53 |
+
|
54 |
+
def get_temperature(self):
|
55 |
+
return self.temperature
|
56 |
+
|
57 |
+
def set_temperature(self, temperature:float):
|
58 |
+
self.temperature = temperature
|
src/astore.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC, abstractmethod
|
2 |
+
|
3 |
+
|
4 |
+
class AStore(ABC):
|
5 |
+
'''
|
6 |
+
Classe abstraite de base pour tous les stores :
|
7 |
+
Chroma
|
8 |
+
Perso
|
9 |
+
...
|
10 |
+
'''
|
11 |
+
|
12 |
+
@abstractmethod
|
13 |
+
def reset(self)->None:
|
14 |
+
pass
|
15 |
+
|
16 |
+
@abstractmethod
|
17 |
+
def print_infos(self)->None:
|
18 |
+
pass
|
19 |
+
|
20 |
+
@abstractmethod
|
21 |
+
def add_to_collection(self, collection_name:str, source:str, vectors:list[list[float]], chunks:list[str])->None:
|
22 |
+
pass
|
23 |
+
|
24 |
+
@abstractmethod
|
25 |
+
def delete_collection(self, name:str)->None:
|
26 |
+
pass
|
27 |
+
|
28 |
+
@abstractmethod
|
29 |
+
def get_similar_vector(self, vector:list[float], collection_name:str)->list[float]:
|
30 |
+
pass
|
31 |
+
|
32 |
+
@abstractmethod
|
33 |
+
def get_similar_chunk(self, query_vector:list[float], collection_name:str)->tuple[str, str]:
|
34 |
+
pass
|
35 |
+
|
36 |
+
@abstractmethod
|
37 |
+
def get_similar_chunks(self, query_vector:list[float], count:int, collection_name:str):
|
38 |
+
pass
|
39 |
+
|
src/chunker.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
|
3 |
+
class Chunker:
|
4 |
+
'''
|
5 |
+
Tronçonnage d'un texte en chunks
|
6 |
+
'''
|
7 |
+
|
8 |
+
def __init__(self):
|
9 |
+
pass
|
10 |
+
|
11 |
+
def split_basic(self, text:str, char_count:int, overlap:int)->list[str]:
|
12 |
+
'''
|
13 |
+
Découpe le texte avec des '\n'.
|
14 |
+
La taille d'un chunk est de max count + 2 * overlap
|
15 |
+
chunk = o1-c-o2
|
16 |
+
o1: les mots du chunk précédent ajoutés, il y en a 'overlap' ou 0 pour le premier chunk
|
17 |
+
c: partie centrale du chunk
|
18 |
+
o2: les mots du chunk suivant ajoutés, il y en a 'overlap' ou 0 pour le dernier chunk
|
19 |
+
Args:
|
20 |
+
char_count: le nombre de caractères dans un chunk (sans compter les mots ajoutés par recouvrement)
|
21 |
+
overlap: le nombre de caractères du chunk précédent (et suivant) ajoutés au début (et à la fin) du chunk
|
22 |
+
Return:
|
23 |
+
La liste des chunks
|
24 |
+
'''
|
25 |
+
# La liste qui sera renvoyée
|
26 |
+
chunks:list[str] = [] # la liste qui sera renvoyée
|
27 |
+
# Découpage du texte en morceaux de 'char_count' caractères
|
28 |
+
n:int = len(text)
|
29 |
+
size:int = n // char_count + 1 # nombre de chunks
|
30 |
+
for i in range(size):
|
31 |
+
start = i*char_count if i == 0 else i*char_count - overlap
|
32 |
+
stop = (i+1)*char_count if i == size - 1 else (i+1)*char_count + overlap
|
33 |
+
s = slice(start, stop)
|
34 |
+
chunk:str = text[s]
|
35 |
+
chunks.append(chunk)
|
36 |
+
return chunks
|
src/model_huggingface.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from .amodel import AModel
|
3 |
+
from huggingface_hub import InferenceClient
|
4 |
+
import numpy as np # feature_extraction renvoie un array numpy...
|
5 |
+
|
6 |
+
|
7 |
+
class HuggingFaceModel(AModel):
|
8 |
+
|
9 |
+
def __init__(self, llm_name:str, feature_name:str, temperature:float=0.0):
|
10 |
+
self.llm_name:str = llm_name
|
11 |
+
self.feature_name:str = feature_name
|
12 |
+
self.temperature = temperature
|
13 |
+
# La variable HF_ACTIVE a été créée dans les settings de l'app sur HuggingFace
|
14 |
+
if (os.getenv("HF_ACTIVE")): # Lancement depuis l'app sur HuggingFace
|
15 |
+
api_token = os.getenv("HF_TOKEN")
|
16 |
+
else: # Lancement depuis mon ordi
|
17 |
+
# print("Launch Rag in HuggingFace local")
|
18 |
+
from dotenv import load_dotenv # Trick: ne passe pas dans une app sur HuggingFace
|
19 |
+
load_dotenv()
|
20 |
+
api_token = os.getenv("HUGGINGFACEHUB_API_TOKEN")
|
21 |
+
try:
|
22 |
+
self.model = InferenceClient(api_key=api_token)
|
23 |
+
except:
|
24 |
+
raise
|
25 |
+
|
26 |
+
def ask_llm(self, question:str)->str:
|
27 |
+
messages = [{"role": "user", "content": question}]
|
28 |
+
try:
|
29 |
+
resp = self.model.chat.completions.create(
|
30 |
+
model=self.llm_name,
|
31 |
+
messages=messages,
|
32 |
+
max_tokens=500,
|
33 |
+
temperature=self.temperature,
|
34 |
+
# stream=True
|
35 |
+
)
|
36 |
+
return resp.choices[0].message.content
|
37 |
+
except:
|
38 |
+
raise
|
39 |
+
|
40 |
+
def create_vector(self, chunk:str)->list[float]:
|
41 |
+
resp = self.model.feature_extraction(
|
42 |
+
text=chunk,
|
43 |
+
# normalize=True, # Only available on server powered by Text-Embedding-Inference.
|
44 |
+
model=self.feature_name, # normalisé ??
|
45 |
+
)
|
46 |
+
return resp
|
47 |
+
|
48 |
+
def create_vectors(self, chunks:list[str])->list[list[float]]:
|
49 |
+
'''
|
50 |
+
Pas de batch pour la création de vectors sur HuggingFace, on les passe un par un
|
51 |
+
'''
|
52 |
+
vectors = []
|
53 |
+
try:
|
54 |
+
for chunk in chunks:
|
55 |
+
v = self.create_vector(chunk)
|
56 |
+
if not isinstance(v, np.ndarray):
|
57 |
+
raise
|
58 |
+
vectors.append(v.tolist())
|
59 |
+
return vectors
|
60 |
+
except:
|
61 |
+
raise
|
src/model_mistral.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
from .amodel import AModel
|
5 |
+
from mistralai import Mistral
|
6 |
+
|
7 |
+
class MistralModel(AModel):
|
8 |
+
'''
|
9 |
+
https://docs.mistral.ai/capabilities/completion/
|
10 |
+
https://docs.mistral.ai/capabilities/embeddings/
|
11 |
+
temperature entre 0.0 et 0.7
|
12 |
+
'''
|
13 |
+
|
14 |
+
def __init__(self, llm_name:str, feature_name:str, temperature:float=0.0):
|
15 |
+
self.llm_name:str = llm_name
|
16 |
+
self.feature_name:str = feature_name
|
17 |
+
self.temperature = temperature
|
18 |
+
load_dotenv()
|
19 |
+
try:
|
20 |
+
self.model = Mistral(api_key=os.getenv("MISTRAL_API_KEY"))
|
21 |
+
except:
|
22 |
+
raise
|
23 |
+
|
24 |
+
|
25 |
+
def ask_llm(self, question:str)->str:
|
26 |
+
try:
|
27 |
+
response = self.model.chat.complete(
|
28 |
+
model=self.llm_name,
|
29 |
+
messages = [{ "role": "user", "content": question, },],
|
30 |
+
temperature=self.temperature
|
31 |
+
)
|
32 |
+
return response.choices[0].message.content
|
33 |
+
except:
|
34 |
+
raise
|
35 |
+
|
36 |
+
def create_vector(self, chunk:str)->list[float]:
|
37 |
+
'''
|
38 |
+
Renvoie un vecteur de taille 1024 à partir de chunk
|
39 |
+
'''
|
40 |
+
try:
|
41 |
+
response = self.model.embeddings.create(
|
42 |
+
model=self.feature_name,
|
43 |
+
# inputs=["Embed this sentence.", "As well as this one."],
|
44 |
+
inputs=[chunk]
|
45 |
+
)
|
46 |
+
return response.data[0].embedding
|
47 |
+
except:
|
48 |
+
raise
|
49 |
+
|
50 |
+
def create_vectors(self, chunks:list[str])->list[list[float]]:
|
51 |
+
'''
|
52 |
+
Renvoie n vecteurs de taille 1024 à partir de la liste chunks
|
53 |
+
'''
|
54 |
+
try:
|
55 |
+
response = self.model.embeddings.create(
|
56 |
+
model=self.feature_name,
|
57 |
+
inputs=chunks,
|
58 |
+
)
|
59 |
+
n:int = len(chunks)
|
60 |
+
result = [response.data[i].embedding for i in range(n)]
|
61 |
+
return result
|
62 |
+
except:
|
63 |
+
raise
|
src/model_ollama.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .amodel import AModel
|
2 |
+
import ollama
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
class OllamaModel(AModel):
|
6 |
+
|
7 |
+
def __init__(self, llm_name:str, feature_name:str, temperature:float=0.0):
|
8 |
+
self.llm_name:str = llm_name
|
9 |
+
self.feature_name:str = feature_name
|
10 |
+
self.temperature = temperature
|
11 |
+
|
12 |
+
def ask_llm(self, question:str)->str:
|
13 |
+
try:
|
14 |
+
resp = ollama.chat(
|
15 |
+
model=self.llm_name,
|
16 |
+
messages=[{'role':'user', 'content':question}],
|
17 |
+
stream=False,
|
18 |
+
options={"temperature":self.temperature})
|
19 |
+
return resp.message.content
|
20 |
+
except:
|
21 |
+
raise
|
22 |
+
|
23 |
+
def create_vector(self, chunk:str)->list[float]:
|
24 |
+
'''
|
25 |
+
TODO: Vérifier s'il ne faut pas utiliser 'embed' plutôt que 'embeddings'
|
26 |
+
'''
|
27 |
+
try:
|
28 |
+
resp = ollama.embeddings(
|
29 |
+
model=self.feature_name,
|
30 |
+
prompt=chunk)
|
31 |
+
return self.normalize(resp.embedding).tolist()
|
32 |
+
except:
|
33 |
+
raise
|
34 |
+
|
35 |
+
def normalize(self, v:list[float]):
|
36 |
+
norm = np.linalg.norm(v)
|
37 |
+
if norm == 0:
|
38 |
+
return v
|
39 |
+
return v / norm
|
40 |
+
|
41 |
+
def create_vectors(self, chunks:list[str])->list[list[float]]:
|
42 |
+
try:
|
43 |
+
resp = ollama.embed(
|
44 |
+
model=self.feature_name,
|
45 |
+
input=chunks)
|
46 |
+
# print(resp.embeddings)
|
47 |
+
return resp.embeddings
|
48 |
+
except:
|
49 |
+
raise
|
src/model_openai.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
from dotenv import load_dotenv
|
4 |
+
from .amodel import AModel
|
5 |
+
from openai import OpenAI
|
6 |
+
|
7 |
+
class OpenAIModel(AModel):
|
8 |
+
'''
|
9 |
+
https://platform.openai.com/docs/guides/text-generation
|
10 |
+
|
11 |
+
'''
|
12 |
+
|
13 |
+
def __init__(self, llm_name:str, feature_name:str, temperature:float=0.0):
|
14 |
+
self.llm_name:str = llm_name
|
15 |
+
self.feature_name:str = feature_name
|
16 |
+
self.temperature = temperature
|
17 |
+
load_dotenv()
|
18 |
+
try:
|
19 |
+
self.model = OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
20 |
+
except:
|
21 |
+
raise
|
22 |
+
|
23 |
+
def ask_llm(self, question:str)->str:
|
24 |
+
try:
|
25 |
+
response = self.model.chat.completions.create(
|
26 |
+
# model="gpt-4o-mini",
|
27 |
+
model=self.llm_name,
|
28 |
+
messages=[
|
29 |
+
{"role":"system", "content":""},
|
30 |
+
{"role":"user", "content":question},
|
31 |
+
],
|
32 |
+
temperature=self.temperature
|
33 |
+
)
|
34 |
+
return response.choices[0].message.content
|
35 |
+
except:
|
36 |
+
raise
|
37 |
+
|
38 |
+
def create_vector(self, chunk:str)->list[float]:
|
39 |
+
'''
|
40 |
+
8192 tokens max
|
41 |
+
'''
|
42 |
+
# les embeddings d'OpenAI sont normalisés à 1
|
43 |
+
try:
|
44 |
+
response = self.model.embeddings.create(
|
45 |
+
input=chunk,
|
46 |
+
model=self.feature_name
|
47 |
+
)
|
48 |
+
return response.data[0].embedding
|
49 |
+
except:
|
50 |
+
raise
|
51 |
+
|
52 |
+
def create_vectors(self, chunks:list[str])->list[list[float]]:
|
53 |
+
'''
|
54 |
+
Pas plus de 2048 chunks
|
55 |
+
'''
|
56 |
+
try:
|
57 |
+
response = self.model.embeddings.create(
|
58 |
+
input=chunks,
|
59 |
+
model=self.feature_name
|
60 |
+
)
|
61 |
+
n:int = len(chunks)
|
62 |
+
result = [response.data[i].embedding for i in range(n)]
|
63 |
+
return result
|
64 |
+
except:
|
65 |
+
raise
|
src/rag.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
from pypdf import PdfReader
|
4 |
+
|
5 |
+
from .chunker import Chunker
|
6 |
+
from .amodel import ModelType
|
7 |
+
from .model_openai import OpenAIModel
|
8 |
+
from .model_huggingface import HuggingFaceModel
|
9 |
+
from .model_ollama import OllamaModel
|
10 |
+
from .model_mistral import MistralModel
|
11 |
+
from .store import Store
|
12 |
+
|
13 |
+
|
14 |
+
class Rag:
|
15 |
+
'''
|
16 |
+
Classe qui s'occupe de toute la chaine du RAG.
|
17 |
+
Elle permet :
|
18 |
+
d'interroger un llm directement (sans RAG) avec ask_llm()
|
19 |
+
d'interroger le RAG lui même avec ask_rag()
|
20 |
+
d'ajouter des documents à la base de données du RAG
|
21 |
+
de remettre la base à zéro
|
22 |
+
de créer des vecteurs
|
23 |
+
de charger des pdf
|
24 |
+
'''
|
25 |
+
|
26 |
+
# Le prompt qui sera utilisé uniquement avec le RAG
|
27 |
+
prompt_template = """
|
28 |
+
En vous basant **uniquement** sur les informations fournies dans le contexte
|
29 |
+
ci-dessous, répondez à la question posée.
|
30 |
+
Les équations seront écrites en latex.
|
31 |
+
Si vous ne trouvez pas la réponse dans le contexte, répondez "Je ne sais pas".
|
32 |
+
Contexte : {context}
|
33 |
+
Question : {question}
|
34 |
+
"""
|
35 |
+
|
36 |
+
def __init__(self, model_type:ModelType, store_dir:str) -> None:
|
37 |
+
'''
|
38 |
+
Constructeur du Rag
|
39 |
+
Args:
|
40 |
+
model_type: la techno utilisée
|
41 |
+
store_dir: le répertoire de persistance de la base de données ou None
|
42 |
+
Exception:
|
43 |
+
Si le model ne peut pas être créé
|
44 |
+
Si le type de model est inconnu
|
45 |
+
'''
|
46 |
+
self.model_type = model_type
|
47 |
+
try:
|
48 |
+
match model_type:
|
49 |
+
case ModelType.MTOPENAI:
|
50 |
+
self.model = OpenAIModel("gpt-4o-mini", "text-embedding-3-small", 0)
|
51 |
+
case ModelType.MTHUGGINGFACE:
|
52 |
+
self.model = HuggingFaceModel("meta-llama/Meta-Llama-3-8B-Instruct", "sentence-transformers/all-MiniLM-l6-v2", 0)
|
53 |
+
case ModelType.MTOLLAMA:
|
54 |
+
self.model = OllamaModel("llama3.2:1b", "nomic-embed-text", 0.0)
|
55 |
+
case ModelType.MTMISTRAL:
|
56 |
+
self.model = MistralModel("mistral-large-latest", "mistral-embed", 0.0)
|
57 |
+
case _:
|
58 |
+
raise Exception("Rag.__init__: Unknown model type: {mt} : {v}".format(mt=ModelType.to_str(model_type), v=model_type))
|
59 |
+
self.emb_store = Store(store_dir) # persistant
|
60 |
+
# self.emb_store = Store(None) # éphémère
|
61 |
+
except Exception as e:
|
62 |
+
raise
|
63 |
+
|
64 |
+
def get_llm_name(self):
|
65 |
+
return self.model.get_llm_name()
|
66 |
+
|
67 |
+
def get_feature_name(self):
|
68 |
+
return self.model.get_feature_name()
|
69 |
+
|
70 |
+
def get_temperature(self):
|
71 |
+
return self.model.get_temperature()
|
72 |
+
|
73 |
+
def set_temperature(self, temperature:float):
|
74 |
+
self.model.set_temperature(temperature)
|
75 |
+
|
76 |
+
def reset_store(self):
|
77 |
+
self.emb_store.reset()
|
78 |
+
|
79 |
+
def delete_collection(self, name:str)->None:
|
80 |
+
self.emb_store.delete_collection(name)
|
81 |
+
|
82 |
+
def create_vectors(self, chunks:list[str])->list[list[float]]:
|
83 |
+
'''
|
84 |
+
Renvoie les vecteurs correspondant à 'chunks', calculés par 'emb_model'
|
85 |
+
Args:
|
86 |
+
chunks: les extraits de texte à calculer
|
87 |
+
Return:
|
88 |
+
la liste des vecteurs calculés
|
89 |
+
'''
|
90 |
+
vectors:list = []
|
91 |
+
tokens:int = 0
|
92 |
+
vectors:list[list[float]] = self.model.create_vectors(chunks) # batch si le model le permet
|
93 |
+
# for chunk in chunks:
|
94 |
+
# vector:list[float] = self.model.create_vector(chunk=chunk)
|
95 |
+
# vectors.append(vector)
|
96 |
+
return vectors
|
97 |
+
|
98 |
+
def load_pdf(self, file_name:str)->str:
|
99 |
+
''' Charge le fichier 'file_name' et renvoie son contenu sous forme de texte. '''
|
100 |
+
reader = PdfReader(file_name)
|
101 |
+
content = ""
|
102 |
+
for page in reader.pages:
|
103 |
+
content += page.extract_text() + "\n"
|
104 |
+
return content
|
105 |
+
|
106 |
+
def get_chunks(self, text:str)->list:
|
107 |
+
'''
|
108 |
+
Découpe le 'text' en chunks de taille chunk_size avec un recouvrement
|
109 |
+
Args:
|
110 |
+
text: Le texte à découper
|
111 |
+
Return:
|
112 |
+
La liste des chunks
|
113 |
+
'''
|
114 |
+
# splitter = RecursiveCharacterTextSplitter(
|
115 |
+
# # separator="\n",
|
116 |
+
# chunk_size=1000,
|
117 |
+
# chunk_overlap=200,
|
118 |
+
# length_function=len,
|
119 |
+
# is_separator_regex=False
|
120 |
+
# )
|
121 |
+
# chunks = splitter.split_text(text)
|
122 |
+
# print("get_chunks: " + str(len(chunks)))
|
123 |
+
chunker = Chunker()
|
124 |
+
chunks = chunker.split_basic(text=text, char_count=1000, overlap=200)
|
125 |
+
return chunks
|
126 |
+
|
127 |
+
def add_pdf_to_store(self, file_name:str, collection_name:str)->None:
|
128 |
+
'''
|
129 |
+
Ajoute un pdf à la base de données du RAG.
|
130 |
+
Args:
|
131 |
+
file_name: le chemin vers le fichier à ajouter
|
132 |
+
collection_name: Le nom de la collection dans laquelle il faut ajouter les chunks
|
133 |
+
La collection est créée si elle n'existe pas.
|
134 |
+
'''
|
135 |
+
text:str = self.load_pdf(file_name)
|
136 |
+
chunks:list[str] = self.get_chunks(text)
|
137 |
+
self.add_chunks_to_store(chunks=chunks, collection_name=collection_name, source=file_name)
|
138 |
+
|
139 |
+
def add_pdf_stream_to_store(self, stream, collection_name:str)->None:
|
140 |
+
'''
|
141 |
+
Ajoute un stream provenant de file_uploader de streamlit par exemple
|
142 |
+
'''
|
143 |
+
text:str = self.load_pdf(stream)
|
144 |
+
chunks:list[str] = self.get_chunks(text)
|
145 |
+
self.add_chunks_to_store(chunks=chunks, collection_name=collection_name, source="stream")
|
146 |
+
|
147 |
+
def add_chunks_to_store(self, chunks:list[str], collection_name:str, source:str)->None:
|
148 |
+
'''
|
149 |
+
Ajoute des chunks à la base de données du RAG.
|
150 |
+
Args:
|
151 |
+
chunks: les chunks à ajouter
|
152 |
+
collection_name: Le nom de la collection dans laquelle il faut ajouter les chunks
|
153 |
+
La collection est créée si elle n'existe pas.
|
154 |
+
source: la source des chunks (nom du fichier, url ...)
|
155 |
+
'''
|
156 |
+
vectors = self.create_vectors(chunks=chunks)
|
157 |
+
self.emb_store.add_to_collection(
|
158 |
+
collection_name=collection_name,
|
159 |
+
source=source,
|
160 |
+
vectors=vectors,
|
161 |
+
chunks=chunks
|
162 |
+
)
|
163 |
+
|
164 |
+
|
165 |
+
def ask_llm(self, question:str)->str:
|
166 |
+
'''
|
167 |
+
Pose une question au llm_model, attend sa réponse et la renvoie.
|
168 |
+
Args:
|
169 |
+
question: La question qu'on veut lui poser
|
170 |
+
Returns:
|
171 |
+
La réponse du llm_model
|
172 |
+
'''
|
173 |
+
return self.model.ask_llm(question=question)
|
174 |
+
|
175 |
+
def ask_rag(self, question:str, collection_name:str)->tuple[str, str, list[str], list[str]]:
|
176 |
+
'''
|
177 |
+
Pose une question au RAG, attend sa réponse et la renvoie.
|
178 |
+
Args:
|
179 |
+
question: La question qu'on veut lui poser
|
180 |
+
collection_name: le nom de la collection que l'on veut interroger
|
181 |
+
Returns:
|
182 |
+
Le prompt effectivement donné au llm_model
|
183 |
+
La réponse du llm_model
|
184 |
+
Les sources du RAG utilisées
|
185 |
+
Les ids des documents du RAG
|
186 |
+
'''
|
187 |
+
if not question:
|
188 |
+
return "", "Error: No question !", [], []
|
189 |
+
if not collection_name:
|
190 |
+
return "", "Error: No collection specified !", [], []
|
191 |
+
if not collection_name in self.emb_store.get_collection_names():
|
192 |
+
return "", "Error: {name} is no more in the database !".format(name=collection_name), [], []
|
193 |
+
# Transformer la 'question' en vecteur avec emb_model
|
194 |
+
query_vector:list[float] = self.model.create_vector(question)
|
195 |
+
# Récupérer les chunks du store similaires à la question
|
196 |
+
chunks, sources, ids = self.emb_store.get_similar_chunks(
|
197 |
+
query_vector=query_vector,
|
198 |
+
count=2,
|
199 |
+
collection_name=collection_name
|
200 |
+
)
|
201 |
+
# Préparer le prompt final à partir du prompt_template
|
202 |
+
prompt:str = self.prompt_template.format(
|
203 |
+
context="\n\n\n".join(chunks),
|
204 |
+
question=question
|
205 |
+
)
|
206 |
+
# demander au llm_model de répondre
|
207 |
+
resp:str = self.ask_llm(question=prompt)
|
208 |
+
|
209 |
+
return prompt, resp, sources, ids
|
210 |
+
|
211 |
+
def test_cours_TSTL()->None:
|
212 |
+
# Test placé ici pendant la mise au point
|
213 |
+
STORE_DIR = "./db/chroma_vectors"
|
214 |
+
# rag = Rag(ModelType.MTOPENAI, store_dir=STORE_DIR)
|
215 |
+
rag = Rag(ModelType.MTHUGGINGFACE, store_dir=STORE_DIR)
|
216 |
+
# rag = Rag(llm_type=ModelType.MTHUGGINGFACE, emb_type=ModelType.MTHUGGINGFACE, store_dir=STORE_DIR)
|
217 |
+
|
218 |
+
rag.reset_store()
|
219 |
+
rag.add_pdf_to_store("chap-1-Statique.pdf", "T_SPCL")
|
220 |
+
# rag.add_pdf_to_store("chap-2-Regulation.pdf", "T_SPCL")
|
221 |
+
# rag.add_pdf_to_store("chap-3-Dynamique.pdf", "T_SPCL")
|
222 |
+
# rag.add_pdf_to_store("chap-4-Echangeurs.pdf", "T_SPCL")
|
223 |
+
|
224 |
+
rag.emb_store.print_infos()
|
225 |
+
|
226 |
+
prompt, resp, sources, ids = rag.ask_rag(
|
227 |
+
question="Quelle est la différence entre une pression relative et une pression absolue?",
|
228 |
+
# question="Qu'est-ce qu'un échangeur à contre-courant?",
|
229 |
+
# question="Quelle est la formule de la résistance thermique? Réponds brièvement",
|
230 |
+
# question="Quelle est l'équation de Bernouilli avec les termes de pompe et pertes de charges? Réponds brièvement",
|
231 |
+
# question="Que signifie le terme de vitesse dans l'équation de Bernouilli ?",
|
232 |
+
# question="Transforme 1 bar en mètre de colonne d'eau",
|
233 |
+
# question="A quoi correspond HMT d'une pompe?",
|
234 |
+
collection_name="T_SPCL"
|
235 |
+
)
|
236 |
+
print(prompt)
|
237 |
+
print("---------------------------")
|
238 |
+
print(resp)
|
239 |
+
print("---------------------------")
|
240 |
+
print("sources:", sources)
|
241 |
+
print("ids=", ids)
|
242 |
+
|
243 |
+
# print(rag.ask_llm("Quelle est l'équation de Bernouilli avec les termes de pompe et pertes de charges? Réponds brièvement"))
|
244 |
+
|
245 |
+
if __name__ == "__main__":
|
246 |
+
test_cours_TSTL()
|
247 |
+
|
248 |
+
|
249 |
+
|
src/store.py
ADDED
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from math import sqrt
|
2 |
+
import operator
|
3 |
+
import json
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
from .astore import AStore
|
7 |
+
|
8 |
+
class Document:
|
9 |
+
'''
|
10 |
+
Un document est :
|
11 |
+
une chaîne de caractère, le chunk
|
12 |
+
une source, livre ou page ou chapitre
|
13 |
+
un vecteur issu d'un modèle d'embedding
|
14 |
+
un id, calculé automatiquement par la collection
|
15 |
+
'''
|
16 |
+
def __init__(self, chunk:str, source:str, vec:list[float], idd:int):
|
17 |
+
self.chunk = chunk
|
18 |
+
self.source = source
|
19 |
+
self.vec = vec
|
20 |
+
self.idd = idd
|
21 |
+
|
22 |
+
def get_json(self):
|
23 |
+
json = {
|
24 |
+
'c':self.chunk,
|
25 |
+
's':self.source,
|
26 |
+
'v':self.vec
|
27 |
+
}
|
28 |
+
|
29 |
+
return json
|
30 |
+
|
31 |
+
class Collection:
|
32 |
+
'''
|
33 |
+
Une Collection est :
|
34 |
+
un nom
|
35 |
+
une liste de documents
|
36 |
+
un id, calculé automatiquement par le Store
|
37 |
+
Une collection est sauvée dans un fichier idc.col
|
38 |
+
le nom de la collection
|
39 |
+
la liste des documents:
|
40 |
+
chunk
|
41 |
+
source
|
42 |
+
vector
|
43 |
+
'''
|
44 |
+
def __init__(self,
|
45 |
+
name:str,
|
46 |
+
docs:list[Document],
|
47 |
+
idc:int):
|
48 |
+
self.name = name
|
49 |
+
self.docs = docs
|
50 |
+
self.idc = idc
|
51 |
+
|
52 |
+
def add_document(self, chunk:str, source:str, vec:list[float])->Document:
|
53 |
+
'''
|
54 |
+
Ajoute un document à la collection
|
55 |
+
Args:
|
56 |
+
chunk: le texte du document
|
57 |
+
source: la source du document (livre, chap...)
|
58 |
+
vec: la représentation vectorielle du document
|
59 |
+
Returns:
|
60 |
+
un Document ou None si problème rencontré
|
61 |
+
Raise:
|
62 |
+
si un des paramètres n'est pas défini
|
63 |
+
'''
|
64 |
+
if chunk == None or source == None or vec == None:
|
65 |
+
raise Exception("Document error: chunk, source or vec is None !")
|
66 |
+
idd:int = len(self.docs) + 1
|
67 |
+
doc:Document = Document(chunk, source, vec, idd)
|
68 |
+
self.docs.append(doc)
|
69 |
+
return doc
|
70 |
+
|
71 |
+
def get_length_octets(self)->int:
|
72 |
+
'''
|
73 |
+
Return la taille en octets de la collection
|
74 |
+
'''
|
75 |
+
if len(self.docs) == 0:
|
76 |
+
return 0
|
77 |
+
vector_size = len(self.docs[0])
|
78 |
+
return len(self.docs) * vector_size * 4 # un float sur 4 octets
|
79 |
+
|
80 |
+
@classmethod
|
81 |
+
def from_disk(self, file_path:str):
|
82 |
+
'''
|
83 |
+
Méthode de classe qui renvoie une Collection à partir d'un fichier de la base
|
84 |
+
Args:
|
85 |
+
file_path: le chemin vers le fichier
|
86 |
+
Return:
|
87 |
+
la Collection
|
88 |
+
Exception:
|
89 |
+
si le fichier n'existe pas ou qu'on ne peut pas le lire
|
90 |
+
'''
|
91 |
+
if not os.path.exists(file_path):
|
92 |
+
raise Exception("File {file} doesn't exist !".format(file=file_path))
|
93 |
+
idc:int = int(Path(file_path).stem)
|
94 |
+
# print("Collection.from_disk, reading : ", idc)
|
95 |
+
try:
|
96 |
+
with open(file_path, "r") as f:
|
97 |
+
datas = json.load(f)
|
98 |
+
name:str = datas['name']
|
99 |
+
docs = []
|
100 |
+
idd: int = 1
|
101 |
+
for d in datas['docs']:
|
102 |
+
doc:Document = Document(d['c'], d['s'], d['v'], idd)
|
103 |
+
docs.append(doc)
|
104 |
+
idd += 1
|
105 |
+
return Collection(name, docs, idc)
|
106 |
+
except:
|
107 |
+
raise Exception("Unable to read {file_path} !".format(file_path=file_path))
|
108 |
+
|
109 |
+
def save(self, persist_dir:str):
|
110 |
+
'''
|
111 |
+
La collection est enregistrée avec le nom idc.col dans le persist_dir
|
112 |
+
Args:
|
113 |
+
persist_dir: le chemin du repertoire de la bdd
|
114 |
+
Exception:
|
115 |
+
Si on ne peut pas sauver sur le disque
|
116 |
+
'''
|
117 |
+
|
118 |
+
file_path:str = os.path.join(persist_dir, str(self.idc)) + ".col"
|
119 |
+
# print("Collection.save : ", file_path)
|
120 |
+
json_object = {
|
121 |
+
'name':self.name,
|
122 |
+
'docs':[]
|
123 |
+
}
|
124 |
+
for doc in self.docs:
|
125 |
+
json_object['docs'].append(doc.get_json())
|
126 |
+
json_object = json.dumps(json_object)
|
127 |
+
try:
|
128 |
+
with open(file_path, "w+") as f:
|
129 |
+
f.write(json_object)
|
130 |
+
except:
|
131 |
+
raise Exception("Unable to save the collection {name}, id={id} !".format(name=self.name, id=self.idc))
|
132 |
+
|
133 |
+
def delete(self, persist_dir:str)->None:
|
134 |
+
'''
|
135 |
+
Supprime la collection de la bdd
|
136 |
+
Args:
|
137 |
+
persist_dir: le chemin du repertoire de la bdd
|
138 |
+
Exception:
|
139 |
+
Si on ne peut pas supprimer du disque
|
140 |
+
'''
|
141 |
+
self.docs.clear()
|
142 |
+
file_path:str = os.path.join(persist_dir, str(self.idc)) + ".col"
|
143 |
+
try:
|
144 |
+
os.remove(file_path)
|
145 |
+
except:
|
146 |
+
raise Exception("Unable to delete the collection {name}, id={id} !".format(name=self.name, id=self.idc))
|
147 |
+
|
148 |
+
class Store(AStore):
|
149 |
+
'''
|
150 |
+
Un store est une liste de collections.
|
151 |
+
A chaque création, ajout ou suppression d'un élément, la base est sauvée si elle est persistante
|
152 |
+
Sur le disque, dans store_dir:
|
153 |
+
Un sous-repertoire par collection, portant le nom de la collection
|
154 |
+
Dans chaque sous-repertoire d'une collection : la liste des vecteurs
|
155 |
+
|
156 |
+
'''
|
157 |
+
def __init__(self, persist_dir:str):
|
158 |
+
''' Constructeur de Store
|
159 |
+
Args:
|
160 |
+
dir_name: le répertoire persistant de la base de données ou None
|
161 |
+
Exception:
|
162 |
+
Dans le cas d'une base persistante:
|
163 |
+
Impossible de créer le répertoire persistant
|
164 |
+
Impossible de lire les collections
|
165 |
+
'''
|
166 |
+
self.persist_dir = persist_dir
|
167 |
+
self.collections = []
|
168 |
+
if persist_dir == None: # store éphémère
|
169 |
+
pass # Rien à faire
|
170 |
+
else:
|
171 |
+
# Charger la liste des collections
|
172 |
+
try:
|
173 |
+
self._create_persist_dir()
|
174 |
+
files = [os.path.join(persist_dir, f) for f in os.listdir(persist_dir) if os.path.isfile(os.path.join(persist_dir, f))]
|
175 |
+
for f in files:
|
176 |
+
col: Collection = Collection.from_disk(f)
|
177 |
+
self.collections.append(col)
|
178 |
+
except Exception as e:
|
179 |
+
raise
|
180 |
+
|
181 |
+
def reset(self)->None:
|
182 |
+
'''
|
183 |
+
Vide la base et l'efface du disque si elle est persistante
|
184 |
+
Exception:
|
185 |
+
Dans le cas d'une base persistante:
|
186 |
+
Impossible de créer le répertoire persistant
|
187 |
+
Impossible de lire les collections
|
188 |
+
'''
|
189 |
+
self.collections = []
|
190 |
+
if self.persist_dir == None: # store éphémère
|
191 |
+
pass
|
192 |
+
else:
|
193 |
+
try:
|
194 |
+
# Supprimer les fichiers du disque
|
195 |
+
if os.path.exists(self.persist_dir):
|
196 |
+
files = [os.path.join(self.persist_dir, f) for f in os.listdir(self.persist_dir) if os.path.isfile(os.path.join(self.persist_dir, f))]
|
197 |
+
# print(files)
|
198 |
+
for f in files:
|
199 |
+
os.remove(f)
|
200 |
+
os.rmdir(self.persist_dir)
|
201 |
+
except Exception as e:
|
202 |
+
raise
|
203 |
+
|
204 |
+
def get_collection_names(self)->list[str]:
|
205 |
+
return [col.name for col in self.collections]
|
206 |
+
|
207 |
+
def print_infos(self)->None:
|
208 |
+
''' Affiche le nombre de collections et pour chaque collection, affiche son nom et son nombre de documents '''
|
209 |
+
print("-------- STORE INFOS ---------------")
|
210 |
+
for col in self.collections:
|
211 |
+
print(col.name)
|
212 |
+
# idds = [doc.idd for doc in col.docs]
|
213 |
+
# print("\t", idds)
|
214 |
+
print("\tdocuments:", len(col.docs))
|
215 |
+
print("-------- /STORE INFOS ---------------")
|
216 |
+
|
217 |
+
def get_collection(self, collection_name:str)->Collection:
|
218 |
+
'''
|
219 |
+
Renvoie la collection dont le nom est 'collection_name' ou None si elle n'existe pas
|
220 |
+
'''
|
221 |
+
for col in self.collections:
|
222 |
+
if col.name == collection_name:
|
223 |
+
return col
|
224 |
+
return None
|
225 |
+
|
226 |
+
def _create_persist_dir(self):
|
227 |
+
'''
|
228 |
+
Recrée le répertoir persistant s'il a disparu après un reset par exemple
|
229 |
+
Exception:
|
230 |
+
Si on ne peut pas créer le 'persist_dir'
|
231 |
+
'''
|
232 |
+
# Vérifier si le persist_dir existe, sinon le créer
|
233 |
+
print("Persist_dir:" + self.persist_dir)
|
234 |
+
try:
|
235 |
+
if not os.path.exists(self.persist_dir):
|
236 |
+
os.mkdir(self.persist_dir)
|
237 |
+
except:
|
238 |
+
raise Exception("Unable to create the persit directory: {dir}".format(dir=self.persist_dir))
|
239 |
+
|
240 |
+
def create_collection(self, name:str)->Collection:
|
241 |
+
'''
|
242 |
+
Crée et renvoie une nouvelle collection vide de documents
|
243 |
+
Args:
|
244 |
+
name: le nom de la création à créer
|
245 |
+
Exception:
|
246 |
+
Dans le cas d'une base persistante:
|
247 |
+
Impossible de créer le répertoire persistant
|
248 |
+
Impossible de sauver la collection
|
249 |
+
'''
|
250 |
+
idc:int = len(self.collections) + 1
|
251 |
+
col:Collection = Collection(name, [], idc)
|
252 |
+
if self.persist_dir != None:
|
253 |
+
try:
|
254 |
+
self._create_persist_dir()
|
255 |
+
col.save(self.persist_dir)
|
256 |
+
except:
|
257 |
+
raise
|
258 |
+
return col
|
259 |
+
|
260 |
+
|
261 |
+
def add_to_collection(self, collection_name:str, source:str, vectors:list[list[float]], chunks:list[str])->None:
|
262 |
+
'''
|
263 |
+
Ajoute une liste de vecteurs à la collection 'collection_name'
|
264 |
+
Args:
|
265 |
+
collection_name: le nom de la collection
|
266 |
+
source: la source unique des chunks, par exemple un nom de fichier, une url ...
|
267 |
+
vectors: la liste des vecteurs obtenus à l'aide d'un modèle d'embeddings
|
268 |
+
chunks: la liste des chunks (documents) correspondant aux vecteurs
|
269 |
+
Exception:
|
270 |
+
Dans le cas d'une base persistante:
|
271 |
+
Impossible de créer le répertoire persistant
|
272 |
+
Impossible de sauver la collection
|
273 |
+
'''
|
274 |
+
col:Collection = self.get_collection(collection_name)
|
275 |
+
if col == None:
|
276 |
+
col = self.create_collection(collection_name)
|
277 |
+
self.collections.append(col)
|
278 |
+
for i in range(len(chunks)):
|
279 |
+
col.add_document(chunks[i], source, vectors[i])
|
280 |
+
if self.persist_dir != None:
|
281 |
+
try:
|
282 |
+
self._create_persist_dir()
|
283 |
+
col.save(self.persist_dir)
|
284 |
+
except:
|
285 |
+
raise
|
286 |
+
|
287 |
+
def delete_collection(self, name:str)->None:
|
288 |
+
''' Vide et supprime la collection dont le nom est 'name', et la supprime du disque si elle est persistante '''
|
289 |
+
col = self.get_collection(name)
|
290 |
+
if col != None:
|
291 |
+
self.collections.remove(col)
|
292 |
+
if self.persist_dir != None:
|
293 |
+
try:
|
294 |
+
self._create_persist_dir()
|
295 |
+
col.delete(self.persist_dir)
|
296 |
+
except:
|
297 |
+
raise
|
298 |
+
|
299 |
+
def normalize(self, v:list[float])->list[float]:
|
300 |
+
'''
|
301 |
+
Normalement les LLMs renvoient des vecteurs normalisés mais:
|
302 |
+
c'est pas sûr pour ceux que je n'ai pas testés
|
303 |
+
c'est pratique d'avoir cette méthode pour 'test_store.py'
|
304 |
+
Args:
|
305 |
+
v: le vecteur à normaliser
|
306 |
+
Returns:
|
307 |
+
le vecteur normalisé
|
308 |
+
'''
|
309 |
+
norm = 0.0
|
310 |
+
for i in range(len(v)):
|
311 |
+
norm += v[i] * v[i]
|
312 |
+
norm = sqrt(norm)
|
313 |
+
if norm == 0.0:
|
314 |
+
return v.copy()
|
315 |
+
result = [None] * len(v)
|
316 |
+
for i in range(len(v)):
|
317 |
+
result[i] = v[i] / norm
|
318 |
+
return result
|
319 |
+
|
320 |
+
def dot_product(self, v1:list[float], v2:list[float])->float:
|
321 |
+
'''
|
322 |
+
Le produit scalaire est utilisé pour une similarité en cosinus:
|
323 |
+
cos(a) = (vecA dot vecB) / (A.B)
|
324 |
+
si les vecteurs A et B sont normalisés, le cos est simplement le produit scalaire
|
325 |
+
Args:
|
326 |
+
v1, v2: les deux vecteurs à multiplier
|
327 |
+
Returns:
|
328 |
+
Un float égal à v1 dot v2
|
329 |
+
'''
|
330 |
+
result = 0.0
|
331 |
+
for i in range(len(v1)):
|
332 |
+
result += v1[i] * v2[i]
|
333 |
+
return result
|
334 |
+
|
335 |
+
def get_similar_vector(self, vector:list[float], collection_name:str)->list[float]:
|
336 |
+
'''
|
337 |
+
Renvoie le vecteur de 'collection' le pus similaire à 'vector'.
|
338 |
+
Args:
|
339 |
+
vector: un vecteur obtenu avec le même modèle d'embeddings que les vecteurs de la 'collection'
|
340 |
+
collection_name: le nom de la collection de la base dans laquelle on cherche une similarité
|
341 |
+
Return:
|
342 |
+
Le vecteur le plus similaire 'vector'
|
343 |
+
'''
|
344 |
+
col:Collection = self.get_collection(collection_name)
|
345 |
+
best_doc:Document = None
|
346 |
+
best_dp: float = -20.0
|
347 |
+
if col != None:
|
348 |
+
for doc in col.docs:
|
349 |
+
dp:float = self.dot_product(vector, doc.vec)
|
350 |
+
if dp > best_dp:
|
351 |
+
best_dp = dp
|
352 |
+
best_doc = doc
|
353 |
+
return best_doc.vec
|
354 |
+
else:
|
355 |
+
return None
|
356 |
+
|
357 |
+
def get_similar_chunk(self, query_vector:list[float], collection_name:str)->tuple[str, str]:
|
358 |
+
'''
|
359 |
+
Renvoie le document de la 'collection' le plus similaire à 'query_vector'.
|
360 |
+
Args:
|
361 |
+
query_vector: un vecteur obtenu avec le même modèle d'embeddings que les vecteurs de la 'collection'
|
362 |
+
collection: la collection de la base dans laquelle on cherche une similarité
|
363 |
+
Returns:
|
364 |
+
Un tuple contenant:
|
365 |
+
le document
|
366 |
+
la source du document
|
367 |
+
'''
|
368 |
+
col:Collection = self.get_collection(collection_name)
|
369 |
+
best_doc:Document = None
|
370 |
+
best_dp: float = -20.0
|
371 |
+
if col != None:
|
372 |
+
for doc in col.docs:
|
373 |
+
dp:float = self.dot_product(query_vector, doc.vec)
|
374 |
+
print(dp)
|
375 |
+
if dp > best_dp:
|
376 |
+
best_dp = dp
|
377 |
+
best_doc = doc
|
378 |
+
return best_doc.chunk, best_doc.source
|
379 |
+
else:
|
380 |
+
return None, None
|
381 |
+
|
382 |
+
def get_similar_chunks(self, query_vector:list[float], count:int, collection_name:str):
|
383 |
+
'''
|
384 |
+
Returns:
|
385 |
+
Un tuple contenant:
|
386 |
+
les documents
|
387 |
+
la source des documents
|
388 |
+
les ids des documents
|
389 |
+
a[0:count-1]
|
390 |
+
'''
|
391 |
+
# start:int = time.time()
|
392 |
+
col:Collection = self.get_collection(collection_name)
|
393 |
+
if col == None:
|
394 |
+
return None, None, None
|
395 |
+
bests:list[dict] = []
|
396 |
+
# Ajouter tous les docs avec leur dotproduct à la liste bests
|
397 |
+
for doc in col.docs:
|
398 |
+
dp:float = self.dot_product(query_vector, doc.vec)
|
399 |
+
bests.append({'doc':doc, 'dp':dp})
|
400 |
+
# Trier la liste en reverse à partir de la clé 'dp'
|
401 |
+
bests.sort(key=operator.itemgetter('dp'), reverse=True)
|
402 |
+
# Adapter le nombre de documents à renvoyer s'il n'y a pas assez de chunks
|
403 |
+
n:int = count if len(bests) >= count else len(bests)
|
404 |
+
# print("get_similar_chunks, count=", count, ", n=", n)
|
405 |
+
# Créer les variables de retour
|
406 |
+
docs = [b['doc'].chunk for b in bests[0:n]]
|
407 |
+
source = bests[0]['doc'].source if n > 0 else None
|
408 |
+
ids = [b['doc'].idd for b in bests[0:n]]
|
409 |
+
# print("my_store.get_similar_chunks:", time.time() - start, "s")
|
410 |
+
return docs, source, ids
|
411 |
+
|