dewey_en_beta / custom_st.py
infgrad's picture
Upload 9 files
a41c6a1 verified
raw
history blame
2.47 kB
import torch
from typing import Optional
from pydantic import BaseModel
from sentence_transformers.models import Transformer as BaseTransformer
class TextSpan(BaseModel):
s: int
e: int
module_name: str
text: Optional[str] = None
class DeweyTransformer(BaseTransformer):
def __init__(
self,
model_name_or_path: str,
**kwargs,
):
self.single_vector_type = kwargs.get("config_args", {}).get("single_vector_type", "mean")
super().__init__(model_name_or_path, **kwargs)
def forward(
self, features: dict[str, torch.Tensor], **kwargs
) -> dict[str, torch.Tensor]:
prompt_length = features.get("prompt_length", 0)
if prompt_length > 0:
# in MondernBert, text is surrounded by [CLS] and [SEP]
prompt_length -= 1
batch_text_spans = []
for data_len in features["attention_mask"].sum(dim=1):
if self.single_vector_type == "cls":
batch_text_spans.append(
[
TextSpan(s=0, e=1, module_name="cls_linear")
]
)
elif self.single_vector_type == "mean":
batch_text_spans.append(
[
TextSpan(s=1 + prompt_length, e=data_len - 1, module_name="chunk_linear")
]
)
elif self.single_vector_type == "cls_add_mean":
batch_text_spans.append(
[
TextSpan(s=0, e=1, module_name="cls_linear"),
TextSpan(s=1 + prompt_length, e=data_len - 1, module_name="chunk_linear")
]
)
else:
raise Exception("single_vector_type should be in {cls, mean or cls_add_mean}")
trans_features = {
"input_ids": features["input_ids"],
"attention_mask": features["attention_mask"],
"batch_text_spans": batch_text_spans,
"normalize_embeddings": self.single_vector_type == "cls_add_mean",
}
# print(features["input_ids"].shape)
vectors_list = self.auto_model(**trans_features, **kwargs)
sentence_embedding = torch.cat(
[vecs.mean(dim=0, keepdim=True) for vecs in vectors_list],
dim=0
)
features.update({"sentence_embedding": sentence_embedding})
return features