Spaces:
Sleeping
Sleeping
Amit Gazal
commited on
Commit
·
eb5c95c
1
Parent(s):
48da469
add text rectangle
Browse files- app.py +121 -17
- requirements.txt +3 -1
app.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import gradio as gr
|
2 |
-
from PIL import Image
|
3 |
import matplotlib.pyplot as plt
|
4 |
import torch
|
5 |
from torchvision import transforms
|
@@ -11,10 +11,14 @@ import io
|
|
11 |
import requests
|
12 |
import numpy as np
|
13 |
from scipy import ndimage
|
|
|
14 |
|
15 |
IDEOGRAM_API_KEY = os.getenv('IDEOGRAM_API_KEY')
|
16 |
IDEOGRAM_URL = "https://api.ideogram.ai/edit"
|
17 |
|
|
|
|
|
|
|
18 |
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
|
19 |
# Constants should be in UPPERCASE
|
20 |
GPT_MODEL_NAME = "gpt-4o"
|
@@ -27,20 +31,19 @@ if torch.cuda.is_available():
|
|
27 |
model.eval()
|
28 |
|
29 |
GPT_PROMPT = '''
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
add christmas lights to some of the stuff in the background, maybe add a few elements like christmas tree, but take into considration the perspective and the logic of the image.
|
35 |
'''
|
36 |
|
37 |
-
def image_to_prompt(image: str) -> tuple[str, str]:
|
38 |
base64_image = encode_image(image)
|
39 |
|
40 |
messages = [{
|
41 |
"role": "user",
|
42 |
"content": [
|
43 |
-
{"type": "text", "text": GPT_PROMPT},
|
44 |
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
|
45 |
]
|
46 |
}]
|
@@ -151,17 +154,115 @@ def dilate_mask(mask: Image.Image) -> Image.Image:
|
|
151 |
# Convert back to PIL Image
|
152 |
return Image.fromarray(dilated_mask.astype(np.uint8))
|
153 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
154 |
def run_flow(input_image, holiday, message):
|
155 |
-
|
|
|
|
|
156 |
print(prompt)
|
157 |
result_image, only_background_image, mask = remove_background(input_image)
|
158 |
dilated_mask = dilate_mask(mask)
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
|
164 |
-
return first_output_image, second_output_image, third_output_image
|
165 |
|
166 |
# Replace the demo interface
|
167 |
demo = gr.Interface(
|
@@ -172,9 +273,12 @@ demo = gr.Interface(
|
|
172 |
gr.Text(label="Optional Message", placeholder="Enter your holiday message here...")
|
173 |
],
|
174 |
outputs=[
|
175 |
-
gr.Image(type="pil", label="
|
176 |
-
gr.Image(type="pil", label="
|
177 |
-
gr.
|
|
|
|
|
|
|
178 |
],
|
179 |
title="Holiday Card Generator",
|
180 |
description="Upload an image to generate a holiday card"
|
|
|
1 |
import gradio as gr
|
2 |
+
from PIL import Image, ImageDraw
|
3 |
import matplotlib.pyplot as plt
|
4 |
import torch
|
5 |
from torchvision import transforms
|
|
|
11 |
import requests
|
12 |
import numpy as np
|
13 |
from scipy import ndimage
|
14 |
+
from insightface.app import FaceAnalysis
|
15 |
|
16 |
IDEOGRAM_API_KEY = os.getenv('IDEOGRAM_API_KEY')
|
17 |
IDEOGRAM_URL = "https://api.ideogram.ai/edit"
|
18 |
|
19 |
+
face_detection_app = FaceAnalysis(allowed_modules=['detection']) # enable detection model only
|
20 |
+
face_detection_app.prepare(ctx_id=0, det_size=(640, 640))
|
21 |
+
|
22 |
client = OpenAI(api_key=os.getenv('OPENAI_API_KEY'))
|
23 |
# Constants should be in UPPERCASE
|
24 |
GPT_MODEL_NAME = "gpt-4o"
|
|
|
31 |
model.eval()
|
32 |
|
33 |
GPT_PROMPT = '''
|
34 |
+
You are a background editor.
|
35 |
+
Your job is to adjust the background of the image to be in a {{holiday}} vibes, but take into considration the perspective and the logic of the image.
|
36 |
+
Your output should be a prompt that can be used to edit the background of the image.
|
37 |
+
The background should be edited in a way that is consistent with the image.
|
|
|
38 |
'''
|
39 |
|
40 |
+
def image_to_prompt(image: str, holiday: str) -> tuple[str, str]:
|
41 |
base64_image = encode_image(image)
|
42 |
|
43 |
messages = [{
|
44 |
"role": "user",
|
45 |
"content": [
|
46 |
+
{"type": "text", "text": GPT_PROMPT.replace("{{holiday}}", holiday)},
|
47 |
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
|
48 |
]
|
49 |
}]
|
|
|
154 |
# Convert back to PIL Image
|
155 |
return Image.fromarray(dilated_mask.astype(np.uint8))
|
156 |
|
157 |
+
def detect_faces(image: Image.Image) -> list[dict]:
|
158 |
+
# Convert PIL Image to numpy array
|
159 |
+
image_np = np.array(image)
|
160 |
+
faces = face_detection_app.get(image_np)
|
161 |
+
return faces
|
162 |
+
|
163 |
+
def check_text_position(x, y, text_rect_width, text_rect_height, face_rects, image_width, image_height):
|
164 |
+
# Calculate text rectangle bounds
|
165 |
+
text_x1 = x - text_rect_width//2
|
166 |
+
text_y1 = y - text_rect_height//2
|
167 |
+
text_x2 = x + text_rect_width//2
|
168 |
+
text_y2 = y + text_rect_height//2
|
169 |
+
|
170 |
+
# Check if text is within image bounds
|
171 |
+
if (text_x1 < 0 or text_x2 > image_width or
|
172 |
+
text_y1 < 0 or text_y2 > image_height):
|
173 |
+
return False
|
174 |
+
|
175 |
+
# Check for collision with any face
|
176 |
+
for face_rect in face_rects:
|
177 |
+
fx1, fy1, fx2, fy2 = face_rect
|
178 |
+
# Check if rectangles overlap
|
179 |
+
if not (text_x2 < fx1 or text_x1 > fx2 or text_y2 < fy1 or text_y1 > fy2):
|
180 |
+
return False
|
181 |
+
return True
|
182 |
+
|
183 |
+
def find_place_to_add_text(image: Image.Image, faces: list[dict]) -> tuple[int, int]:
|
184 |
+
image_width, image_height = image.size
|
185 |
+
|
186 |
+
# Convert face coordinates to rectangles for collision detection
|
187 |
+
face_rects = []
|
188 |
+
padding = 20 # Padding around faces
|
189 |
+
for face in faces:
|
190 |
+
bbox = face.bbox # Get bounding box coordinates
|
191 |
+
x1, y1, x2, y2 = map(int, bbox)
|
192 |
+
face_rects.append((
|
193 |
+
max(0, x1-padding),
|
194 |
+
max(0, y1-padding),
|
195 |
+
min(image_width, x2+padding),
|
196 |
+
min(image_height, y2+padding)
|
197 |
+
))
|
198 |
+
|
199 |
+
# Define possible text positions
|
200 |
+
padding_x = int(0.1 * image_width)
|
201 |
+
padding_y = int(0.1 * image_height)
|
202 |
+
|
203 |
+
positions = [
|
204 |
+
(image_width//2, int(0.85*image_height) - padding_y), # Bottom center
|
205 |
+
(image_width//2, int(0.15*image_height) + padding_y), # Top center
|
206 |
+
(int(0.15*image_width) + padding_x, image_height//2), # Left middle
|
207 |
+
(int(0.85*image_width) - padding_x, image_height//2) # Right middle
|
208 |
+
]
|
209 |
+
|
210 |
+
# Start with largest desired text size and gradually reduce
|
211 |
+
current_text_width = 0.8
|
212 |
+
current_text_height = 0.3
|
213 |
+
min_text_width = 0.1
|
214 |
+
min_text_height = 0.03
|
215 |
+
reduction_factor = 0.9 # Reduce size by 10% each iteration
|
216 |
+
|
217 |
+
while current_text_width >= min_text_width and current_text_height >= min_text_height:
|
218 |
+
text_rect_width = current_text_width * image_width
|
219 |
+
text_rect_height = current_text_height * image_height
|
220 |
+
|
221 |
+
# Try each position with current size
|
222 |
+
for x, y in positions:
|
223 |
+
if check_text_position(x, y, text_rect_width, text_rect_height,
|
224 |
+
face_rects, image_width, image_height):
|
225 |
+
top_left_x_in_percent = (x - text_rect_width//2) / image_width
|
226 |
+
top_left_y_in_percent = (y - text_rect_height//2) / image_height
|
227 |
+
return top_left_x_in_percent, top_left_y_in_percent, current_text_width, current_text_height
|
228 |
+
|
229 |
+
# If no position works, reduce text size and try again
|
230 |
+
current_text_width *= reduction_factor
|
231 |
+
current_text_height *= reduction_factor
|
232 |
+
|
233 |
+
# If we get here, return bottom center with minimum size as fallback
|
234 |
+
print("Failed to find a suitable position")
|
235 |
+
# Return bottom center with minimum size as fallback
|
236 |
+
return (
|
237 |
+
(image_width//2 - (min_text_width * image_width)//2) / image_width, # x position
|
238 |
+
(int(0.85*image_height) - (min_text_height * image_height)//2) / image_height, # y position
|
239 |
+
min_text_width, # width
|
240 |
+
min_text_height # height
|
241 |
+
)
|
242 |
+
|
243 |
def run_flow(input_image, holiday, message):
|
244 |
+
faces = detect_faces(input_image)
|
245 |
+
|
246 |
+
prompt = image_to_prompt(input_image, holiday)
|
247 |
print(prompt)
|
248 |
result_image, only_background_image, mask = remove_background(input_image)
|
249 |
dilated_mask = dilate_mask(mask)
|
250 |
+
output_image = modify_background(input_image, dilated_mask, prompt)
|
251 |
+
|
252 |
+
# Create a copy of the modified image before drawing
|
253 |
+
output_image_with_text_rectangle = output_image.copy()
|
254 |
+
text_x_in_percent, text_y_in_percent, text_width_in_percent, text_height_in_percent = find_place_to_add_text(input_image, faces)
|
255 |
+
text_x = text_x_in_percent * output_image.width
|
256 |
+
text_y = text_y_in_percent * output_image.height
|
257 |
+
text_width = text_width_in_percent * output_image.width
|
258 |
+
text_height = text_height_in_percent * output_image.height
|
259 |
+
print(text_x, text_y, text_width, text_height)
|
260 |
+
draw = ImageDraw.Draw(output_image_with_text_rectangle)
|
261 |
+
draw.rectangle((text_x, text_y, text_x + text_width, text_y + text_height), outline="red")
|
262 |
+
|
263 |
+
# Return the actual images, not the ImageDraw object
|
264 |
+
return output_image, output_image_with_text_rectangle, text_x_in_percent, text_y_in_percent, text_width_in_percent, text_height_in_percent
|
265 |
|
|
|
266 |
|
267 |
# Replace the demo interface
|
268 |
demo = gr.Interface(
|
|
|
273 |
gr.Text(label="Optional Message", placeholder="Enter your holiday message here...")
|
274 |
],
|
275 |
outputs=[
|
276 |
+
gr.Image(type="pil", label="Output Image"),
|
277 |
+
gr.Image(type="pil", label="Output Image With Text Rectangle"),
|
278 |
+
gr.Number(label="Text Top Left X"),
|
279 |
+
gr.Number(label="Text Top Left Y"),
|
280 |
+
gr.Number(label="Text Width"),
|
281 |
+
gr.Number(label="Text Height")
|
282 |
],
|
283 |
title="Holiday Card Generator",
|
284 |
description="Upload an image to generate a holiday card"
|
requirements.txt
CHANGED
@@ -9,4 +9,6 @@ matplotlib
|
|
9 |
openai
|
10 |
requests
|
11 |
scipy
|
12 |
-
numpy
|
|
|
|
|
|
9 |
openai
|
10 |
requests
|
11 |
scipy
|
12 |
+
numpy
|
13 |
+
insightface
|
14 |
+
onnxruntime
|