Anwar786 commited on
Commit
b754394
·
verified ·
1 Parent(s): ac20447

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +59 -86
handler.py CHANGED
@@ -1,127 +1,100 @@
1
- from typing import Dict, List, Any
2
  import base64
3
  from PIL import Image
4
  from io import BytesIO
5
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
6
  import torch
7
-
8
-
9
- import numpy as np
10
- import cv2
11
  import controlnet_hinter
12
 
13
  # set device
14
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
  if device.type != 'cuda':
16
- raise ValueError("need to run on GPU")
17
  # set mixed precision dtype
18
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
19
 
20
- # controlnet mapping for controlnet id and control hinter
21
  CONTROLNET_MAPPING = {
22
- "canny_edge": {
23
- "model_id": "lllyasviel/sd-controlnet-canny",
24
- "hinter": controlnet_hinter.hint_canny
25
- },
26
- "pose": {
27
- "model_id": "lllyasviel/sd-controlnet-openpose",
28
- "hinter": controlnet_hinter.hint_openpose
29
- },
30
  "depth": {
31
  "model_id": "lllyasviel/sd-controlnet-depth",
32
  "hinter": controlnet_hinter.hint_depth
33
- },
34
- "scribble": {
35
- "model_id": "lllyasviel/sd-controlnet-scribble",
36
- "hinter": controlnet_hinter.hint_scribble,
37
- },
38
- "segmentation": {
39
- "model_id": "lllyasviel/sd-controlnet-seg",
40
- "hinter": controlnet_hinter.hint_segmentation,
41
- },
42
- "normal": {
43
- "model_id": "lllyasviel/sd-controlnet-normal",
44
- "hinter": controlnet_hinter.hint_normal,
45
- },
46
- "hed": {
47
- "model_id": "lllyasviel/sd-controlnet-hed",
48
- "hinter": controlnet_hinter.hint_hed,
49
- },
50
- "hough": {
51
- "model_id": "lllyasviel/sd-controlnet-mlsd",
52
- "hinter": controlnet_hinter.hint_hough,
53
  }
54
  }
55
 
56
-
57
  class EndpointHandler():
58
  def __init__(self, path=""):
59
  # define default controlnet id and load controlnet
60
- self.control_type = "normal"
61
- self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],torch_dtype=dtype).to(device)
62
-
63
- # Load StableDiffusionControlNetPipeline
64
  self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
65
- self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
66
- controlnet=self.controlnet,
67
  torch_dtype=dtype,
68
  safety_checker=None).to(device)
69
  # Define Generator with seed
70
  self.generator = torch.Generator(device="cpu").manual_seed(3)
71
 
72
- def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
73
- """
74
- :param data: A dictionary contains `inputs` and optional `image` field.
75
- :return: A dictionary with `image` field contains image in base64.
76
- """
77
- prompt = data.pop("inputs", None)
78
- image = data.pop("image", None)
79
- controlnet_type = data.pop("controlnet_type", None)
80
-
81
- # Check if neither prompt nor image is provided
82
- if prompt is None and image is None:
83
- return {"error": "Please provide a prompt and base64 encoded image."}
84
-
85
- # Check if a new controlnet is provided
86
- if controlnet_type is not None and controlnet_type != self.control_type:
87
- print(f"changing controlnet from {self.control_type} to {controlnet_type} using {CONTROLNET_MAPPING[controlnet_type]['model_id']} model")
88
- self.control_type = controlnet_type
89
- self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],
90
- torch_dtype=dtype).to(device)
91
- self.pipe.controlnet = self.controlnet
92
-
93
-
94
- # hyperparamters
95
- num_inference_steps = data.pop("num_inference_steps", 30)
96
- guidance_scale = data.pop("guidance_scale", 7.5)
97
- negative_prompt = data.pop("negative_prompt", None)
98
- height = data.pop("height", None)
99
- width = data.pop("width", None)
100
- controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1.0)
101
 
102
- # process image
103
- image = self.decode_base64_image(image)
104
- control_image = CONTROLNET_MAPPING[self.control_type]["hinter"](image)
105
-
106
- # run inference pipeline
 
 
 
107
  out = self.pipe(
108
- prompt=prompt,
109
  negative_prompt=negative_prompt,
110
- image=control_image,
111
- num_inference_steps=num_inference_steps,
112
  guidance_scale=guidance_scale,
113
  num_images_per_prompt=1,
114
  height=height,
115
  width=width,
116
- controlnet_conditioning_scale=controlnet_conditioning_scale,
117
- generator=self.generator
 
118
  )
