English
refoundd commited on
Commit
0f93b31
·
verified ·
1 Parent(s): 2ccac96

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +67 -47
handler.py CHANGED
@@ -1,15 +1,20 @@
1
  import os
2
- from typing import Any, Dict
3
  from PIL import Image
4
  import torch
5
  from diffusers import FluxPipeline
6
  from huggingface_inference_toolkit.logging import logger
7
  from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
 
8
  import time
9
- import torch.distributed as dist
10
- from para_attn.context_parallel import init_context_parallel_mesh
11
- from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
12
- from para_attn.parallel_vae.diffusers_adapters import parallelize_vae
 
 
 
 
13
 
14
  class EndpointHandler:
15
  def __init__(self, path=""):
@@ -17,60 +22,75 @@ class EndpointHandler:
17
  "NoMoreCopyrightOrg/flux-dev",
18
  torch_dtype=torch.bfloat16,
19
  ).to("cuda")
20
- mesh = init_context_parallel_mesh(
21
- self.pipe.device.type,
22
- max_ring_dim_size=2,
23
- )
24
- parallelize_pipe(
25
- self.pipe,
26
- mesh=mesh,
27
- )
28
- parallelize_vae(self.pipe.vae, mesh=mesh._flatten())
29
  apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.12)
30
- torch._inductor.config.reorder_for_compute_comm_overlap = True
31
  self.pipe.transformer = torch.compile(
32
  self.pipe.transformer, mode="max-autotune-no-cudagraphs",
33
  )
34
  self.pipe.vae = torch.compile(
35
  self.pipe.vae, mode="max-autotune-no-cudagraphs",
36
  )
 
 
 
 
 
37
 
38
- def __call__(self, data: Dict[str, Any]) -> str:
39
- logger.info(f"Received incoming request with {data=}")
 
 
 
 
 
40
 
41
- if "inputs" in data and isinstance(data["inputs"], str):
42
- prompt = data.pop("inputs")
43
- elif "prompt" in data and isinstance(data["prompt"], str):
44
- prompt = data.pop("prompt")
45
- else:
46
- raise ValueError(
47
- "Provided input body must contain either the key `inputs` or `prompt` with the"
48
- " prompt to use for the image generation, and it needs to be a non-empty string."
49
- )
50
 
51
- parameters = data.pop("parameters", {})
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- num_inference_steps = parameters.get("num_inference_steps", 28)
54
- width = parameters.get("width", 1024)
55
- height = parameters.get("height", 1024)
56
- guidance_scale = parameters.get("guidance_scale", 3.5)
 
57
 
58
- # seed generator (seed cannot be provided as is but via a generator)
59
- seed = parameters.get("seed", 0)
60
- generator = torch.manual_seed(seed)
61
- start_time = time.time()
62
- result = self.pipe( # type: ignore
63
- prompt,
64
- height=height,
65
- width=width,
66
- guidance_scale=guidance_scale,
67
- num_inference_steps=num_inference_steps,
68
- generator=generator,
69
- output_type="pil" if dist.get_rank() == 0 else "pt",
70
- ).images[0]
71
- end_time = time.time()
72
- if dist.get_rank() == 0:
73
  time_taken = end_time - start_time
74
  print(f"Time taken: {time_taken:.2f} seconds")
 
 
75
  return result
76
- return "123"
 
 
 
1
  import os
2
+ from typing import Any, Dict, Union
3
  from PIL import Image
4
  import torch
5
  from diffusers import FluxPipeline
6
  from huggingface_inference_toolkit.logging import logger
7
  from para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe
8
+ from torchao.quantization import autoquant
9
  import time
10
+ import gc
11
+
12
+ # Set high precision for float32 matrix multiplications.
13
+ # This setting optimizes performance on NVIDIA GPUs with Ampere architecture (e.g., A100, RTX 30 series) or newer.
14
+ torch.set_float32_matmul_precision("high")
15
+
16
+ import torch._dynamo
17
+ torch._dynamo.config.suppress_errors = False # for debugging
18
 
19
  class EndpointHandler:
20
  def __init__(self, path=""):
 
22
  "NoMoreCopyrightOrg/flux-dev",
23
  torch_dtype=torch.bfloat16,
24
  ).to("cuda")
25
+ self.pipe.enable_vae_slicing()
26
+ self.pipe.enable_vae_tiling()
27
+ self.pipe.transformer.fuse_qkv_projections()
28
+ self.pipe.vae.fuse_qkv_projections()
29
+ self.pipe.transformer.to(memory_format=torch.channels_last)
30
+ self.pipe.vae.to(memory_format=torch.channels_last)
 
 
 
31
  apply_cache_on_pipe(self.pipe, residual_diff_threshold=0.12)
 
32
  self.pipe.transformer = torch.compile(
33
  self.pipe.transformer, mode="max-autotune-no-cudagraphs",
34
  )
35
  self.pipe.vae = torch.compile(
36
  self.pipe.vae, mode="max-autotune-no-cudagraphs",
37
  )
38
+ self.pipe.transformer = autoquant(self.pipe.transformer, error_on_unseen=False)
39
+ self.pipe.vae = autoquant(self.pipe.vae, error_on_unseen=False)
40
+
41
+ gc.collect()
42
+ torch.cuda.empty_cache()
43
 
44
+ start_time = time.time()
45
+ print("Start warming-up pipeline")
46
+ self.pipe("Hello world!") # Warm-up for compiling
47
+ end_time = time.time()
48
+ time_taken = end_time - start_time
49
+ print(f"Time taken: {time_taken:.2f} seconds")
50
+ self.record=0
51
 
52
+ def __call__(self, data: Dict[str, Any]) -> Union[Image.Image, None]:
53
+ try:
54
+ logger.info(f"Received incoming request with {data=}")
 
 
 
 
 
 
55
 
56
+ if "inputs" in data and isinstance(data["inputs"], str):
57
+ prompt = data.pop("inputs")
58
+ elif "prompt" in data and isinstance(data["prompt"], str):
59
+ prompt = data.pop("prompt")
60
+ else:
61
+ raise ValueError(
62
+ "Provided input body must contain either the key `inputs` or `prompt` with the"
63
+ " prompt to use for the image generation, and it needs to be a non-empty string."
64
+ )
65
+ if prompt=="get_queue":
66
+ return self.record
67
+ parameters = data.pop("parameters", {})
68
 
69
+ num_inference_steps = parameters.get("num_inference_steps", 28)
70
+ width = parameters.get("width", 1024)
71
+ height = parameters.get("height", 1024)
72
+ #guidance_scale = parameters.get("guidance_scale", 3.5)
73
+ guidance_scale = parameters.get("guidance", 3.5)
74
 
75
+ # seed generator (seed cannot be provided as is but via a generator)
76
+ seed = parameters.get("seed", 0)
77
+ generator = torch.manual_seed(seed)
78
+ self.record+=1
79
+ start_time = time.time()
80
+ result = self.pipe( # type: ignore
81
+ prompt,
82
+ height=height,
83
+ width=width,
84
+ guidance_scale=guidance_scale,
85
+ num_inference_steps=num_inference_steps,
86
+ generator=generator,
87
+ ).images[0]
88
+ end_time = time.time()
 
89
  time_taken = end_time - start_time
90
  print(f"Time taken: {time_taken:.2f} seconds")
91
+ self.record-=1
92
+
93
  return result
94
+ except Exception as e:
95
+ print(e)
96
+ return None