File size: 1,084 Bytes
520a6ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02bba55
520a6ec
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch

from typing import Tuple
from dataclasses import dataclass
from transformers import PretrainedConfig, PreTrainedModel

from .csd import CSD
from .config import CSDConfig


@dataclass
class CSDOutput:
    image_embeds: torch.Tensor
    style_embeds: torch.Tensor
    content_embeds: torch.Tensor


class CSDModel(PreTrainedModel):
    config_class = CSDConfig

    def __init__(self, config: CSDConfig) -> None:
        super(CSDModel, self).__init__(config)

        self.model = CSD(
            vit_input_resolution=config.vit_input_resolution,
            vit_patch_size=config.vit_patch_size,
            vit_width=config.vit_width,
            vit_layers=config.vit_layers,
            vit_heads=config.vit_heads,
            vit_output_dim=config.vit_output_dim,
        )

    @torch.inference_mode()
    def forward(self, pixel_values: torch.Tensor, **kwargs) -> CSDOutput:
        image_embeds, style_embeds, content_embeds = self.model(pixel_values)
        return CSDOutput(image_embeds=image_embeds, style_embeds=style_embeds, content_embeds=content_embeds)