CSD / model.py
vvmatorin's picture
fix: add kwargs to model to ignore additional arguments
02bba55 verified
import torch
from typing import Tuple
from dataclasses import dataclass
from transformers import PretrainedConfig, PreTrainedModel
from .csd import CSD
from .config import CSDConfig
@dataclass
class CSDOutput:
image_embeds: torch.Tensor
style_embeds: torch.Tensor
content_embeds: torch.Tensor
class CSDModel(PreTrainedModel):
config_class = CSDConfig
def __init__(self, config: CSDConfig) -> None:
super(CSDModel, self).__init__(config)
self.model = CSD(
vit_input_resolution=config.vit_input_resolution,
vit_patch_size=config.vit_patch_size,
vit_width=config.vit_width,
vit_layers=config.vit_layers,
vit_heads=config.vit_heads,
vit_output_dim=config.vit_output_dim,
)
@torch.inference_mode()
def forward(self, pixel_values: torch.Tensor, **kwargs) -> CSDOutput:
image_embeds, style_embeds, content_embeds = self.model(pixel_values)
return CSDOutput(image_embeds=image_embeds, style_embeds=style_embeds, content_embeds=content_embeds)