| import colorgram |
| import cv2 |
| import numpy as np |
| from PIL import Image |
| import json |
| import torch |
| from transformers import Pix2StructForConditionalGeneration, Pix2StructProcessor |
| import functools |
|
|
|
|
| class DesignTokenExtractor: |
| def __init__(self): |
| |
| self.pix2struct_model = None |
| self.pix2struct_processor = None |
| self._load_models() |
| |
| @functools.lru_cache(maxsize=1) |
| def _load_models(self): |
| """Load models with caching to prevent repeated initialization""" |
| try: |
| self.pix2struct_processor = Pix2StructProcessor.from_pretrained( |
| "google/pix2struct-screen2words-base" |
| ) |
| self.pix2struct_model = Pix2StructForConditionalGeneration.from_pretrained( |
| "google/pix2struct-screen2words-base", |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
| ) |
| except Exception as e: |
| print(f"Warning: Could not load Pix2Struct model: {e}") |
| |
| |
| def extract_colors(self, image_path, num_colors=8): |
| """Extract dominant colors using colorgram""" |
| try: |
| colors = colorgram.extract(image_path, num_colors) |
| palette = {} |
| |
| for i, color in enumerate(colors): |
| |
| if i == 0 and color.proportion > 0.3: |
| name = "background" |
| elif i == 1: |
| name = "primary" |
| elif i == 2: |
| name = "secondary" |
| else: |
| name = f"accent-{i-2}" |
| |
| palette[name] = { |
| "hex": f"#{color.rgb.r:02x}{color.rgb.g:02x}{color.rgb.b:02x}", |
| "rgb": f"rgb({color.rgb.r}, {color.rgb.g}, {color.rgb.b})", |
| "proportion": round(color.proportion, 3) |
| } |
| |
| return palette |
| except Exception as e: |
| print(f"Error extracting colors: {e}") |
| return self._get_default_colors() |
| |
| def detect_spacing(self, image): |
| """Analyze spacing patterns using OpenCV""" |
| try: |
| gray = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2GRAY) |
| edges = cv2.Canny(gray, 50, 150) |
| |
| |
| contours, _ = cv2.findContours(edges, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| |
| |
| bounding_boxes = [cv2.boundingRect(c) for c in contours if cv2.contourArea(c) > 100] |
| |
| if len(bounding_boxes) > 1: |
| |
| bounding_boxes.sort(key=lambda x: x[1]) |
| |
| vertical_gaps = [] |
| for i in range(len(bounding_boxes)-1): |
| gap = bounding_boxes[i+1][1] - (bounding_boxes[i][1] + bounding_boxes[i][3]) |
| if gap > 0: |
| vertical_gaps.append(gap) |
| |
| |
| spacing_system = self._cluster_spacing_values(vertical_gaps) |
| return spacing_system |
| except Exception as e: |
| print(f"Error detecting spacing: {e}") |
| |
| return {"small": "8px", "medium": "16px", "large": "32px"} |
| |
| def _cluster_spacing_values(self, gaps): |
| """Group similar spacing values""" |
| if not gaps: |
| return {"small": "8px", "medium": "16px", "large": "32px"} |
| |
| gaps.sort() |
| |
| |
| unique_gaps = list(set(gaps)) |
| |
| if len(unique_gaps) >= 3: |
| return { |
| "small": f"{unique_gaps[0]}px", |
| "medium": f"{unique_gaps[len(unique_gaps)//2]}px", |
| "large": f"{unique_gaps[-1]}px" |
| } |
| elif len(unique_gaps) == 2: |
| return { |
| "small": f"{unique_gaps[0]}px", |
| "large": f"{unique_gaps[1]}px" |
| } |
| |
| return {"base": f"{unique_gaps[0]}px" if unique_gaps else "16px"} |
| |
| def analyze_components(self, image): |
| """Use Pix2Struct for component understanding""" |
| if self.pix2struct_model is None or self.pix2struct_processor is None: |
| |
| return { |
| "detected_elements": "Model not available - basic extraction only", |
| "layout": "responsive" |
| } |
| |
| try: |
| inputs = self.pix2struct_processor(images=image, return_tensors="pt") |
| |
| with torch.no_grad(): |
| generated_ids = self.pix2struct_model.generate(**inputs, max_length=100) |
| |
| description = self.pix2struct_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
| |
| |
| components = { |
| "detected_elements": description, |
| "layout": "responsive" if "responsive" in description.lower() else "fixed" |
| } |
| |
| return components |
| except Exception as e: |
| print(f"Error analyzing components: {e}") |
| return { |
| "detected_elements": "Error during analysis", |
| "layout": "responsive" |
| } |
| |
| def detect_typography(self, image): |
| """Basic typography detection""" |
| |
| return { |
| "heading": { |
| "family": "sans-serif", |
| "size": "32px", |
| "weight": "700" |
| }, |
| "body": { |
| "family": "sans-serif", |
| "size": "16px", |
| "weight": "400" |
| }, |
| "caption": { |
| "family": "sans-serif", |
| "size": "14px", |
| "weight": "400" |
| } |
| } |
| |
| def _get_default_colors(self): |
| """Return default color palette""" |
| return { |
| "primary": {"hex": "#3B82F6", "rgb": "rgb(59, 130, 246)", "proportion": 0.25}, |
| "secondary": {"hex": "#8B5CF6", "rgb": "rgb(139, 92, 246)", "proportion": 0.15}, |
| "background": {"hex": "#FFFFFF", "rgb": "rgb(255, 255, 255)", "proportion": 0.40}, |
| "text": {"hex": "#1F2937", "rgb": "rgb(31, 41, 55)", "proportion": 0.20} |
| } |
| |
| def resize_for_processing(self, image, max_dimension=1024): |
| """Resize large images while maintaining aspect ratio""" |
| if max(image.size) > max_dimension: |
| ratio = max_dimension / max(image.size) |
| new_size = tuple(int(dim * ratio) for dim in image.size) |
| return image.resize(new_size, Image.Resampling.LANCZOS) |
| return image |