Jofthomas HF staff commited on
Commit
660e224
·
1 Parent(s): 52487eb

switch to partialState

Browse files
Files changed (1) hide show
  1. text_to_video.py +4 -3
text_to_video.py CHANGED
@@ -1,10 +1,11 @@
1
  import torch
2
 
3
- from transformers.tools.base import Tool, get_default_device
4
  from transformers.utils import is_accelerate_available
5
 
6
  from diffusers import DiffusionPipeline
7
-
 
8
 
9
  TEXT_TO_VIDEO_DESCRIPTION = (
10
  "This is a tool that creates a video according to a text description. It takes an input named `prompt` which "
@@ -31,7 +32,7 @@ class TextToVideoTool(Tool):
31
 
32
  def setup(self):
33
  if self.device is None:
34
- self.device = get_default_device()
35
 
36
  self.pipeline = DiffusionPipeline.from_pretrained(
37
  self.default_checkpoint, variant="fp16"
 
1
  import torch
2
 
3
+ from transformers.tools.base import Tool
4
  from transformers.utils import is_accelerate_available
5
 
6
  from diffusers import DiffusionPipeline
7
+ if is_accelerate_available():
8
+ from accelerate import PartialState
9
 
10
  TEXT_TO_VIDEO_DESCRIPTION = (
11
  "This is a tool that creates a video according to a text description. It takes an input named `prompt` which "
 
32
 
33
  def setup(self):
34
  if self.device is None:
35
+ self.device = PartialState().default_device
36
 
37
  self.pipeline = DiffusionPipeline.from_pretrained(
38
  self.default_checkpoint, variant="fp16"