119
 
120
-
121
- # return first generate PIL image
122
- return out.images[0]
123
-
124
- # helper to decode input image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  def decode_base64_image(self, image_string):
126
  base64_image = base64.b64decode(image_string)
127
  buffer = BytesIO(base64_image)
 
1
+ from typing import List, Dict, Any
2
  import base64
3
  from PIL import Image
4
  from io import BytesIO
5
  from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
6
  import torch
 
 
 
 
7
  import controlnet_hinter
8
 
9
  # set device
10
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
11
  if device.type != 'cuda':
12
+ raise ValueError("Need to run on GPU")
13
  # set mixed precision dtype
14
  dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
15
 
16
+ # controlnet mapping for depth controlnet
17
  CONTROLNET_MAPPING = {
 
 
 
 
 
 
 
 
18
  "depth": {
19
  "model_id": "lllyasviel/sd-controlnet-depth",
20
  "hinter": controlnet_hinter.hint_depth
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  }
22
  }
23
 
 
24
  class EndpointHandler():
25
  def __init__(self, path=""):
26
  # define default controlnet id and load controlnet
27
+ self.control_type = "depth"
28
+ self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"], torch_dtype=dtype).to(device)
29
+
30
+ # Load StableDiffusionControlNetPipeline
31
  self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
32
+ self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
33
+ controlnet=self.controlnet,
34
  torch_dtype=dtype,
35
  safety_checker=None).to(device)
36
  # Define Generator with seed
37
  self.generator = torch.Generator(device="cpu").manual_seed(3)
38
 
39
+ def __call__(self, data: Any) -> Dict[str, str]:
40
+ # Extract parameters from the payload
41
+ prompt = data.get("prompt", None)
42
+ negative_prompt = data.get("negative_prompt", None)
43
+ width = data.get("width", None)
44
+ height = data.get("height", None)
45
+ num_inference_steps = data.get("steps", 30)
46
+ guidance_scale = data.get("cfg_scale", 7)
47
+ sampler_index = data.get("sampler_index", "DPM++ 2M Karras") # Default to "DPM++ 2M Karras" if not provided
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ # Check if prompt is provided
50
+ if prompt is None:
51
+ return {"error": "Please provide a prompt."}
52
+
53
+ # Extract controlnet configuration from payload
54
+ controlnet_config = data.get("alwayson_scripts", {}).get("controlnet", {}).get("args", [{}])[0]
55
+
56
+ # Run stable diffusion process
57
  out = self.pipe(
58
+ prompt=prompt,
59
  negative_prompt=negative_prompt,
60
+ num_inference_steps=num_inference_steps,
 
61
  guidance_scale=guidance_scale,
62
  num_images_per_prompt=1,
63
  height=height,
64
  width=width,
65
+ controlnet_conditioning_scale=1.0,
66
+ generator=self.generator,
67
+ sampler_index=sampler_index # Pass the sampler_index to the stable diffusion process
68
  )
69
 
70
+ # Get the generated image
71
+ generated_image = out.images[0]
72
+
73
+ # Process with controlnet if enabled
74
+ if controlnet_config.get("enabled", False):
75
+ input_image_base64 = controlnet_config.get("input_image", "")
76
+ input_image = self.decode_base64_image(input_image_base64)
77
+ controlnet_model = controlnet_config.get("model", "")
78
+ controlnet_control_mode = controlnet_config.get("control_mode", "")
79
+
80
+ processed_image = self.process_with_controlnet(generated_image, input_image, controlnet_model, controlnet_control_mode)
81
+ else:
82
+ processed_image = generated_image
83
+
84
+ # Return the final processed image as base64
85
+ return {"image": self.encode_base64_image(processed_image)}
86
+
87
+ def process_with_controlnet(self, generated_image, input_image, model, control_mode):
88
+ # Simulated controlnet processing (replace with actual implementation)
89
+ # Here, we're just using the input_image as-is. Replace this with your controlnet logic.
90
+ return input_image
91
+
92
+ def encode_base64_image(self, image):
93
+ # Encode the PIL Image to base64
94
+ buffer = BytesIO()
95
+ image.save(buffer, format="PNG")
96
+ return base64.b64encode(buffer.getvalue()).decode("utf-8")
97
+
98
  def decode_base64_image(self, image_string):
99
  base64_image = base64.b64decode(image_string)
100
  buffer = BytesIO(base64_image)