|
import torch |
|
import torch.nn as nn |
|
from transformers import AutoTokenizer, T5EncoderModel, T5ForConditionalGeneration |
|
|
|
|
|
class MT5Embedder(nn.Module): |
|
available_models = ["t5-v1_1-xxl"] |
|
|
|
def __init__( |
|
self, |
|
model_dir="t5-v1_1-xxl", |
|
model_kwargs=None, |
|
torch_dtype=None, |
|
use_tokenizer_only=False, |
|
conditional_generation=False, |
|
max_length=128, |
|
): |
|
super().__init__() |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.torch_dtype = torch_dtype or torch.bfloat16 |
|
self.max_length = max_length |
|
if model_kwargs is None: |
|
model_kwargs = { |
|
|
|
"torch_dtype": self.torch_dtype, |
|
} |
|
model_kwargs["device_map"] = {"shared": self.device, "encoder": self.device} |
|
self.tokenizer = AutoTokenizer.from_pretrained(model_dir) |
|
if use_tokenizer_only: |
|
return |
|
if conditional_generation: |
|
self.model = None |
|
self.generation_model = T5ForConditionalGeneration.from_pretrained( |
|
model_dir |
|
) |
|
return |
|
self.model = T5EncoderModel.from_pretrained(model_dir, **model_kwargs).eval().to(self.torch_dtype) |
|
|
|
def get_tokens_and_mask(self, texts): |
|
text_tokens_and_mask = self.tokenizer( |
|
texts, |
|
max_length=self.max_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_attention_mask=True, |
|
add_special_tokens=True, |
|
return_tensors="pt", |
|
) |
|
tokens = text_tokens_and_mask["input_ids"][0] |
|
mask = text_tokens_and_mask["attention_mask"][0] |
|
|
|
|
|
return tokens, mask |
|
|
|
def get_text_embeddings(self, texts, attention_mask=True, layer_index=-1): |
|
text_tokens_and_mask = self.tokenizer( |
|
texts, |
|
max_length=self.max_length, |
|
padding="max_length", |
|
truncation=True, |
|
return_attention_mask=True, |
|
add_special_tokens=True, |
|
return_tensors="pt", |
|
) |
|
|
|
with torch.no_grad(): |
|
outputs = self.model( |
|
input_ids=text_tokens_and_mask["input_ids"].to(self.device), |
|
attention_mask=text_tokens_and_mask["attention_mask"].to(self.device) |
|
if attention_mask |
|
else None, |
|
output_hidden_states=True, |
|
) |
|
text_encoder_embs = outputs["hidden_states"][layer_index].detach() |
|
|
|
return text_encoder_embs, text_tokens_and_mask["attention_mask"].to(self.device) |
|
|
|
@torch.no_grad() |
|
def __call__(self, tokens, attention_mask, layer_index=-1): |
|
with torch.cuda.amp.autocast(): |
|
outputs = self.model( |
|
input_ids=tokens, |
|
attention_mask=attention_mask, |
|
output_hidden_states=True, |
|
) |
|
|
|
z = outputs.hidden_states[layer_index].detach() |
|
return z |
|
|
|
def general(self, text: str): |
|
|
|
input_ids = self.tokenizer(text, max_length=128).input_ids |
|
print(input_ids) |
|
outputs = self.generation_model(input_ids) |
|
return outputs |