Fix
Browse files- inference.py +7 -4
inference.py
CHANGED
@@ -55,12 +55,15 @@ class InferencePipeline:
|
|
55 |
if model_id == self.model_id:
|
56 |
return
|
57 |
base_model_id = self.get_base_model_info(model_id, self.hf_token)
|
58 |
-
unet = UNet3DConditionModel.from_pretrained(
|
59 |
-
|
60 |
-
|
|
|
|
|
61 |
pipe = TuneAVideoPipeline.from_pretrained(base_model_id,
|
62 |
unet=unet,
|
63 |
-
torch_dtype=torch.float16
|
|
|
64 |
pipe = pipe.to(self.device)
|
65 |
self.pipe = pipe
|
66 |
self.model_id = model_id # type: ignore
|
|
|
55 |
if model_id == self.model_id:
|
56 |
return
|
57 |
base_model_id = self.get_base_model_info(model_id, self.hf_token)
|
58 |
+
unet = UNet3DConditionModel.from_pretrained(
|
59 |
+
model_id,
|
60 |
+
subfolder='unet',
|
61 |
+
torch_dtype=torch.float16,
|
62 |
+
use_auth_token=self.hf_token)
|
63 |
pipe = TuneAVideoPipeline.from_pretrained(base_model_id,
|
64 |
unet=unet,
|
65 |
+
torch_dtype=torch.float16,
|
66 |
+
use_auth_token=self.hf_token)
|
67 |
pipe = pipe.to(self.device)
|
68 |
self.pipe = pipe
|
69 |
self.model_id = model_id # type: ignore
|