# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import importlib import os from dataclasses import dataclass from enum import Enum from typing import Optional, Union import torch from huggingface_hub.utils import validate_hf_hub_args from ..utils import BaseOutput, PushToHubMixin SCHEDULER_CONFIG_NAME = "scheduler_config.json" # NOTE: We make this type an enum because it simplifies usage in docs and prevents # circular imports when used for `_compatibles` within the schedulers module. # When it's used as a type in pipelines, it really is a Union because the actual # scheduler instance is passed in. class KarrasDiffusionSchedulers(Enum): DDIMScheduler = 1 DDPMScheduler = 2 PNDMScheduler = 3 LMSDiscreteScheduler = 4 EulerDiscreteScheduler = 5 HeunDiscreteScheduler = 6 EulerAncestralDiscreteScheduler = 7 DPMSolverMultistepScheduler = 8 DPMSolverSinglestepScheduler = 9 KDPM2DiscreteScheduler = 10 KDPM2AncestralDiscreteScheduler = 11 DEISMultistepScheduler = 12 UniPCMultistepScheduler = 13 DPMSolverSDEScheduler = 14 EDMEulerScheduler = 15 @dataclass class SchedulerOutput(BaseOutput): """ Base class for the output of a scheduler's `step` function. Args: prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the denoising loop. """ prev_sample: torch.FloatTensor class SchedulerMixin(PushToHubMixin): """ Base class for all schedulers. [`SchedulerMixin`] contains common functions shared by all schedulers such as general loading and saving functionalities. [`ConfigMixin`] takes care of storing the configuration attributes (like `num_train_timesteps`) that are passed to the scheduler's `__init__` function, and the attributes can be accessed by `scheduler.config.num_train_timesteps`. Class attributes: - **_compatibles** (`List[str]`) -- A list of scheduler classes that are compatible with the parent scheduler class. Use [`~ConfigMixin.from_config`] to load a different compatible scheduler class (should be overridden by parent class). """ config_name = SCHEDULER_CONFIG_NAME _compatibles = [] has_compatibles = True @classmethod @validate_hf_hub_args def from_pretrained( cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, subfolder: Optional[str] = None, return_unused_kwargs=False, **kwargs, ): r""" Instantiate a scheduler from a pre-defined JSON configuration file in a local directory or Hub repository. Parameters: pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): Can be either: - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on the Hub. - A path to a *directory* (for example `./my_model_directory`) containing the scheduler configuration saved with [`~SchedulerMixin.save_pretrained`]. subfolder (`str`, *optional*): The subfolder location of a model file within a larger model repository on the Hub or locally. return_unused_kwargs (`bool`, *optional*, defaults to `False`): Whether kwargs that are not consumed by the Python class should be returned or not. cache_dir (`Union[str, os.PathLike]`, *optional*): Path to a directory where a downloaded pretrained model configuration is cached if the standard cache is not used. force_download (`bool`, *optional*, defaults to `False`): Whether or not to force the (re-)download of the model weights and configuration files, overriding the cached versions if they exist. resume_download (`bool`, *optional*, defaults to `False`): Whether or not to resume downloading the model weights and configuration files. If set to `False`, any incompletely downloaded files are deleted. proxies (`Dict[str, str]`, *optional*): A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. output_loading_info(`bool`, *optional*, defaults to `False`): Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. local_files_only(`bool`, *optional*, defaults to `False`): Whether to only load local model weights and configuration files or not. If set to `True`, the model won't be downloaded from the Hub. token (`str` or *bool*, *optional*): The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from `diffusers-cli login` (stored in `~/.huggingface`) is used. revision (`str`, *optional*, defaults to `"main"`): The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier allowed by Git. To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `huggingface-cli login`. You can also activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a firewalled environment. """ config, kwargs, commit_hash = cls.load_config( pretrained_model_name_or_path=pretrained_model_name_or_path, subfolder=subfolder, return_unused_kwargs=True, return_commit_hash=True, **kwargs, ) return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): """ Save a scheduler configuration object to a directory so that it can be reloaded using the [`~SchedulerMixin.from_pretrained`] class method. Args: save_directory (`str` or `os.PathLike`): Directory where the configuration JSON file will be saved (will be created if it does not exist). push_to_hub (`bool`, *optional*, defaults to `False`): Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the repository you want to push to with `repo_id` (will default to the name of `save_directory` in your namespace). kwargs (`Dict[str, Any]`, *optional*): Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. """ self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) @property def compatibles(self): """ Returns all schedulers that are compatible with this scheduler Returns: `List[SchedulerMixin]`: List of compatible schedulers """ return self._get_compatibles() @classmethod def _get_compatibles(cls): compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) diffusers_library = importlib.import_module(__name__.split(".")[0]) compatible_classes = [ getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) ] return compatible_classes