Spaces:
Running
on
Zero
Running
on
Zero
Update SDLens/hooked_sd_pipeline.py
Browse files
SDLens/hooked_sd_pipeline.py
CHANGED
@@ -3,6 +3,7 @@ from diffusers import StableDiffusionXLPipeline, IFPipeline
|
|
3 |
from typing import List, Dict, Callable, Union
|
4 |
import torch
|
5 |
from .hooked_scheduler import HookedNoiseScheduler
|
|
|
6 |
|
7 |
def retrieve(io):
|
8 |
if isinstance(io, tuple):
|
@@ -30,7 +31,7 @@ class HookedDiffusionAbstractPipeline:
|
|
30 |
def from_pretrained(cls, *args, **kwargs):
|
31 |
return cls(cls.parent_cls.from_pretrained(*args, **kwargs))
|
32 |
|
33 |
-
|
34 |
def run_with_hooks(self,
|
35 |
*args,
|
36 |
position_hook_dict: Dict[str, Union[Callable, List[Callable]]],
|
@@ -69,6 +70,8 @@ class HookedDiffusionAbstractPipeline:
|
|
69 |
|
70 |
return output
|
71 |
|
|
|
|
|
72 |
def run_with_cache(self,
|
73 |
*args,
|
74 |
positions_to_cache: List[str],
|
@@ -123,6 +126,8 @@ class HookedDiffusionAbstractPipeline:
|
|
123 |
cache_dict['output'] = cache_output
|
124 |
return output, cache_dict
|
125 |
|
|
|
|
|
126 |
def run_with_hooks_and_cache(self,
|
127 |
*args,
|
128 |
position_hook_dict: Dict[str, Union[Callable, List[Callable]]],
|
|
|
3 |
from typing import List, Dict, Callable, Union
|
4 |
import torch
|
5 |
from .hooked_scheduler import HookedNoiseScheduler
|
6 |
+
import spaces
|
7 |
|
8 |
def retrieve(io):
|
9 |
if isinstance(io, tuple):
|
|
|
31 |
def from_pretrained(cls, *args, **kwargs):
|
32 |
return cls(cls.parent_cls.from_pretrained(*args, **kwargs))
|
33 |
|
34 |
+
@spaces.GPU
|
35 |
def run_with_hooks(self,
|
36 |
*args,
|
37 |
position_hook_dict: Dict[str, Union[Callable, List[Callable]]],
|
|
|
70 |
|
71 |
return output
|
72 |
|
73 |
+
|
74 |
+
@spaces.GPU
|
75 |
def run_with_cache(self,
|
76 |
*args,
|
77 |
positions_to_cache: List[str],
|
|
|
126 |
cache_dict['output'] = cache_output
|
127 |
return output, cache_dict
|
128 |
|
129 |
+
|
130 |
+
@spaces.GPU
|
131 |
def run_with_hooks_and_cache(self,
|
132 |
*args,
|
133 |
position_hook_dict: Dict[str, Union[Callable, List[Callable]]],
|