MuseVSpace / MuseV /musev /models /text_model.py
anchorxia's picture
add musev
96d7ad8
from typing import Any, Dict
from torch import nn
class TextEmbExtractor(nn.Module):
def __init__(self, tokenizer, text_encoder) -> None:
super(TextEmbExtractor, self).__init__()
self.tokenizer = tokenizer
self.text_encoder = text_encoder
def forward(
self,
texts,
text_params: Dict = None,
):
if text_params is None:
text_params = {}
special_prompt_input = self.tokenizer(
texts,
max_length=self.tokenizer.model_max_length,
padding="max_length",
truncation=True,
return_tensors="pt",
)
if (
hasattr(self.text_encoder.config, "use_attention_mask")
and self.text_encoder.config.use_attention_mask
):
attention_mask = special_prompt_input.attention_mask.to(
self.text_encoder.device
)
else:
attention_mask = None
embeddings = self.text_encoder(
special_prompt_input.input_ids.to(self.text_encoder.device),
attention_mask=attention_mask,
**text_params
)
return embeddings