import gradio as gr from PIL import Image, ImageDraw import matplotlib.pyplot as plt import torch from torchvision import transforms from transformers import AutoModelForImageSegmentation from openai import OpenAI import os import base64 import io import requests import numpy as np from scipy import ndimage from insightface.app import FaceAnalysis IDEOGRAM_API_KEY = os.getenv('IDEOGRAM_API_KEY') IDEOGRAM_URL = "https://api.ideogram.ai/edit" face_detection_app = FaceAnalysis(allowed_modules=['detection']) # enable detection model only face_detection_app.prepare(ctx_id=0, det_size=(640, 640)) client = OpenAI(api_key=os.getenv('OPENAI_API_KEY')) # Constants should be in UPPERCASE GPT_MODEL_NAME = "gpt-4o" GPT_MAX_TOKENS = 500 model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True) torch.set_float32_matmul_precision(['high', 'highest'][0]) if torch.cuda.is_available(): model = model.to('cuda') model.eval() GPT_PROMPT = ''' You are a background editor. Your job is to adjust the background of the image to be in a {{holiday}} vibes, but take into considration the perspective and the logic of the image. Your output should be a prompt that can be used to edit the background of the image. The background should be edited in a way that is consistent with the image. The prompt should not include any text or writing in the background. ''' def image_to_prompt(image: str, holiday: str) -> tuple[str, str]: base64_image = encode_image(image) messages = [{ "role": "user", "content": [ {"type": "text", "text": GPT_PROMPT.replace("{{holiday}}", holiday)}, {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}} ] }] response = client.chat.completions.create( model=GPT_MODEL_NAME, messages=messages, max_tokens=GPT_MAX_TOKENS ) full_response = response.choices[0].message.content return full_response def encode_image(image: Image.Image) -> str: """Convert a PIL Image to base64 encoded string. Args: image (PIL.Image.Image): The PIL Image to encode Returns: str: Base64 encoded image string """ # Create a temporary buffer to save the image buffer = io.BytesIO() # Save the image as PNG to the buffer image.save(buffer, format='PNG') # Get the bytes from the buffer and encode to base64 return base64.b64encode(buffer.getvalue()).decode('utf-8') def remove_background(input_image): image_size = (1024, 1024) # Transform the input image transform_image = transforms.Compose([ transforms.Resize(image_size), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Process the image input_tensor = transform_image(input_image).unsqueeze(0) if torch.cuda.is_available(): input_tensor = input_tensor.to('cuda') # Generate prediction with torch.no_grad(): preds = model(input_tensor)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(input_image.size) # Create image without background result_image = input_image.copy() result_image.putalpha(mask) # Create image with only background only_background_image = input_image.copy() inverted_mask = Image.eval(mask, lambda x: 255 - x) # Invert the mask only_background_image.putalpha(inverted_mask) return result_image, only_background_image, mask def modify_background(image: Image.Image, mask: Image.Image, prompt: str) -> Image.Image: # Convert PIL images to bytes image_buffer = io.BytesIO() image.save(image_buffer, format='PNG') image_bytes = image_buffer.getvalue() mask_buffer = io.BytesIO() mask.save(mask_buffer, format='PNG') mask_bytes = mask_buffer.getvalue() # Create the files dictionary with actual bytes data files = { "image_file": ("image.png", image_bytes, "image/png"), "mask": ("mask.png", mask_bytes, "image/png") # You might want to send a different mask file } prevent_text_in_background = "Do not include any text or writing in the background." prompt = f"{prompt} {prevent_text_in_background}" payload = { "prompt": prompt, # Use the actual prompt parameter "model": "V_2", "magic_prompt_option": "ON", "num_images": 1, "style_type": "REALISTIC" } headers = {"Api-Key": IDEOGRAM_API_KEY} response = requests.post(IDEOGRAM_URL, data=payload, files=files, headers=headers) if response.status_code == 200: # Assuming the API returns an image in the response response_data = response.json() # You'll need to handle the response according to Ideogram's API specification # This is a placeholder - adjust according to actual API response format result_image_url = response_data.get('data')[0].get('url') if result_image_url: result_response = requests.get(result_image_url) return Image.open(io.BytesIO(result_response.content)) raise Exception(f"Failed to modify background: {response.text}") def dilate_mask(mask: Image.Image) -> Image.Image: # Convert mask to numpy array mask_array = np.array(mask) # Apply maximum filter using scipy.ndimage dilated_mask = ndimage.maximum_filter(mask_array, size=20) # Convert back to PIL Image return Image.fromarray(dilated_mask.astype(np.uint8)) def detect_faces(image: Image.Image) -> list[dict]: # Convert PIL Image to numpy array image_np = np.array(image) faces = face_detection_app.get(image_np) return faces def check_text_position(x, y, text_rect_width, text_rect_height, face_rects, image_width, image_height): # Calculate text rectangle bounds text_x1 = x - text_rect_width//2 text_y1 = y - text_rect_height//2 text_x2 = x + text_rect_width//2 text_y2 = y + text_rect_height//2 # Check if text is within image bounds if (text_x1 < 0 or text_x2 > image_width or text_y1 < 0 or text_y2 > image_height): return False # Check for collision with any face for face_rect in face_rects: fx1, fy1, fx2, fy2 = face_rect # Check if rectangles overlap if not (text_x2 < fx1 or text_x1 > fx2 or text_y2 < fy1 or text_y1 > fy2): return False return True def find_place_to_add_text(image: Image.Image, faces: list[dict]) -> tuple[int, int]: image_width, image_height = image.size # Convert face coordinates to rectangles for collision detection face_rects = [] padding = 20 # Padding around faces for face in faces: bbox = face.bbox # Get bounding box coordinates x1, y1, x2, y2 = map(int, bbox) face_rects.append(( max(0, x1-padding), max(0, y1-padding), min(image_width, x2+padding), min(image_height, y2+padding) )) # Define possible text positions padding_x = int(0.1 * image_width) padding_y = int(0.1 * image_height) positions = [ (image_width//2, int(0.85*image_height) - padding_y), # Bottom center (image_width//2, int(0.15*image_height) + padding_y), # Top center (int(0.15*image_width) + padding_x, image_height//2), # Left middle (int(0.85*image_width) - padding_x, image_height//2) # Right middle ] # Start with largest desired text size and gradually reduce current_text_width = 0.8 current_text_height = 0.3 min_text_width = 0.1 min_text_height = 0.03 reduction_factor = 0.9 # Reduce size by 10% each iteration while current_text_width >= min_text_width and current_text_height >= min_text_height: text_rect_width = current_text_width * image_width text_rect_height = current_text_height * image_height # Try each position with current size for x, y in positions: if check_text_position(x, y, text_rect_width, text_rect_height, face_rects, image_width, image_height): top_left_x_in_percent = (x - text_rect_width//2) / image_width top_left_y_in_percent = (y - text_rect_height//2) / image_height return top_left_x_in_percent, top_left_y_in_percent, current_text_width, current_text_height # If no position works, reduce text size and try again current_text_width *= reduction_factor current_text_height *= reduction_factor # If we get here, return bottom center with minimum size as fallback print("Failed to find a suitable position") # Return bottom center with minimum size as fallback return ( (image_width//2 - (min_text_width * image_width)//2) / image_width, # x position (int(0.85*image_height) - (min_text_height * image_height)//2) / image_height, # y position min_text_width, # width min_text_height # height ) def crop_to_ratio_while_preventing_faces(image: Image.Image, faces: list[dict]) -> Image.Image: ASPECT_RATIO_PORTRAIT = 5/7 ASPECT_RATIO_LANDSCAPE = 7/5 image_width, image_height = image.size # Calculate current aspect ratio current_ratio = image_width / image_height is_portrait = current_ratio < 1 target_ratio = ASPECT_RATIO_PORTRAIT if is_portrait else ASPECT_RATIO_LANDSCAPE # Calculate new dimensions if current_ratio > target_ratio: new_width = int(image_height * target_ratio) new_height = image_height else: new_width = image_width new_height = int(image_width / target_ratio) # If no faces, just do center crop if not faces: x = (image_width - new_width) // 2 y = (image_height - new_height) // 2 return image.crop((x, y, x + new_width, y + new_height)) # Find the bounding box that contains all faces face_x1 = min(int(face['bbox'][0]) for face in faces) face_y1 = min(int(face['bbox'][1]) for face in faces) face_x2 = max(int(face['bbox'][2]) for face in faces) face_y2 = max(int(face['bbox'][3]) for face in faces) # Add padding around faces padding = 50 face_x1 = max(0, face_x1 - padding) face_y1 = max(0, face_y1 - padding) face_x2 = min(image_width, face_x2 + padding) face_y2 = min(image_height, face_y2 + padding) # Calculate crop coordinates that ensure faces are included x = max(0, min(face_x1, image_width - new_width)) y = max(0, min(face_y1, image_height - new_height)) # Adjust if faces would be cut off if x + new_width < face_x2: x = max(0, face_x2 - new_width) if y + new_height < face_y2: y = max(0, face_y2 - new_height) return image.crop((x, y, x + new_width, y + new_height)) def run_flow(input_image, holiday, message): faces = detect_faces(input_image) cropped_image = crop_to_ratio_while_preventing_faces(input_image, faces) prompt = image_to_prompt(cropped_image, holiday) print(prompt) result_image, only_background_image, mask = remove_background(cropped_image) dilated_mask = dilate_mask(mask) output_image = modify_background(cropped_image, dilated_mask, prompt) # Create a copy of the modified image before drawing output_image_with_text_rectangle = output_image.copy() text_x_in_percent, text_y_in_percent, text_width_in_percent, text_height_in_percent = find_place_to_add_text(cropped_image, faces) text_x = text_x_in_percent * output_image.width text_y = text_y_in_percent * output_image.height text_width = text_width_in_percent * output_image.width text_height = text_height_in_percent * output_image.height draw = ImageDraw.Draw(output_image_with_text_rectangle) draw.rectangle((text_x, text_y, text_x + text_width, text_y + text_height), outline="red") # Return the actual images, not the ImageDraw object return output_image, output_image_with_text_rectangle, text_x_in_percent, text_y_in_percent, text_width_in_percent, text_height_in_percent # Replace the demo interface demo = gr.Interface( fn=run_flow, inputs=[ gr.Image(type="pil"), gr.Text(label="Holiday (e.g. Christmas, New Year's, etc.)"), gr.Text(label="Optional Message", placeholder="Enter your holiday message here...") ], outputs=[ gr.Image(type="pil", label="Output Image"), gr.Image(type="pil", label="Output Image With Text Rectangle"), gr.Number(label="Text Top Left X"), gr.Number(label="Text Top Left Y"), gr.Number(label="Text Width"), gr.Number(label="Text Height") ], title="Holiday Card Generator", description="Upload an image to generate a holiday card" ) demo.launch()