import huggingface_hub | |
from .config import Config | |
from transformers import PreTrainedModel | |
from .dcgan import Generator | |
# config = Config() | |
# config.save_pretrained("WGAN-GP") | |
class WGAN_GP(PreTrainedModel): | |
config_class = Config | |
def __init__(self, config): | |
super().__init__(config) | |
self.generator=Generator(config.cfg["imsize"],config.cfg["img_ch"],config.cfg["zdim"], | |
config.cfg["norm_type"]["g"],config.cfg["final_activation"]["g"]) | |
def forward(self, input): | |
return self.generator(input) | |