File size: 1,084 Bytes
520a6ec 02bba55 520a6ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 |
import torch
from typing import Tuple
from dataclasses import dataclass
from transformers import PretrainedConfig, PreTrainedModel
from .csd import CSD
from .config import CSDConfig
@dataclass
class CSDOutput:
image_embeds: torch.Tensor
style_embeds: torch.Tensor
content_embeds: torch.Tensor
class CSDModel(PreTrainedModel):
config_class = CSDConfig
def __init__(self, config: CSDConfig) -> None:
super(CSDModel, self).__init__(config)
self.model = CSD(
vit_input_resolution=config.vit_input_resolution,
vit_patch_size=config.vit_patch_size,
vit_width=config.vit_width,
vit_layers=config.vit_layers,
vit_heads=config.vit_heads,
vit_output_dim=config.vit_output_dim,
)
@torch.inference_mode()
def forward(self, pixel_values: torch.Tensor, **kwargs) -> CSDOutput:
image_embeds, style_embeds, content_embeds = self.model(pixel_values)
return CSDOutput(image_embeds=image_embeds, style_embeds=style_embeds, content_embeds=content_embeds)
|