Spaces:
Paused
Paused
import torch | |
from transformers import PreTrainedModel, XLMRobertaConfig, XLMRobertaModel | |
class MCLIPConfig(XLMRobertaConfig): | |
model_type = "M-CLIP" | |
def __init__(self, transformerDimSize=1024, imageDimSize=768, **kwargs): | |
self.transformerDimensions = transformerDimSize | |
self.numDims = imageDimSize | |
super().__init__(**kwargs) | |
class MultilingualCLIP(PreTrainedModel): | |
config_class = MCLIPConfig | |
def __init__(self, config, *args, **kwargs): | |
super().__init__(config, *args, **kwargs) | |
self.transformer = XLMRobertaModel(config) | |
self.LinearTransformation = torch.nn.Linear( | |
in_features=config.transformerDimensions, out_features=config.numDims | |
) | |
def forward(self, input_ids, attention_mask): | |
embs = self.transformer(input_ids=input_ids, attention_mask=attention_mask)[0] | |
embs2 = (embs * attention_mask.unsqueeze(2)).sum(dim=1) / attention_mask.sum(dim=1)[:, None] | |
return self.LinearTransformation(embs2), embs | |