File size: 794 Bytes
0dea13a 3d37cdf 0dea13a 3d37cdf ff4f79e 3d37cdf 2ffab58 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import torch
import torch.nn as nn
from .config import CSDConfig
from transformers import PreTrainedModel, CLIPVisionModel
class CSDModel(PreTrainedModel):
config_class = CSDConfig
def __init__(self, config: CSDConfig):
super().__init__(config)
self.backbone = CLIPVisionModel(config)
self.out_style = nn.Linear(config.hidden_size, config.style_projection_dim, bias=False)
self.out_content = nn.Linear(config.hidden_size, config.content_projection_dim, bias=False)
@torch.inference_mode()
def forward(self, pixel_values):
features = self.backbone(pixel_values, return_dict=False)[1]
style_embeds = self.out_style(features)
content_embeds = self.out_content(features)
return features, style_embeds, content_embeds |