Spaces:
Runtime error
Runtime error
Update models/__init__.py
Browse files- models/__init__.py +1 -21
models/__init__.py
CHANGED
@@ -2,8 +2,6 @@ import os
|
|
2 |
import sys
|
3 |
sys.path.append(os.path.split(sys.path[0])[0])
|
4 |
|
5 |
-
from .dit import DiT_models
|
6 |
-
from .uvit import UViT_models
|
7 |
from .unet import UNet3DConditionModel
|
8 |
from torch.optim.lr_scheduler import LambdaLR
|
9 |
|
@@ -28,25 +26,7 @@ def get_lr_scheduler(optimizer, name, **kwargs):
|
|
28 |
|
29 |
def get_models(args):
|
30 |
|
31 |
-
if '
|
32 |
-
return DiT_models[args.model](
|
33 |
-
input_size=args.latent_size,
|
34 |
-
num_classes=args.num_classes,
|
35 |
-
class_guided=args.class_guided,
|
36 |
-
num_frames=args.num_frames,
|
37 |
-
use_lora=args.use_lora,
|
38 |
-
attention_mode=args.attention_mode
|
39 |
-
)
|
40 |
-
elif 'UViT' in args.model:
|
41 |
-
return UViT_models[args.model](
|
42 |
-
input_size=args.latent_size,
|
43 |
-
num_classes=args.num_classes,
|
44 |
-
class_guided=args.class_guided,
|
45 |
-
num_frames=args.num_frames,
|
46 |
-
use_lora=args.use_lora,
|
47 |
-
attention_mode=args.attention_mode
|
48 |
-
)
|
49 |
-
elif 'TAV' in args.model:
|
50 |
pretrained_model_path = args.pretrained_model_path
|
51 |
return UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", use_concat=args.use_mask)
|
52 |
else:
|
|
|
2 |
import sys
|
3 |
sys.path.append(os.path.split(sys.path[0])[0])
|
4 |
|
|
|
|
|
5 |
from .unet import UNet3DConditionModel
|
6 |
from torch.optim.lr_scheduler import LambdaLR
|
7 |
|
|
|
26 |
|
27 |
def get_models(args):
|
28 |
|
29 |
+
if 'TAV' in args.model:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
pretrained_model_path = args.pretrained_model_path
|
31 |
return UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet", use_concat=args.use_mask)
|
32 |
else:
|