swoyam-sarvam commited on
Commit
c48064c
·
1 Parent(s): cd8f732

added mask creation

Browse files
Files changed (3) hide show
  1. app.py +18 -10
  2. mask.py +45 -0
  3. requirements.txt +4 -0
app.py CHANGED
@@ -19,28 +19,36 @@ install_packages()
19
  import gradio as gr
20
  from control import main
21
  from diffusers.utils import load_image
 
22
 
23
-
24
- def process_images(image, mask, prompt):
25
  try:
26
- # Pass images directly to main function
 
 
 
 
 
 
27
  result = main(image, mask, prompt)
28
- return result
 
29
  except Exception as e:
30
- return str(e)
31
-
32
 
33
  # Create Gradio interface
34
  demo = gr.Interface(
35
- fn=process_images,
36
  inputs=[
37
  gr.Image(label="Input Image", type="pil"),
38
- gr.Image(label="Mask Image", type="pil"),
39
  gr.Textbox(label="Prompt"),
40
  ],
41
- outputs=gr.Image(label="Generated Image"),
 
 
 
42
  title="Image Inpainting with FLUX ControlNet",
43
- description="Upload an image and its mask, then provide a prompt to generate the inpainted result.",
44
  )
45
 
46
  if __name__ == "__main__":
 
19
  import gradio as gr
20
  from control import main
21
  from diffusers.utils import load_image
22
+ from mask import create_mask
23
 
24
+ def process_image(image, prompt):
 
25
  try:
26
+ # Create mask from input image
27
+ mask = create_mask(image)
28
+
29
+ # First show the generated mask
30
+ yield mask
31
+
32
+ # Then process image with mask
33
  result = main(image, mask, prompt)
34
+ yield result
35
+
36
  except Exception as e:
37
+ yield str(e)
 
38
 
39
  # Create Gradio interface
40
  demo = gr.Interface(
41
+ fn=process_image,
42
  inputs=[
43
  gr.Image(label="Input Image", type="pil"),
 
44
  gr.Textbox(label="Prompt"),
45
  ],
46
+ outputs=[
47
+ gr.Image(label="Generated Mask"),
48
+ gr.Image(label="Generated Image")
49
+ ],
50
  title="Image Inpainting with FLUX ControlNet",
51
+ description="Upload an image and provide a prompt. The system will first generate a mask and then create the inpainted result.",
52
  )
53
 
54
  if __name__ == "__main__":
mask.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+
5
+ def create_mask(image):
6
+ """
7
+ Creates a binary mask from an input image by thresholding and smoothing.
8
+
9
+ Args:
10
+ image: Input image (numpy array)
11
+
12
+ Returns:
13
+ Binary mask image (numpy array)
14
+ """
15
+ # Store original image for visualization
16
+ image_org = image.copy()
17
+
18
+ # Convert image to binary (0 or 255)
19
+ for i in range(len(image)):
20
+ for j in range(len(image[i])):
21
+ if image[i][j] != 255:
22
+ image[i][j] = 0
23
+
24
+ # Add padding of 50 pixels on all sides
25
+ padding = 0
26
+ image = cv2.copyMakeBorder(image, padding, padding, padding, padding,
27
+ cv2.BORDER_CONSTANT, value=255)
28
+
29
+ # Apply Gaussian blur for smoothening
30
+ image = cv2.GaussianBlur(image, (5,5), 50)
31
+
32
+ # Threshold to create binary mask
33
+ _, mask = cv2.threshold(image, 254, 255, cv2.THRESH_BINARY)
34
+
35
+ # Visualization (commented out for production use)
36
+ """
37
+ fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
38
+ ax1.imshow(image_org, cmap="gray")
39
+ ax1.set_title("Original Image")
40
+ ax2.imshow(mask, cmap="gray")
41
+ ax2.set_title("Mask Image")
42
+ plt.show()
43
+ """
44
+
45
+ return mask
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ diffusers==0.30.2
2
+ torch>=2.6.0
3
+ gradio>=5.17.1
4
+ transformers>=4.49.0