|
import torch |
|
|
|
from typing import Tuple |
|
from dataclasses import dataclass |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
from .csd import CSD |
|
from .config import CSDConfig |
|
|
|
|
|
@dataclass |
|
class CSDOutput: |
|
image_embeds: torch.Tensor |
|
style_embeds: torch.Tensor |
|
content_embeds: torch.Tensor |
|
|
|
|
|
class CSDModel(PreTrainedModel): |
|
config_class = CSDConfig |
|
|
|
def __init__(self, config: CSDConfig) -> None: |
|
super(CSDModel, self).__init__(config) |
|
|
|
self.model = CSD( |
|
vit_input_resolution=config.vit_input_resolution, |
|
vit_patch_size=config.vit_patch_size, |
|
vit_width=config.vit_width, |
|
vit_layers=config.vit_layers, |
|
vit_heads=config.vit_heads, |
|
vit_output_dim=config.vit_output_dim, |
|
) |
|
|
|
@torch.inference_mode() |
|
def forward(self, pixel_values: torch.Tensor, **kwargs) -> CSDOutput: |
|
image_embeds, style_embeds, content_embeds = self.model(pixel_values) |
|
return CSDOutput(image_embeds=image_embeds, style_embeds=style_embeds, content_embeds=content_embeds) |
|
|