from clip import CLIP from encoder import VAE_Encoder from decoder import VAE_Decoder from diffusion import Diffusion import model_converter import torch def load_models(ckpt_path, device, use_se=False): state_dict = model_converter.load_from_standard_weights(ckpt_path, device) encoder = VAE_Encoder().to(device) encoder.load_state_dict(state_dict['encoder'], strict=True) decoder = VAE_Decoder().to(device) decoder.load_state_dict(state_dict['decoder'], strict=True) # Initialize diffusion model with SE blocks disabled for loading pre-trained weights diffusion = Diffusion(use_se=False).to(device) diffusion.load_state_dict(state_dict['diffusion'], strict=True) # If SE blocks are requested, reinitialize the model with them if use_se: diffusion = Diffusion(use_se=True).to(device) # Copy the weights from the loaded model with torch.no_grad(): for name, param in diffusion.named_parameters(): if 'se' not in name: # Skip SE block parameters if name in state_dict['diffusion']: param.copy_(state_dict['diffusion'][name]) clip = CLIP().to(device) clip.load_state_dict(state_dict['clip'], strict=True) return { 'clip': clip, 'encoder': encoder, 'decoder': decoder, 'diffusion': diffusion, }