barreloflube commited on
Commit
b6b27f8
·
1 Parent(s): ad31d2d

Refactor flux_helpers.py to use CUDA device for generator seed

Browse files
Files changed (1) hide show
  1. modules/helpers/flux_helpers.py +2 -1
modules/helpers/flux_helpers.py CHANGED
@@ -88,6 +88,7 @@ def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
88
  positive_prompt_embeds, positive_prompt_pooled = get_prompt_attention(pipeline, request.prompt)
89
 
90
  # Common Args
 
91
  args = {
92
  'prompt_embeds': positive_prompt_embeds,
93
  'pooled_prompt_embeds': positive_prompt_pooled,
@@ -96,7 +97,7 @@ def gen_img(request: BaseReq | BaseImg2ImgReq | BaseInpaintReq):
96
  'num_images_per_prompt': request.num_images_per_prompt,
97
  'num_inference_steps': request.num_inference_steps,
98
  'guidance_scale': request.guidance_scale,
99
- 'generator': [torch.Generator().manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator().manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)],
100
  }
101
 
102
  if request.controlnet_config:
 
88
  positive_prompt_embeds, positive_prompt_pooled = get_prompt_attention(pipeline, request.prompt)
89
 
90
  # Common Args
91
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
92
  args = {
93
  'prompt_embeds': positive_prompt_embeds,
94
  'pooled_prompt_embeds': positive_prompt_pooled,
 
97
  'num_images_per_prompt': request.num_images_per_prompt,
98
  'num_inference_steps': request.num_inference_steps,
99
  'guidance_scale': request.guidance_scale,
100
+ 'generator': [torch.Generator(device=device).manual_seed(request.seed + i) if not request.seed is any([None, 0, -1]) else torch.Generator(device=device).manual_seed(random.randint(0, 2**32 - 1)) for i in range(request.num_images_per_prompt)],
101
  }
102
 
103
  if request.controlnet_config: