shapermindai's picture
Synced repo using 'sync_with_huggingface' Github Action
a3ffd31 verified
raw
history blame
15.5 kB
import threading
import chromadb
import posthog
import torch
import math
import numpy as np
import extensions.superboogav2.parameters as parameters
from chromadb.config import Settings
from sentence_transformers import SentenceTransformer
from modules.logging_colors import logger
from modules.text_generation import encode, decode
logger.debug('Intercepting all calls to posthog.')
posthog.capture = lambda *args, **kwargs: None
class Collecter():
def __init__(self):
pass
def add(self, texts: list[str], texts_with_context: list[str], starting_indices: list[int]):
pass
def get(self, search_strings: list[str], n_results: int) -> list[str]:
pass
def clear(self):
pass
class Embedder():
def __init__(self):
pass
def embed(self, text: str) -> list[torch.Tensor]:
pass
class Info:
def __init__(self, start_index, text_with_context, distance, id):
self.text_with_context = text_with_context
self.start_index = start_index
self.distance = distance
self.id = id
def calculate_distance(self, other_info):
if parameters.get_new_dist_strategy() == parameters.DIST_MIN_STRATEGY:
# Min
return min(self.distance, other_info.distance)
elif parameters.get_new_dist_strategy() == parameters.DIST_HARMONIC_STRATEGY:
# Harmonic mean
return 2 * (self.distance * other_info.distance) / (self.distance + other_info.distance)
elif parameters.get_new_dist_strategy() == parameters.DIST_GEOMETRIC_STRATEGY:
# Geometric mean
return (self.distance * other_info.distance) ** 0.5
elif parameters.get_new_dist_strategy() == parameters.DIST_ARITHMETIC_STRATEGY:
# Arithmetic mean
return (self.distance + other_info.distance) / 2
else: # Min is default
return min(self.distance, other_info.distance)
def merge_with(self, other_info):
s1 = self.text_with_context
s2 = other_info.text_with_context
s1_start = self.start_index
s2_start = other_info.start_index
new_dist = self.calculate_distance(other_info)
if self.should_merge(s1, s2, s1_start, s2_start):
if s1_start <= s2_start:
if s1_start + len(s1) >= s2_start + len(s2): # if s1 completely covers s2
return Info(s1_start, s1, new_dist, self.id)
else:
overlap = max(0, s1_start + len(s1) - s2_start)
return Info(s1_start, s1 + s2[overlap:], new_dist, self.id)
else:
if s2_start + len(s2) >= s1_start + len(s1): # if s2 completely covers s1
return Info(s2_start, s2, new_dist, other_info.id)
else:
overlap = max(0, s2_start + len(s2) - s1_start)
return Info(s2_start, s2 + s1[overlap:], new_dist, other_info.id)
return None
@staticmethod
def should_merge(s1, s2, s1_start, s2_start):
# Check if s1 and s2 are adjacent or overlapping
s1_end = s1_start + len(s1)
s2_end = s2_start + len(s2)
return not (s1_end < s2_start or s2_end < s1_start)
class ChromaCollector(Collecter):
def __init__(self, embedder: Embedder):
super().__init__()
self.chroma_client = chromadb.Client(Settings(anonymized_telemetry=False))
self.embedder = embedder
self.collection = self.chroma_client.create_collection(name="context", embedding_function=self.embedder.embed)
self.ids = []
self.id_to_info = {}
self.embeddings_cache = {}
self.lock = threading.Lock() # Locking so the server doesn't break.
def add(self, texts: list[str], texts_with_context: list[str], starting_indices: list[int], metadatas: list[dict] = None):
with self.lock:
assert metadatas is None or len(metadatas) == len(texts), "metadatas must be None or have the same length as texts"
if len(texts) == 0:
return
new_ids = self._get_new_ids(len(texts))
(existing_texts, existing_embeddings, existing_ids, existing_metas), \
(non_existing_texts, non_existing_ids, non_existing_metas) = self._split_texts_by_cache_hit(texts, new_ids, metadatas)
# If there are any already existing texts, add them all at once.
if existing_texts:
logger.info(f'Adding {len(existing_embeddings)} cached embeddings.')
args = {'embeddings': existing_embeddings, 'documents': existing_texts, 'ids': existing_ids}
if metadatas is not None:
args['metadatas'] = existing_metas
self.collection.add(**args)
# If there are any non-existing texts, compute their embeddings all at once. Each call to embed has significant overhead.
if non_existing_texts:
non_existing_embeddings = self.embedder.embed(non_existing_texts).tolist()
for text, embedding in zip(non_existing_texts, non_existing_embeddings):
self.embeddings_cache[text] = embedding
logger.info(f'Adding {len(non_existing_embeddings)} new embeddings.')
args = {'embeddings': non_existing_embeddings, 'documents': non_existing_texts, 'ids': non_existing_ids}
if metadatas is not None:
args['metadatas'] = non_existing_metas
self.collection.add(**args)
# Create a dictionary that maps each ID to its context and starting index
new_info = {
id_: {'text_with_context': context, 'start_index': start_index}
for id_, context, start_index in zip(new_ids, texts_with_context, starting_indices)
}
self.id_to_info.update(new_info)
self.ids.extend(new_ids)
def _split_texts_by_cache_hit(self, texts: list[str], new_ids: list[str], metadatas: list[dict]):
existing_texts, non_existing_texts = [], []
existing_embeddings = []
existing_ids, non_existing_ids = [], []
existing_metas, non_existing_metas = [], []
for i, text in enumerate(texts):
id_ = new_ids[i]
metadata = metadatas[i] if metadatas is not None else None
embedding = self.embeddings_cache.get(text)
if embedding:
existing_texts.append(text)
existing_embeddings.append(embedding)
existing_ids.append(id_)
existing_metas.append(metadata)
else:
non_existing_texts.append(text)
non_existing_ids.append(id_)
non_existing_metas.append(metadata)
return (existing_texts, existing_embeddings, existing_ids, existing_metas), \
(non_existing_texts, non_existing_ids, non_existing_metas)
def _get_new_ids(self, num_new_ids: int):
if self.ids:
max_existing_id = max(int(id_) for id_ in self.ids)
else:
max_existing_id = -1
return [str(i + max_existing_id + 1) for i in range(num_new_ids)]
def _find_min_max_start_index(self):
max_index, min_index = 0, float('inf')
for _, val in self.id_to_info.items():
if val['start_index'] > max_index:
max_index = val['start_index']
if val['start_index'] < min_index:
min_index = val['start_index']
return min_index, max_index
# NB: Does not make sense to weigh excerpts from different documents.
# But let's say that's the user's problem. Perfect world scenario:
# Apply time weighing to different documents. For each document, then, add
# separate time weighing.
def _apply_sigmoid_time_weighing(self, infos: list[Info], document_len: int, time_steepness: float, time_power: float):
sigmoid = lambda x: 1 / (1 + np.exp(-x))
weights = sigmoid(time_steepness * np.linspace(-10, 10, document_len))
# Scale to [0,time_power] and shift it up to [1-time_power, 1]
weights = weights - min(weights)
weights = weights * (time_power / max(weights))
weights = weights + (1 - time_power)
# Reverse the weights
weights = weights[::-1]
for info in infos:
index = info.start_index
info.distance *= weights[index]
def _filter_outliers_by_median_distance(self, infos: list[Info], significant_level: float):
# Ensure there are infos to filter
if not infos:
return []
# Find info with minimum distance
min_info = min(infos, key=lambda x: x.distance)
# Calculate median distance among infos
median_distance = np.median([inf.distance for inf in infos])
# Filter out infos that have a distance significantly greater than the median
filtered_infos = [inf for inf in infos if inf.distance <= significant_level * median_distance]
# Always include the info with minimum distance
if min_info not in filtered_infos:
filtered_infos.append(min_info)
return filtered_infos
def _merge_infos(self, infos: list[Info]):
merged_infos = []
current_info = infos[0]
for next_info in infos[1:]:
merged = current_info.merge_with(next_info)
if merged is not None:
current_info = merged
else:
merged_infos.append(current_info)
current_info = next_info
merged_infos.append(current_info)
return merged_infos
# Main function for retrieving chunks by distance. It performs merging, time weighing, and mean filtering.
def _get_documents_ids_distances(self, search_strings: list[str], n_results: int):
n_results = min(len(self.ids), n_results)
if n_results == 0:
return [], [], []
if isinstance(search_strings, str):
search_strings = [search_strings]
infos = []
min_start_index, max_start_index = self._find_min_max_start_index()
for search_string in search_strings:
result = self.collection.query(query_texts=search_string, n_results=math.ceil(n_results / len(search_strings)), include=['distances'])
curr_infos = [Info(start_index=self.id_to_info[id]['start_index'],
text_with_context=self.id_to_info[id]['text_with_context'],
distance=distance, id=id)
for id, distance in zip(result['ids'][0], result['distances'][0])]
self._apply_sigmoid_time_weighing(infos=curr_infos, document_len=max_start_index - min_start_index + 1, time_steepness=parameters.get_time_steepness(), time_power=parameters.get_time_power())
curr_infos = self._filter_outliers_by_median_distance(curr_infos, parameters.get_significant_level())
infos.extend(curr_infos)
infos.sort(key=lambda x: x.start_index)
infos = self._merge_infos(infos)
texts_with_context = [inf.text_with_context for inf in infos]
ids = [inf.id for inf in infos]
distances = [inf.distance for inf in infos]
return texts_with_context, ids, distances
# Get chunks by similarity
def get(self, search_strings: list[str], n_results: int) -> list[str]:
with self.lock:
documents, _, _ = self._get_documents_ids_distances(search_strings, n_results)
return documents
# Get ids by similarity
def get_ids(self, search_strings: list[str], n_results: int) -> list[str]:
with self.lock:
_, ids, _ = self._get_documents_ids_distances(search_strings, n_results)
return ids
# Cutoff token count
def _get_documents_up_to_token_count(self, documents: list[str], max_token_count: int):
# TODO: Move to caller; We add delimiters there which might go over the limit.
current_token_count = 0
return_documents = []
for doc in documents:
doc_tokens = encode(doc)[0]
doc_token_count = len(doc_tokens)
if current_token_count + doc_token_count > max_token_count:
# If adding this document would exceed the max token count,
# truncate the document to fit within the limit.
remaining_tokens = max_token_count - current_token_count
truncated_doc = decode(doc_tokens[:remaining_tokens], skip_special_tokens=True)
return_documents.append(truncated_doc)
break
else:
return_documents.append(doc)
current_token_count += doc_token_count
return return_documents
# Get chunks by similarity and then sort by ids
def get_sorted_by_ids(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]:
with self.lock:
documents, ids, _ = self._get_documents_ids_distances(search_strings, n_results)
sorted_docs = [x for _, x in sorted(zip(ids, documents))]
return self._get_documents_up_to_token_count(sorted_docs, max_token_count)
# Get chunks by similarity and then sort by distance (lowest distance is last).
def get_sorted_by_dist(self, search_strings: list[str], n_results: int, max_token_count: int) -> list[str]:
with self.lock:
documents, _, distances = self._get_documents_ids_distances(search_strings, n_results)
sorted_docs = [doc for doc, _ in sorted(zip(documents, distances), key=lambda x: x[1])] # sorted lowest -> highest
# If a document is truncated or competely skipped, it would be with high distance.
return_documents = self._get_documents_up_to_token_count(sorted_docs, max_token_count)
return_documents.reverse() # highest -> lowest
return return_documents
def delete(self, ids_to_delete: list[str], where: dict):
with self.lock:
ids_to_delete = self.collection.get(ids=ids_to_delete, where=where)['ids']
self.collection.delete(ids=ids_to_delete, where=where)
# Remove the deleted ids from self.ids and self.id_to_info
ids_set = set(ids_to_delete)
self.ids = [id_ for id_ in self.ids if id_ not in ids_set]
for id_ in ids_to_delete:
self.id_to_info.pop(id_, None)
logger.info(f'Successfully deleted {len(ids_to_delete)} records from chromaDB.')
def clear(self):
with self.lock:
self.chroma_client.reset()
self.collection = self.chroma_client.create_collection("context", embedding_function=self.embedder.embed)
self.ids = []
self.id_to_info = {}
logger.info('Successfully cleared all records and reset chromaDB.')
class SentenceTransformerEmbedder(Embedder):
def __init__(self) -> None:
logger.debug('Creating Sentence Embedder...')
self.model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
self.embed = self.model.encode
def make_collector():
return ChromaCollector(SentenceTransformerEmbedder())