|
import torch |
|
from typing import Optional |
|
from pydantic import BaseModel |
|
from sentence_transformers.models import Transformer as BaseTransformer |
|
|
|
|
|
class TextSpan(BaseModel): |
|
s: int |
|
e: int |
|
module_name: str |
|
text: Optional[str] = None |
|
|
|
|
|
class DeweyTransformer(BaseTransformer): |
|
def __init__( |
|
self, |
|
model_name_or_path: str, |
|
**kwargs, |
|
): |
|
self.single_vector_type = kwargs.get("config_args", {}).get("single_vector_type", "mean") |
|
super().__init__(model_name_or_path, **kwargs) |
|
|
|
def forward( |
|
self, features: dict[str, torch.Tensor], **kwargs |
|
) -> dict[str, torch.Tensor]: |
|
prompt_length = features.get("prompt_length", 0) |
|
if prompt_length > 0: |
|
|
|
prompt_length -= 1 |
|
batch_text_spans = [] |
|
for data_len in features["attention_mask"].sum(dim=1): |
|
if self.single_vector_type == "cls": |
|
batch_text_spans.append( |
|
[ |
|
TextSpan(s=0, e=1, module_name="cls_linear") |
|
] |
|
) |
|
elif self.single_vector_type == "mean": |
|
batch_text_spans.append( |
|
[ |
|
TextSpan(s=1 + prompt_length, e=data_len - 1, module_name="chunk_linear") |
|
] |
|
) |
|
elif self.single_vector_type == "cls_add_mean": |
|
batch_text_spans.append( |
|
[ |
|
TextSpan(s=0, e=1, module_name="cls_linear"), |
|
TextSpan(s=1 + prompt_length, e=data_len - 1, module_name="chunk_linear") |
|
] |
|
) |
|
else: |
|
raise Exception("single_vector_type should be in {cls, mean or cls_add_mean}") |
|
|
|
trans_features = { |
|
"input_ids": features["input_ids"], |
|
"attention_mask": features["attention_mask"], |
|
"batch_text_spans": batch_text_spans, |
|
"normalize_embeddings": self.single_vector_type == "cls_add_mean", |
|
} |
|
|
|
vectors_list = self.auto_model(**trans_features, **kwargs) |
|
sentence_embedding = torch.cat( |
|
[vecs.mean(dim=0, keepdim=True) for vecs in vectors_list], |
|
dim=0 |
|
) |
|
features.update({"sentence_embedding": sentence_embedding}) |
|
return features |
|
|