Spaces:
Runtime error
Runtime error
import torch | |
from huggingface_hub import hf_hub_download | |
from safetensors.torch import load_file | |
def load_unet_model(base, repo, ckpt, device="cpu"): | |
""" | |
Load the UNet model from Hugging Face Hub. | |
Args: | |
base (str): Base model name. | |
repo (str): Repository name. | |
ckpt (str): Checkpoint filename. | |
device (str): Device to load the model on. | |
Returns: | |
UNet2DConditionModel: Loaded UNet model. | |
""" | |
from diffusers import UNet2DConditionModel | |
unet = UNet2DConditionModel.from_config(base, subfolder="unet").to(device, torch.float16) | |
unet.load_state_dict(load_file(hf_hub_download(repo, ckpt), device=device)) | |
return unet | |