wangshuai6 commited on
Commit
52d009c
·
1 Parent(s): 819ea47
Files changed (1) hide show
  1. app.py +2 -0
app.py CHANGED
@@ -33,6 +33,7 @@
33
  # step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
34
  import os
35
  import torch
 
36
  import argparse
37
  from omegaconf import OmegaConf
38
  from src.models.vae import fp2uint8
@@ -69,6 +70,7 @@ class Pipeline:
69
  self.diffusion_sampler = diffusion_sampler
70
  self.resolution = resolution
71
 
 
72
  @torch.no_grad()
73
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
74
  def __call__(self, y, num_images, seed, num_steps, guidance, state_refresh_rate, guidance_interval_min, guidance_interval_max, timeshift):
 
33
  # step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
34
  import os
35
  import torch
36
+ import spaces
37
  import argparse
38
  from omegaconf import OmegaConf
39
  from src.models.vae import fp2uint8
 
70
  self.diffusion_sampler = diffusion_sampler
71
  self.resolution = resolution
72
 
73
+ @spaces.GPU
74
  @torch.no_grad()
75
  @torch.autocast(device_type="cuda", dtype=torch.bfloat16)
76
  def __call__(self, y, num_images, seed, num_steps, guidance, state_refresh_rate, guidance_interval_min, guidance_interval_max, timeshift):