|
import logging
|
|
import torch
|
|
import numpy as np
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from typing import Union, Optional, Tuple, List
|
|
from pydantic import BaseModel
|
|
from tqdm import tqdm
|
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
from transformers import ModernBertModel, ModernBertPreTrainedModel, ModernBertConfig
|
|
|
|
|
|
class TextSpan(BaseModel):
|
|
s: int
|
|
e: int
|
|
module_name: str
|
|
text: Optional[str] = None
|
|
|
|
|
|
class Instance(BaseModel):
|
|
original_text: str
|
|
text_spans: List[TextSpan]
|
|
|
|
|
|
def recursive_split(text, chunk_size=256, chunk_overlap=32):
|
|
""" recursive split a text by RecursiveCharacterTextSplitter in langchain_text_splitters """
|
|
splitter = RecursiveCharacterTextSplitter(
|
|
chunk_size=chunk_size,
|
|
chunk_overlap=chunk_overlap,
|
|
length_function=lambda x: len(x.split()),
|
|
separators=["\n\n", "\n", ". ", "? ", "! ", "; "],
|
|
)
|
|
chunks = splitter.split_text(text)
|
|
if not chunks:
|
|
logging.error(f"Error, chunks is empty, text:{text}")
|
|
return [text], [[0, len(text)]]
|
|
chunk_span = [
|
|
|
|
[text.find(chunk), text.find(chunk) + len(chunk)]
|
|
for chunk in chunks
|
|
]
|
|
assert chunk_span[0][0] == 0
|
|
assert all((span[0] >= 0 for span in chunk_span))
|
|
return chunks, chunk_span
|
|
|
|
|
|
def make_batch_input_for_prediction(
|
|
texts: List[str],
|
|
tokenizer,
|
|
max_seq_length: int,
|
|
chunk_size=256,
|
|
chunk_overlap=32,
|
|
prompt: str = "",
|
|
fast_chunk: bool = False,
|
|
batch_text_spans: List[List[TextSpan]] = None,
|
|
):
|
|
""" prepare input"""
|
|
if batch_text_spans is not None:
|
|
ipt = tokenizer(
|
|
[prompt + i for i in texts],
|
|
padding="longest",
|
|
truncation=True,
|
|
max_length=max_seq_length,
|
|
return_tensors="pt"
|
|
)
|
|
for text_spans, data_len in zip(batch_text_spans, ipt["attention_mask"].sum(dim=1)):
|
|
for text_span in text_spans:
|
|
assert -1 < text_span.s < text_span.e <= data_len
|
|
ipt["batch_text_spans"] = batch_text_spans
|
|
return ipt
|
|
prompt_len = len(tokenizer.tokenize(prompt))
|
|
truncated_texts = [
|
|
tokenizer.decode(
|
|
tokenizer.encode(text)[:max_seq_length - prompt_len - 2],
|
|
skip_special_tokens=True,
|
|
clean_up_tokenization_spaces=True
|
|
).strip()
|
|
for text in texts
|
|
]
|
|
ipt = tokenizer(
|
|
[prompt + i for i in truncated_texts],
|
|
padding="longest",
|
|
truncation=True,
|
|
max_length=max_seq_length,
|
|
return_tensors="pt"
|
|
)
|
|
batch_text_spans = []
|
|
for text, data_len in zip(truncated_texts, ipt["attention_mask"].sum(dim=1)):
|
|
text_spans = [
|
|
TextSpan(
|
|
s=0,
|
|
e=1,
|
|
module_name="cls_linear",
|
|
),
|
|
TextSpan(
|
|
s=1 + prompt_len,
|
|
e=data_len - 1,
|
|
module_name="chunk_linear",
|
|
),
|
|
]
|
|
|
|
if chunk_size > 1 and chunk_overlap > -1:
|
|
|
|
if fast_chunk:
|
|
start_pos, end_pos = 1 + prompt_len, data_len - 1
|
|
for s in range(start_pos, end_pos, chunk_size):
|
|
s -= chunk_overlap
|
|
s = max((s, start_pos))
|
|
e = min((s + chunk_size, end_pos))
|
|
if e - s > 0 and not (s == start_pos and e == end_pos):
|
|
text_spans.append(
|
|
TextSpan(
|
|
s=s,
|
|
e=e,
|
|
module_name="chunk_linear",
|
|
)
|
|
)
|
|
|
|
else:
|
|
chunks, chunk_span = recursive_split(text, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
|
if len(chunks) > 1:
|
|
for (s, e), chunk in zip(chunk_span, chunks):
|
|
s = len(tokenizer.tokenize(text[:s])) + 1 + prompt_len
|
|
e = len(tokenizer.tokenize(text[:e])) + 1 + prompt_len
|
|
if s >= e:
|
|
continue
|
|
|
|
text_spans.append(
|
|
TextSpan(
|
|
s=s,
|
|
e=e,
|
|
module_name="chunk_linear",
|
|
text=chunk
|
|
)
|
|
)
|
|
|
|
batch_text_spans.append(text_spans)
|
|
ipt["batch_text_spans"] = batch_text_spans
|
|
return ipt
|
|
|
|
|
|
class DeweyV1(ModernBertPreTrainedModel):
|
|
def __init__(self, config: ModernBertConfig):
|
|
super().__init__(config)
|
|
self.config = config
|
|
self.model = ModernBertModel(config)
|
|
hidden_size = config.hidden_size
|
|
vector_size = config.vector_size
|
|
self.linear_dict = nn.ModuleDict(
|
|
{
|
|
"cls_linear": nn.Linear(hidden_size, vector_size, bias=True),
|
|
"chunk_linear": nn.Linear(hidden_size, vector_size, bias=True),
|
|
}
|
|
)
|
|
|
|
self.post_init()
|
|
|
|
def get_multi_vectors(
|
|
self,
|
|
batch_token_embeddings: torch.Tensor,
|
|
batch_text_spans: List[List[TextSpan]],
|
|
normalize_embeddings: bool = True
|
|
) -> List[torch.Tensor]:
|
|
multi_vectors = []
|
|
for token_embeddings, text_spans in zip(batch_token_embeddings, batch_text_spans):
|
|
chunk_vectors = []
|
|
for text_span in text_spans:
|
|
s, e = text_span.s, text_span.e
|
|
if s >= token_embeddings.shape[0] or s >= e:
|
|
logging.warning(
|
|
f"given span is wrong, s, e, token_embeddings.shape: {s, e, token_embeddings.shape}",
|
|
)
|
|
s, e = 0, 1
|
|
mean_tokens_embs = token_embeddings[s:e, :].mean(dim=0, keepdim=True)
|
|
|
|
|
|
chunk_vectors.append(
|
|
self.linear_dict[text_span.module_name](mean_tokens_embs),
|
|
)
|
|
chunk_vectors = torch.cat(chunk_vectors, dim=0)
|
|
if normalize_embeddings:
|
|
multi_vectors.append(F.normalize(chunk_vectors, p=2, dim=-1))
|
|
else:
|
|
multi_vectors.append(chunk_vectors)
|
|
return multi_vectors
|
|
|
|
def forward(
|
|
self,
|
|
input_ids: torch.Tensor,
|
|
attention_mask: torch.Tensor,
|
|
batch_text_spans: List[List[TextSpan]],
|
|
normalize_embeddings: bool = True,
|
|
*args,
|
|
**kwargs
|
|
) -> List[torch.Tensor]:
|
|
batch_token_embeddings = self.model(input_ids=input_ids, attention_mask=attention_mask)[0]
|
|
multi_vectors = self.get_multi_vectors(
|
|
batch_token_embeddings=batch_token_embeddings,
|
|
batch_text_spans=batch_text_spans,
|
|
normalize_embeddings=normalize_embeddings
|
|
)
|
|
return multi_vectors
|
|
|
|
@torch.no_grad()
|
|
def encode(
|
|
self,
|
|
sentences: str | list[str],
|
|
batch_size: int = 32,
|
|
use_cuda: bool = True,
|
|
show_progress_bar: bool = True,
|
|
chunk_size: int = 256,
|
|
chunk_overlap: int = 32,
|
|
convert_to_tensor: bool = False,
|
|
max_seq_length: int = 8192,
|
|
normalize_embeddings: bool = True,
|
|
prompt: str = "",
|
|
fast_chunk: bool = False,
|
|
batch_text_spans: List[List[TextSpan]] = None,
|
|
*args,
|
|
**kwargs
|
|
) -> Tuple[List[Union[np.ndarray, torch.Tensor]] | torch.Tensor | np.ndarray, List[List[TextSpan]]]:
|
|
"""
|
|
encode sentences to multi vectors
|
|
Args:
|
|
sentences: str | list[str], The sentences to embed
|
|
batch_size: int
|
|
use_cuda: bool, Whether to use GPU for inference
|
|
show_progress_bar: bool, Whether to display the progress bar
|
|
chunk_size: int, the number tokens of chunk, The recommended size is between 64-1024. The larger the value,
|
|
the faster the speed, but the effect may decrease. The smaller the value, the slower the speed,
|
|
and when the value is very small, the effect may also decrease.
|
|
chunk_overlap: int, Overlap in characters between chunks
|
|
convert_to_tensor: bool, If true: convert to torch fp32 tensor, otherwise will return fp32 ndarray
|
|
max_seq_length: int, max length of text
|
|
normalize_embeddings: bool, whether to do a L2-normalize for vectors
|
|
prompt: str, the prompt for text, the final text to be encoded is "[CLS]{prompt}{sentence}[SEP]",
|
|
Note, you CANNOT manually add a prompt before the sentence yourself, as this will affect our length calculation!
|
|
fast_chunk: bool, if true, directly chunk on input ids, else using RecursiveCharacterTextSplitter
|
|
batch_text_spans: List[List[TextSpan]], default is None, if provided, the model will not chunk text anymore
|
|
*args:
|
|
**kwargs:
|
|
|
|
Returns:
|
|
List[tensor|ndarray], each text's multi vectors
|
|
"""
|
|
self.eval()
|
|
|
|
if isinstance(sentences, str):
|
|
sentences = [sentences]
|
|
deduplicate_sentences = list(set(sentences))
|
|
deduplicate_sentences.sort(key=lambda x: len(x), reverse=True)
|
|
|
|
vectors_list, text_spans = [], []
|
|
for start in tqdm(
|
|
range(0, len(deduplicate_sentences), batch_size),
|
|
desc="encoding text...",
|
|
disable=not show_progress_bar
|
|
):
|
|
batch = deduplicate_sentences[start:start + batch_size]
|
|
ipt = make_batch_input_for_prediction(
|
|
batch,
|
|
tokenizer=self.tokenizer,
|
|
max_seq_length=max_seq_length,
|
|
chunk_size=chunk_size,
|
|
chunk_overlap=chunk_overlap,
|
|
prompt=prompt,
|
|
fast_chunk=fast_chunk,
|
|
batch_text_spans=batch_text_spans
|
|
)
|
|
text_spans.extend(ipt["batch_text_spans"])
|
|
ipt = {k: v.cuda() if use_cuda and isinstance(v, torch.Tensor) else v for k, v in ipt.items()}
|
|
vectors_list.extend(self(**ipt, normalize_embeddings=normalize_embeddings))
|
|
|
|
assert len(deduplicate_sentences) == len(vectors_list)
|
|
sen2vecs = dict(zip(deduplicate_sentences, vectors_list))
|
|
sen2spans = dict(zip(deduplicate_sentences, text_spans))
|
|
|
|
text_spans = [sen2spans[sen] for sen in sentences]
|
|
if convert_to_tensor:
|
|
result = [sen2vecs[sen].cpu().float() for sen in sentences]
|
|
else:
|
|
result = [sen2vecs[sen].cpu().float().numpy() for sen in sentences]
|
|
return result, text_spans
|
|
|