import torch from torch import nn from PIL import Image from transformers import CLIPTokenizer, CLIPTextModel, AutoProcessor, T5EncoderModel, T5TokenizerFast from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler from flux.transformer_flux import FluxTransformer2DModel from flux.pipeline_flux_chameleon import FluxPipeline from flux.pipeline_flux_img2img import FluxImg2ImgPipeline from flux.pipeline_flux_inpaint import FluxInpaintPipeline from flux.pipeline_flux_controlnet import FluxControlNetPipeline, FluxControlNetModel from flux.pipeline_flux_controlnet_img2img import FluxControlNetImg2ImgPipeline from flux.controlnet_flux import FluxMultiControlNetModel from flux.pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel import os import cv2 import numpy as np import math def get_model_path(model_name): """Get the full path for a model based on the checkpoints directory.""" base_dir = os.getenv('CHECKPOINT_DIR', 'checkpoints') # Allow environment variable override return os.path.join(base_dir, model_name) # Model paths configuration MODEL_PATHS = { 'flux': get_model_path('flux'), 'qwen2vl': get_model_path('qwen2-vl'), 'controlnet': get_model_path('controlnet'), 'depth_anything': { 'path': get_model_path('depth-anything-v2'), 'weights': 'depth_anything_v2_vitl.pth' }, 'anyline': { 'path': get_model_path('anyline'), 'weights': 'MTEED.pth' }, 'sam2': { 'path': get_model_path('segment-anything-2'), 'weights': 'sam2_hiera_large.pt', 'config': 'sam2_hiera_l.yaml' } } ASPECT_RATIOS = { "1:1": (1024, 1024), "16:9": (1344, 768), "9:16": (768, 1344), "2.4:1": (1536, 640), "3:4": (896, 1152), "4:3": (1152, 896), } class Qwen2Connector(nn.Module): def __init__(self, input_dim=3584, output_dim=4096): super().__init__() self.linear = nn.Linear(input_dim, output_dim) def forward(self, x): return self.linear(x) class FluxModel: def __init__(self, is_turbo=False, device="cuda", required_features=None): """ Initialize FluxModel with specified features Args: is_turbo: Enable turbo mode for faster inference device: Device to run the model on required_features: List of required features ['controlnet', 'depth', 'line', 'sam'] """ self.device = torch.device(device) self.dtype = torch.bfloat16 if required_features is None: required_features = [] self._line_detector_imported = False self._depth_model_imported = False self._sam_imported = False self._turbo_imported = False # Initialize base models (always required) self._init_base_models() # Initialize optional models based on requirements if 'controlnet' in required_features or any(f in required_features for f in ['depth', 'line']): self._init_controlnet() if 'depth' in required_features: self._init_depth_model() if 'line' in required_features: self._init_line_detector() if 'sam' in required_features: self._init_sam() if is_turbo: self._enable_turbo() def _init_base_models(self): """Initialize the core models that are always needed""" # Qwen2VL and connector initialization self.qwen2vl = Qwen2VLSimplifiedModel.from_pretrained( MODEL_PATHS['qwen2vl'], torch_dtype=self.dtype ) self.qwen2vl.requires_grad_(False).to(self.device) self.connector = Qwen2Connector(input_dim=3584, output_dim=4096) connector_path = os.path.join(MODEL_PATHS['qwen2vl'], "connector.pt") if os.path.exists(connector_path): connector_state_dict = torch.load(connector_path, map_location=self.device, weights_only=True) connector_state_dict = {k.replace('module.', ''): v for k, v in connector_state_dict.items()} self.connector.load_state_dict(connector_state_dict) self.connector.to(self.dtype).to(self.device) # Text encoders initialization self.tokenizer = CLIPTokenizer.from_pretrained(MODEL_PATHS['flux'], subfolder="tokenizer") self.text_encoder = CLIPTextModel.from_pretrained(MODEL_PATHS['flux'], subfolder="text_encoder") self.text_encoder_two = T5EncoderModel.from_pretrained(MODEL_PATHS['flux'], subfolder="text_encoder_2") self.tokenizer_two = T5TokenizerFast.from_pretrained(MODEL_PATHS['flux'], subfolder="tokenizer_2") self.text_encoder.requires_grad_(False).to(self.dtype).to(self.device) self.text_encoder_two.requires_grad_(False).to(self.dtype).to(self.device) # T5 context embedder self.t5_context_embedder = nn.Linear(4096, 3072) t5_embedder_path = os.path.join(MODEL_PATHS['qwen2vl'], "t5_embedder.pt") t5_embedder_state_dict = torch.load(t5_embedder_path, map_location=self.device, weights_only=True) self.t5_context_embedder.load_state_dict(t5_embedder_state_dict) self.t5_context_embedder.to(self.dtype).to(self.device) # Basic components self.noise_scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(MODEL_PATHS['flux'], subfolder="scheduler", shift=1) self.vae = AutoencoderKL.from_pretrained(MODEL_PATHS['flux'], subfolder="vae") self.transformer = FluxTransformer2DModel.from_pretrained(MODEL_PATHS['flux'], subfolder="transformer") self.vae.requires_grad_(False).to(self.dtype).to(self.device) self.transformer.requires_grad_(False).to(self.dtype).to(self.device) def _init_controlnet(self): """Initialize ControlNet model""" self.controlnet_union = FluxControlNetModel.from_pretrained( MODEL_PATHS['controlnet'], torch_dtype=torch.bfloat16 ) self.controlnet_union.requires_grad_(False).to(self.device) self.controlnet = FluxMultiControlNetModel([self.controlnet_union]) def _init_depth_model(self): """Initialize Depth Anything V2 model""" if not self._depth_model_imported: from depth_anything_v2.dpt import DepthAnythingV2 self._depth_model_imported = True self.depth_model = DepthAnythingV2( encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024] ) depth_weights = os.path.join(MODEL_PATHS['depth_anything']['path'], MODEL_PATHS['depth_anything']['weights']) self.depth_model.load_state_dict(torch.load(depth_weights, map_location=self.device)) self.depth_model.requires_grad_(False).to(self.device) def _init_line_detector(self): """Initialize line detection model""" if not self._line_detector_imported: from controlnet_aux import AnylineDetector self._line_detector_imported = True self.anyline = AnylineDetector.from_pretrained( MODEL_PATHS['anyline']['path'], filename=MODEL_PATHS['anyline']['weights'] ) self.anyline.to(self.device) def _init_sam(self): """Initialize SAM2 model""" if not self._sam_imported: from sam2.build_sam import build_sam2 from sam2.sam2_image_predictor import SAM2ImagePredictor self._sam_imported = True sam2_checkpoint = os.path.join(MODEL_PATHS['sam2']['path'], MODEL_PATHS['sam2']['weights']) model_cfg = os.path.join(MODEL_PATHS['sam2']['path'], MODEL_PATHS['sam2']['config']) self.sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=self.device) self.sam2_predictor = SAM2ImagePredictor(self.sam2_model) def _enable_turbo(self): """Enable turbo mode for faster inference""" if not self._turbo_imported: from optimum.quanto import freeze, qfloat8, quantize self._turbo_imported = True quantize( self.transformer, weights=qfloat8, exclude=[ "*.norm", "*.norm1", "*.norm2", "*.norm2_context", "proj_out", "x_embedder", "norm_out", "context_embedder", ], ) freeze(self.transformer) def generate_mask(self, image, input_points, input_labels): """ 使用SAM2生成分割mask Args: image: PIL Image或numpy数组 input_points: numpy数组,形状为(N, 2),包含点的坐标 input_labels: numpy数组,形状为(N,),1表示前景点,0表示背景点 Returns: PIL Image: 最高分数的mask """ try: # 确保图像是numpy数组 if isinstance(image, Image.Image): image_array = np.array(image) else: image_array = image # 设置图像 self.sam2_predictor.set_image(image_array) # 进行预测 with torch.inference_mode(): masks, scores, logits = self.sam2_predictor.predict( point_coords=input_points, point_labels=input_labels, multimask_output=True, ) # 返回得分最高的mask best_mask_idx = scores.argmax() mask = masks[best_mask_idx] mask_image = Image.fromarray((mask * 255).astype(np.uint8)) return mask_image except Exception as e: print(f"Mask generation failed: {str(e)}") raise def recover_2d_shape(self, image_hidden_state, grid_thw): batch_size, num_tokens, hidden_dim = image_hidden_state.shape _, h, w = grid_thw h_out = h // 2 w_out = w // 2 # 重塑为 (batch_size, height, width, hidden_dim) reshaped = image_hidden_state.view(batch_size, h_out, w_out, hidden_dim) return reshaped def generate_attention_matrix(self, center_x, center_y, radius, image_shape): height, width = image_shape y, x = np.ogrid[:height, :width] center_y, center_x = center_y * height, center_x * width distances = np.sqrt((x - center_x)**2 + (y - center_y)**2) attention = np.clip(1 - distances / (radius * min(height, width)), 0, 1) return attention def apply_attention(self, image_hidden_state, image_grid_thw, center_x, center_y, radius): qwen2_2d_image_embedding = self.recover_2d_shape(image_hidden_state, tuple(image_grid_thw.tolist()[0])) attention_matrix = self.generate_attention_matrix( center_x, center_y, radius, (qwen2_2d_image_embedding.size(1), qwen2_2d_image_embedding.size(2)) ) attention_tensor = torch.from_numpy(attention_matrix).to(self.dtype).unsqueeze(0).unsqueeze(-1) qwen2_2d_image_embedding = qwen2_2d_image_embedding * attention_tensor.to(self.device) return qwen2_2d_image_embedding.view(1, -1, qwen2_2d_image_embedding.size(3)) def compute_text_embeddings(self, prompt): with torch.no_grad(): text_inputs = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt") text_input_ids = text_inputs.input_ids.to(self.device) prompt_embeds = self.text_encoder(text_input_ids, output_hidden_states=False) pooled_prompt_embeds = prompt_embeds.pooler_output return pooled_prompt_embeds.to(self.dtype) def compute_t5_text_embeddings( self, max_sequence_length=256, prompt=None, num_images_per_prompt=1, device=None, ): prompt = [prompt] if isinstance(prompt, str) else prompt batch_size = len(prompt) text_inputs = self.tokenizer_two( prompt, padding="max_length", max_length=max_sequence_length, truncation=True, return_length=False, return_overflowing_tokens=False, return_tensors="pt", ) text_input_ids = text_inputs.input_ids prompt_embeds = self.text_encoder_two(text_input_ids.to(device))[0] dtype = self.text_encoder_two.dtype prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) _, seq_len, _ = prompt_embeds.shape # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) return prompt_embeds def process_image(self, image): message = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": "Describe this image."}, ] } ] text = self.qwen2vl_processor.apply_chat_template(message, tokenize=False, add_generation_prompt=True) with torch.no_grad(): inputs = self.qwen2vl_processor(text=[text], images=[image], padding=True, return_tensors="pt").to(self.device) output_hidden_state, image_token_mask, image_grid_thw = self.qwen2vl(**inputs) image_hidden_state = output_hidden_state[image_token_mask].view(1, -1, output_hidden_state.size(-1)) return image_hidden_state, image_grid_thw def resize_image(self, img, max_pixels=1050000): # 确保输入是 PIL Image if not isinstance(img, Image.Image): img = Image.fromarray(img) width, height = img.size num_pixels = width * height if num_pixels > max_pixels: scale = math.sqrt(max_pixels / num_pixels) new_width = int(width * scale) new_height = int(height * scale) # 调整宽度和高度,使其能被8整除 new_width = new_width - (new_width % 8) new_height = new_height - (new_height % 8) img = img.resize((new_width, new_height), Image.LANCZOS) else: # 如果图片不需要缩小,仍然需要确保尺寸能被8整除 new_width = width - (width % 8) new_height = height - (height % 8) if new_width != width or new_height != height: img = img.resize((new_width, new_height), Image.LANCZOS) return img def generate_depth_map(self, image): """Generate depth map using Depth Anything V2""" # Convert PIL to numpy array image_np = np.array(image) # Convert RGB to BGR for cv2 image_bgr = cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR) # Generate depth map with torch.no_grad(): depth = self.depth_model.infer_image(image_bgr) # Normalize depth to 0-1 range depth_norm = (depth - depth.min()) / (depth.max() - depth.min()) # Convert to RGB image depth_rgb = (depth_norm * 255).astype(np.uint8) depth_rgb = cv2.cvtColor(depth_rgb, cv2.COLOR_GRAY2RGB) return Image.fromarray(depth_rgb) def generate(self, input_image_a, input_image_b=None, prompt="", guidance_scale=3.5, num_inference_steps=28, aspect_ratio="1:1", center_x=None, center_y=None, radius=None, mode="variation", denoise_strength=0.8, mask_image=None, imageCount=2, line_mode=True, depth_mode=True, line_strength=0.4, depth_strength=0.2): batch_size = imageCount if aspect_ratio not in ASPECT_RATIOS: raise ValueError(f"Invalid aspect ratio. Choose from {list(ASPECT_RATIOS.keys())}") width, height = ASPECT_RATIOS[aspect_ratio] pooled_prompt_embeds = self.compute_text_embeddings(prompt="") t5_prompt_embeds = None if prompt != "": self.qwen2vl_processor = AutoProcessor.from_pretrained(MODEL_PATHS['qwen2vl'], min_pixels=256*28*28, max_pixels=256*28*28) t5_prompt_embeds = self.compute_t5_text_embeddings(prompt=prompt, device=self.device) t5_prompt_embeds = self.t5_context_embedder(t5_prompt_embeds) else: self.qwen2vl_processor = AutoProcessor.from_pretrained(MODEL_PATHS['qwen2vl'], min_pixels=512*28*28, max_pixels=512*28*28) qwen2_hidden_state_a, image_grid_thw_a = self.process_image(input_image_a) # 只有当所有注意力参数都被提供时,才应用注意力机制 if mode == "variation": if center_x is not None and center_y is not None and radius is not None: qwen2_hidden_state_a = self.apply_attention(qwen2_hidden_state_a, image_grid_thw_a, center_x, center_y, radius) qwen2_hidden_state_a = self.connector(qwen2_hidden_state_a) if mode == "img2img" or mode == "inpaint": if input_image_b: qwen2_hidden_state_b, image_grid_thw_b = self.process_image(input_image_b) if center_x is not None and center_y is not None and radius is not None: qwen2_hidden_state_b = self.apply_attention(qwen2_hidden_state_b, image_grid_thw_b, center_x, center_y, radius) qwen2_hidden_state_b = self.connector(qwen2_hidden_state_b) else: qwen2_hidden_state_a = self.connector(qwen2_hidden_state_a) qwen2_hidden_state_b = None if mode == "controlnet" or mode == "controlnet-inpaint": qwen2_hidden_state_b = None if input_image_b: qwen2_hidden_state_b, image_grid_thw_b = self.process_image(input_image_b) if center_x is not None and center_y is not None and radius is not None: qwen2_hidden_state_b = self.apply_attention(qwen2_hidden_state_b, image_grid_thw_b, center_x, center_y, radius) qwen2_hidden_state_b = self.connector(qwen2_hidden_state_b) qwen2_hidden_state_a = self.connector(qwen2_hidden_state_a) ############################# # IMAGE GENERATION ############################# if mode == "variation": # Initialize different pipelines pipeline = FluxPipeline( transformer=self.transformer, scheduler=self.noise_scheduler, vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, ) gen_images = pipeline( prompt_embeds=qwen2_hidden_state_a.repeat(batch_size, 1, 1), t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None, pooled_prompt_embeds=pooled_prompt_embeds, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, height=height, width=width, ).images ############################# # IMAGE-TO-IMAGE ############################# elif mode == "img2img": input_image_a = self.resize_image(input_image_a) width, height = input_image_a.size img2img_pipeline = FluxImg2ImgPipeline( transformer=self.transformer, scheduler=self.noise_scheduler, vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, ) gen_images = img2img_pipeline( image=input_image_a, strength=denoise_strength, prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1), t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None, pooled_prompt_embeds=pooled_prompt_embeds, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, height=height, width=width, ).images ############################# # INPAINTING ############################# elif mode == "inpaint": if mask_image is None: raise ValueError("Mask image is required for inpainting mode") input_image_a = self.resize_image(input_image_a) mask_image = self.resize_image(mask_image) width, height = input_image_a.size inpaint_pipeline = FluxInpaintPipeline( transformer=self.transformer, scheduler=self.noise_scheduler, vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, ) gen_images = inpaint_pipeline( image=input_image_a, mask_image=mask_image, strength=denoise_strength, prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1), t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None, pooled_prompt_embeds=pooled_prompt_embeds, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, height=height, width=width, ).images ############################# # CONTROLNET ############################# elif mode == "controlnet": input_image_a = self.resize_image(input_image_a) width, height = input_image_a.size controlnet_pipeline = FluxControlNetImg2ImgPipeline( transformer=self.transformer, scheduler=self.noise_scheduler, vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, controlnet=self.controlnet, ) # 准备控制图像和模式列表 control_images = [] control_modes = [] conditioning_scales = [] # 根据用户选择添加控制模式 if depth_mode: control_image_depth = self.generate_depth_map(input_image_a) control_images.append(control_image_depth) control_modes.append(2) # depth mode conditioning_scales.append(depth_strength) if line_mode: control_image_canny = self.anyline(input_image_a, detect_resolution=1280) control_images.append(control_image_canny) control_modes.append(0) # line mode conditioning_scales.append(line_strength) # 如果没有启用任何模式,默认使用line+depth模式 if not line_mode and not depth_mode: control_image_depth = self.generate_depth_map(input_image_a) control_image_canny = self.anyline(input_image_a, detect_resolution=1280) control_images = [control_image_depth, control_image_canny] control_modes = [2, 0] conditioning_scales = [0.2, 0.4] if qwen2_hidden_state_b is not None: qwen2_hidden_state_b = qwen2_hidden_state_b[:, :qwen2_hidden_state_a.shape[1], :] qwen2_hidden_state_a = qwen2_hidden_state_a[:, :qwen2_hidden_state_b.shape[1], :] gen_images = controlnet_pipeline( image=input_image_a, strength=denoise_strength, control_image=control_images, control_mode=control_modes, controlnet_conditioning_scale=conditioning_scales, prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1), t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None, prompt_embeds_control=qwen2_hidden_state_a.repeat(batch_size, 1, 1), pooled_prompt_embeds=pooled_prompt_embeds, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, height=height, width=width, ).images ############################# # CONTROLNET INPAINT ############################# elif mode == "controlnet-inpaint": input_image_a = self.resize_image(input_image_a) mask_image = self.resize_image(mask_image) width, height = input_image_a.size controlnet_pipeline = FluxControlNetInpaintPipeline( transformer=self.transformer, scheduler=self.noise_scheduler, vae=self.vae, text_encoder=self.text_encoder, tokenizer=self.tokenizer, controlnet=self.controlnet, ) # 准备控制图像和模式列表 control_images = [] control_modes = [] conditioning_scales = [] # 根据用户选择添加控制模式 if depth_mode: control_image_depth = self.generate_depth_map(input_image_a) control_images.append(control_image_depth) control_modes.append(2) # depth mode conditioning_scales.append(depth_strength) if line_mode: control_image_canny = self.anyline(input_image_a, detect_resolution=1280) control_images.append(control_image_canny) control_modes.append(0) # line mode conditioning_scales.append(line_strength) # 如果没有启用任何模式,默认使用line+depth模式 if not line_mode and not depth_mode: control_image_depth = self.generate_depth_map(input_image_a) control_image_canny = self.anyline(input_image_a, detect_resolution=1280) control_images = [control_image_depth, control_image_canny] control_modes = [2, 0] conditioning_scales = [0.2, 0.4] if qwen2_hidden_state_b is not None: qwen2_hidden_state_b = qwen2_hidden_state_b[:, :qwen2_hidden_state_a.shape[1], :] qwen2_hidden_state_a = qwen2_hidden_state_a[:, :qwen2_hidden_state_b.shape[1], :] gen_images = controlnet_pipeline( image=input_image_a, mask_image=mask_image, control_image=control_images, control_mode=control_modes, controlnet_conditioning_scale=conditioning_scales, strength=denoise_strength, prompt_embeds=qwen2_hidden_state_b.repeat(batch_size, 1, 1) if qwen2_hidden_state_b is not None else qwen2_hidden_state_a.repeat(batch_size, 1, 1), t5_prompt_embeds=t5_prompt_embeds.repeat(batch_size, 1, 1) if t5_prompt_embeds is not None else None, prompt_embeds_control=qwen2_hidden_state_a.repeat(batch_size, 1, 1), pooled_prompt_embeds=pooled_prompt_embeds, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, height=height, width=width, ).images else: raise ValueError(f"Invalid mode: {mode}") return gen_images