Jordan Legg commited on
Commit
51e970b
Β·
1 Parent(s): 242b4ef

return to normal

Browse files
Files changed (1) hide show
  1. app.py +11 -27
app.py CHANGED
@@ -3,25 +3,12 @@ import numpy as np
3
  import random
4
  import spaces
5
  import torch
6
- from diffusers import FluxPipeline
7
 
8
- # Enable cuDNN benchmarking for potential performance improvement
9
- torch.backends.cudnn.benchmark = True
10
-
11
- # Set up device and data types
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
- DTYPE = torch.float16
14
-
15
- # Load the model
16
- pipe = FluxPipeline.from_pretrained(
17
- "black-forest-labs/FLUX.1-schnell",
18
- torch_dtype=torch.bfloat16,
19
- )
20
 
21
- # Configure the pipeline
22
- pipe.enable_sequential_cpu_offload()
23
- pipe.vae.enable_tiling()
24
- pipe = pipe.to(DTYPE)
25
 
26
  MAX_SEED = np.iinfo(np.int32).max
27
  MAX_IMAGE_SIZE = 2048
@@ -30,18 +17,15 @@ MAX_IMAGE_SIZE = 2048
30
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
31
  if randomize_seed:
32
  seed = random.randint(0, MAX_SEED)
33
- generator = torch.Generator(device=device).manual_seed(seed)
34
-
35
  image = pipe(
36
- prompt,
37
- num_inference_steps=num_inference_steps,
38
- num_images_per_prompt=1,
39
- guidance_scale=0.0,
40
- height=height,
41
- width=width,
42
- generator=generator,
43
- ).images[0]
44
-
45
  return image, seed
46
 
47
  # Gradio interface
 
3
  import random
4
  import spaces
5
  import torch
6
+ from diffusers import DiffusionPipeline
7
 
8
+ dtype = torch.bfloat16
 
 
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
10
 
11
+ pipe = DiffusionPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype).to(device)
 
 
 
12
 
13
  MAX_SEED = np.iinfo(np.int32).max
14
  MAX_IMAGE_SIZE = 2048
 
17
  def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024, num_inference_steps=4, progress=gr.Progress(track_tqdm=True)):
18
  if randomize_seed:
19
  seed = random.randint(0, MAX_SEED)
20
+ generator = torch.Generator().manual_seed(seed)
 
21
  image = pipe(
22
+ prompt = prompt,
23
+ width = width,
24
+ height = height,
25
+ num_inference_steps = num_inference_steps,
26
+ generator = generator,
27
+ guidance_scale=0.0
28
+ ).images[0]
 
 
29
  return image, seed
30
 
31
  # Gradio interface