yamildiego commited on
Commit
041d81f
1 Parent(s): 69481a1
Files changed (1) hide show
  1. handler.py +3 -3
handler.py CHANGED
@@ -16,16 +16,16 @@ if device.type != 'cuda':
16
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
17
 
18
  class EndpointHandler():
19
- # def __init__(self, path=""):
20
  # self.stable_diffusion_id = "Lykon/dreamshaper-8"
21
 
22
  # self.prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype)#.to(device)
23
  # self.decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)#.to(device)
24
 
25
 
26
- # self.generator = torch.Generator(device=device.type).manual_seed(3)
27
 
28
- def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
29
  # import torch
30
 
31
  device = "cuda"
 
16
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
17
 
18
  class EndpointHandler():
19
+ def __init__(self, path=""):
20
  # self.stable_diffusion_id = "Lykon/dreamshaper-8"
21
 
22
  # self.prior_pipeline = StableCascadePriorPipeline.from_pretrained("stabilityai/stable-cascade-prior", torch_dtype=dtype)#.to(device)
23
  # self.decoder_pipeline = StableCascadeDecoderPipeline.from_pretrained("stabilityai/stable-cascade", torch_dtype=dtype)#.to(device)
24
 
25
 
26
+ self.generator = torch.Generator(device=device.type).manual_seed(3)
27
 
28
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
29
  # import torch
30
 
31
  device = "cuda"