Jofthomas HF staff commited on
Commit
e26163a
1 Parent(s): 29d7a26

Update image_transformation.py

Browse files
Files changed (1) hide show
  1. image_transformation.py +4 -3
image_transformation.py CHANGED
@@ -1,13 +1,14 @@
1
  import torch
2
 
3
- from transformers.tools.base import Tool, get_default_device
4
  from transformers.utils import (
5
  is_accelerate_available,
6
  is_vision_available,
7
  )
8
 
9
  from diffusers import DiffusionPipeline
10
-
 
11
 
12
  IMAGE_TRANSFORMATION_DESCRIPTION = (
13
  "This is a tool that transforms an image according to a prompt. It takes two inputs: `image`, which should be "
@@ -37,7 +38,7 @@ class ImageTransformationTool(Tool):
37
 
38
  def setup(self):
39
  if self.device is None:
40
- self.device = get_default_device()
41
 
42
  self.pipeline = DiffusionPipeline.from_pretrained(self.stable_diffusion)
43
 
 
1
  import torch
2
 
3
+ from transformers.tools.base import Tool
4
  from transformers.utils import (
5
  is_accelerate_available,
6
  is_vision_available,
7
  )
8
 
9
  from diffusers import DiffusionPipeline
10
+ if is_accelerate_available():
11
+ from accelerate import PartialState
12
 
13
  IMAGE_TRANSFORMATION_DESCRIPTION = (
14
  "This is a tool that transforms an image according to a prompt. It takes two inputs: `image`, which should be "
 
38
 
39
  def setup(self):
40
  if self.device is None:
41
+ self.device = PartialState().default_device
42
 
43
  self.pipeline = DiffusionPipeline.from_pretrained(self.stable_diffusion)
44