File size: 794 Bytes
0dea13a
3d37cdf
 
 
 
 
 
 
 
 
 
 
 
0dea13a
3d37cdf
ff4f79e
3d37cdf
2ffab58
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
import torch.nn as nn
from .config import CSDConfig
from transformers import PreTrainedModel, CLIPVisionModel

class CSDModel(PreTrainedModel):
    config_class = CSDConfig
    def __init__(self, config: CSDConfig):
        super().__init__(config)
        self.backbone = CLIPVisionModel(config)
        self.out_style = nn.Linear(config.hidden_size, config.style_projection_dim, bias=False)
        self.out_content = nn.Linear(config.hidden_size, config.content_projection_dim, bias=False)

    @torch.inference_mode()
    def forward(self, pixel_values):
        features = self.backbone(pixel_values, return_dict=False)[1]
        style_embeds = self.out_style(features)
        content_embeds = self.out_content(features)
        return features, style_embeds, content_embeds