Spaces:
Paused
Paused
# Copyright 2024 The HuggingFace Team. All rights reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import importlib | |
import os | |
from dataclasses import dataclass | |
from enum import Enum | |
from typing import Optional, Union | |
import torch | |
from huggingface_hub.utils import validate_hf_hub_args | |
from ..utils import BaseOutput, PushToHubMixin | |
SCHEDULER_CONFIG_NAME = "scheduler_config.json" | |
# NOTE: We make this type an enum because it simplifies usage in docs and prevents | |
# circular imports when used for `_compatibles` within the schedulers module. | |
# When it's used as a type in pipelines, it really is a Union because the actual | |
# scheduler instance is passed in. | |
class KarrasDiffusionSchedulers(Enum): | |
DDIMScheduler = 1 | |
DDPMScheduler = 2 | |
PNDMScheduler = 3 | |
LMSDiscreteScheduler = 4 | |
EulerDiscreteScheduler = 5 | |
HeunDiscreteScheduler = 6 | |
EulerAncestralDiscreteScheduler = 7 | |
DPMSolverMultistepScheduler = 8 | |
DPMSolverSinglestepScheduler = 9 | |
KDPM2DiscreteScheduler = 10 | |
KDPM2AncestralDiscreteScheduler = 11 | |
DEISMultistepScheduler = 12 | |
UniPCMultistepScheduler = 13 | |
DPMSolverSDEScheduler = 14 | |
EDMEulerScheduler = 15 | |
class SchedulerOutput(BaseOutput): | |
""" | |
Base class for the output of a scheduler's `step` function. | |
Args: | |
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): | |
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the | |
denoising loop. | |
""" | |
prev_sample: torch.FloatTensor | |
class SchedulerMixin(PushToHubMixin): | |
""" | |
Base class for all schedulers. | |
[`SchedulerMixin`] contains common functions shared by all schedulers such as general loading and saving | |
functionalities. | |
[`ConfigMixin`] takes care of storing the configuration attributes (like `num_train_timesteps`) that are passed to | |
the scheduler's `__init__` function, and the attributes can be accessed by `scheduler.config.num_train_timesteps`. | |
Class attributes: | |
- **_compatibles** (`List[str]`) -- A list of scheduler classes that are compatible with the parent scheduler | |
class. Use [`~ConfigMixin.from_config`] to load a different compatible scheduler class (should be overridden | |
by parent class). | |
""" | |
config_name = SCHEDULER_CONFIG_NAME | |
_compatibles = [] | |
has_compatibles = True | |
def from_pretrained( | |
cls, | |
pretrained_model_name_or_path: Optional[Union[str, os.PathLike]] = None, | |
subfolder: Optional[str] = None, | |
return_unused_kwargs=False, | |
**kwargs, | |
): | |
r""" | |
Instantiate a scheduler from a pre-defined JSON configuration file in a local directory or Hub repository. | |
Parameters: | |
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*): | |
Can be either: | |
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on | |
the Hub. | |
- A path to a *directory* (for example `./my_model_directory`) containing the scheduler | |
configuration saved with [`~SchedulerMixin.save_pretrained`]. | |
subfolder (`str`, *optional*): | |
The subfolder location of a model file within a larger model repository on the Hub or locally. | |
return_unused_kwargs (`bool`, *optional*, defaults to `False`): | |
Whether kwargs that are not consumed by the Python class should be returned or not. | |
cache_dir (`Union[str, os.PathLike]`, *optional*): | |
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache | |
is not used. | |
force_download (`bool`, *optional*, defaults to `False`): | |
Whether or not to force the (re-)download of the model weights and configuration files, overriding the | |
cached versions if they exist. | |
resume_download (`bool`, *optional*, defaults to `False`): | |
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any | |
incompletely downloaded files are deleted. | |
proxies (`Dict[str, str]`, *optional*): | |
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128', | |
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request. | |
output_loading_info(`bool`, *optional*, defaults to `False`): | |
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages. | |
local_files_only(`bool`, *optional*, defaults to `False`): | |
Whether to only load local model weights and configuration files or not. If set to `True`, the model | |
won't be downloaded from the Hub. | |
token (`str` or *bool*, *optional*): | |
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from | |
`diffusers-cli login` (stored in `~/.huggingface`) is used. | |
revision (`str`, *optional*, defaults to `"main"`): | |
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier | |
allowed by Git. | |
<Tip> | |
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with | |
`huggingface-cli login`. You can also activate the special | |
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a | |
firewalled environment. | |
</Tip> | |
""" | |
config, kwargs, commit_hash = cls.load_config( | |
pretrained_model_name_or_path=pretrained_model_name_or_path, | |
subfolder=subfolder, | |
return_unused_kwargs=True, | |
return_commit_hash=True, | |
**kwargs, | |
) | |
return cls.from_config(config, return_unused_kwargs=return_unused_kwargs, **kwargs) | |
def save_pretrained(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs): | |
""" | |
Save a scheduler configuration object to a directory so that it can be reloaded using the | |
[`~SchedulerMixin.from_pretrained`] class method. | |
Args: | |
save_directory (`str` or `os.PathLike`): | |
Directory where the configuration JSON file will be saved (will be created if it does not exist). | |
push_to_hub (`bool`, *optional*, defaults to `False`): | |
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the | |
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your | |
namespace). | |
kwargs (`Dict[str, Any]`, *optional*): | |
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method. | |
""" | |
self.save_config(save_directory=save_directory, push_to_hub=push_to_hub, **kwargs) | |
def compatibles(self): | |
""" | |
Returns all schedulers that are compatible with this scheduler | |
Returns: | |
`List[SchedulerMixin]`: List of compatible schedulers | |
""" | |
return self._get_compatibles() | |
def _get_compatibles(cls): | |
compatible_classes_str = list(set([cls.__name__] + cls._compatibles)) | |
diffusers_library = importlib.import_module(__name__.split(".")[0]) | |
compatible_classes = [ | |
getattr(diffusers_library, c) for c in compatible_classes_str if hasattr(diffusers_library, c) | |
] | |
return compatible_classes | |