import torch from typing import Optional from pydantic import BaseModel from sentence_transformers.models import Transformer as BaseTransformer class TextSpan(BaseModel): s: int e: int module_name: str text: Optional[str] = None class DeweyTransformer(BaseTransformer): def __init__( self, model_name_or_path: str, **kwargs, ): self.single_vector_type = kwargs.get("config_args", {}).get("single_vector_type", "mean") super().__init__(model_name_or_path, **kwargs) def forward( self, features: dict[str, torch.Tensor], **kwargs ) -> dict[str, torch.Tensor]: prompt_length = features.get("prompt_length", 0) if prompt_length > 0: # in MondernBert, text is surrounded by [CLS] and [SEP] prompt_length -= 1 batch_text_spans = [] for data_len in features["attention_mask"].sum(dim=1): if self.single_vector_type == "cls": batch_text_spans.append( [ TextSpan(s=0, e=1, module_name="cls_linear") ] ) elif self.single_vector_type == "mean": batch_text_spans.append( [ TextSpan(s=1 + prompt_length, e=data_len - 1, module_name="chunk_linear") ] ) elif self.single_vector_type == "cls_add_mean": batch_text_spans.append( [ TextSpan(s=0, e=1, module_name="cls_linear"), TextSpan(s=1 + prompt_length, e=data_len - 1, module_name="chunk_linear") ] ) else: raise Exception("single_vector_type should be in {cls, mean or cls_add_mean}") trans_features = { "input_ids": features["input_ids"], "attention_mask": features["attention_mask"], "batch_text_spans": batch_text_spans, "normalize_embeddings": self.single_vector_type == "cls_add_mean", } # print(features["input_ids"].shape) vectors_list = self.auto_model(**trans_features, **kwargs) sentence_embedding = torch.cat( [vecs.mean(dim=0, keepdim=True) for vecs in vectors_list], dim=0 ) features.update({"sentence_embedding": sentence_embedding}) return features