import torch import torchvision.transforms.functional as F import io import os from typing import List import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt import matplotlib.patches as patches from PIL import Image, ImageDraw, ImageColor, ImageFont import random import numpy as np import re #workaround for unnecessary flash_attn requirement from unittest.mock import patch from transformers.dynamic_module_utils import get_imports def fixed_get_imports(filename: str | os.PathLike) -> list[str]: if not str(filename).endswith("modeling_florence2.py"): return get_imports(filename) imports = get_imports(filename) imports.remove("flash_attn") return imports import comfy.model_management as mm from comfy.utils import ProgressBar import folder_paths script_directory = os.path.dirname(os.path.abspath(__file__)) from transformers import AutoModelForCausalLM, AutoProcessor class DownloadAndLoadFlorence2Model: @classmethod def INPUT_TYPES(s): return {"required": { "model": ( [ 'microsoft/Florence-2-base', 'microsoft/Florence-2-base-ft', 'microsoft/Florence-2-large', 'microsoft/Florence-2-large-ft', 'HuggingFaceM4/Florence-2-DocVQA' ], { "default": 'microsoft/Florence-2-base' }), "precision": ([ 'fp16','bf16','fp32'], { "default": 'fp16' }), "attention": ( [ 'flash_attention_2', 'sdpa', 'eager'], { "default": 'sdpa' }), }, } RETURN_TYPES = ("FL2MODEL",) RETURN_NAMES = ("florence2_model",) FUNCTION = "loadmodel" CATEGORY = "Florence2" def loadmodel(self, model, precision, attention): device = mm.get_torch_device() offload_device = mm.unet_offload_device() dtype = {"bf16": torch.bfloat16, "fp16": torch.float16, "fp32": torch.float32}[precision] model_name = model.rsplit('/', 1)[-1] model_path = os.path.join(folder_paths.models_dir, "LLM", model_name) if not os.path.exists(model_path): print(f"Downloading Lumina model to: {model_path}") from huggingface_hub import snapshot_download snapshot_download(repo_id=model, local_dir=model_path, local_dir_use_symlinks=False) print(f"using {attention} for attention") with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports): #workaround for unnecessary flash_attn requirement model = AutoModelForCausalLM.from_pretrained(model_path, attn_implementation=attention, device_map=device, torch_dtype=dtype,trust_remote_code=True) processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True) florence2_model = { 'model': model, 'processor': processor, 'dtype': dtype } return (florence2_model,) def calculate_bounding_box(width, height, flat_points) -> List[float]: """ Calculate the bounding box for a polygon. Args: flat_points (list of int): Flat list of x, y coordinates defining the polygon points. Returns: tuple: (min_x, min_y, max_x, max_y) defining the bounding box. """ if not flat_points or len(flat_points) % 2 != 0: raise ValueError("The list of points must be non-empty and have an even number of elements") x_coords = flat_points[0::2] y_coords = flat_points[1::2] min_x = min(x_coords) max_x = max(x_coords) min_y = min(y_coords) max_y = max(y_coords) return [min_x / width, min_y / height, max_x / width, max_y / height] class Florence2Run: @classmethod def INPUT_TYPES(s): return { "required": { "image": ("IMAGE", ), "florence2_model": ("FL2MODEL", ), "text_input": ("STRING", {"default": "", "multiline": True}), "task": ( [ 'region_caption', 'dense_region_caption', 'region_proposal', 'caption', 'detailed_caption', 'more_detailed_caption', 'caption_to_phrase_grounding', 'referring_expression_segmentation', 'ocr', 'ocr_with_region', 'docvqa' ], ), "fill_mask": ("BOOLEAN", {"default": True}), }, "optional": { "keep_model_loaded": ("BOOLEAN", {"default": False}), "max_new_tokens": ("INT", {"default": 1024, "min": 1, "max": 4096}), "num_beams": ("INT", {"default": 3, "min": 1, "max": 64}), "do_sample": ("BOOLEAN", {"default": True}), "output_mask_select": ("STRING", {"default": ""}), } } RETURN_TYPES = ("IMAGE", "MASK", "STRING", "JSON") RETURN_NAMES =("image", "mask", "caption", "data") FUNCTION = "encode" CATEGORY = "Florence2" def encode(self, image, text_input, florence2_model, task, fill_mask, keep_model_loaded=False, num_beams=3, max_new_tokens=1024, do_sample=True, output_mask_select=""): device = mm.get_torch_device() _, height, width, _ = image.shape offload_device = mm.unet_offload_device() annotated_image_tensor = None mask_tensor = None processor = florence2_model['processor'] model = florence2_model['model'] dtype = florence2_model['dtype'] model.to(device) colormap = ['blue','orange','green','purple','brown','pink','olive','cyan','red', 'lime','indigo','violet','aqua','magenta','gold','tan','skyblue'] prompts = { 'region_caption': '', 'dense_region_caption': '', 'region_proposal': '', 'caption': '', 'detailed_caption': '', 'more_detailed_caption': '', 'caption_to_phrase_grounding': '', 'referring_expression_segmentation': '', 'ocr': '', 'ocr_with_region': '', 'docvqa': '' } task_prompt = prompts.get(task, '') if (task not in ['referring_expression_segmentation', 'caption_to_phrase_grounding', 'docvqa']) and text_input: raise ValueError("Text input (prompt) is only supported for 'referring_expression_segmentation', 'caption_to_phrase_grounding', and 'docvqa'") if text_input != "": prompt = task_prompt + " " + text_input else: prompt = task_prompt image = image.permute(0, 3, 1, 2) out = [] out_masks = [] out_results = [] out_data = [] pbar = ProgressBar(len(image)) for img in image: image_pil = F.to_pil_image(img) inputs = processor(text=prompt, images=image_pil, return_tensors="pt", do_rescale=False).to(dtype).to(device) generated_ids = model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=max_new_tokens, do_sample=do_sample, num_beams=num_beams, ) results = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] print(results) # cleanup the special tokens from the final list if task == 'ocr_with_region': clean_results = str(results) cleaned_string = re.sub(r'|<[^>]*>', '\n', clean_results) clean_results = re.sub(r'\n+', '\n', cleaned_string) else: clean_results = str(results) clean_results = clean_results.replace('', '') clean_results = clean_results.replace('', '') #return single string if only one image for compatibility with nodes that can't handle string lists if len(image) == 1: out_results = clean_results else: out_results.append(clean_results) W, H = image_pil.size parsed_answer = processor.post_process_generation(results, task=task_prompt, image_size=(W, H)) if task == 'region_caption' or task == 'dense_region_caption' or task == 'caption_to_phrase_grounding' or task == 'region_proposal': fig, ax = plt.subplots(figsize=(W / 100, H / 100), dpi=100) fig.subplots_adjust(left=0, right=1, top=1, bottom=0) ax.imshow(image_pil) bboxes = parsed_answer[task_prompt]['bboxes'] labels = parsed_answer[task_prompt]['labels'] mask_indexes = [] # Determine mask indexes outside the loop if output_mask_select != "": mask_indexes = [n for n in output_mask_select.split(",")] print(mask_indexes) else: mask_indexes = [str(i) for i in range(len(bboxes))] # Initialize mask_layer only if needed if fill_mask: mask_layer = Image.new('RGB', image_pil.size, (0, 0, 0)) mask_draw = ImageDraw.Draw(mask_layer) for index, (bbox, label) in enumerate(zip(bboxes, labels)): # Modify the label to include the index indexed_label = f"{index}.{label}" if fill_mask: if str(index) in mask_indexes: print("match index:", str(index), "in mask_indexes:", mask_indexes) mask_draw.rectangle([bbox[0], bbox[1], bbox[2], bbox[3]], fill=(255, 255, 255)) if label in mask_indexes: print("match label") mask_draw.rectangle([bbox[0], bbox[1], bbox[2], bbox[3]], fill=(255, 255, 255)) # Create a Rectangle patch rect = patches.Rectangle( (bbox[0], bbox[1]), # (x,y) - lower left corner bbox[2] - bbox[0], # Width bbox[3] - bbox[1], # Height linewidth=1, edgecolor='r', facecolor='none', label=indexed_label ) # Calculate text width with a rough estimation text_width = len(label) * 6 # Adjust multiplier based on your font size text_height = 12 # Adjust based on your font size # Initial text position text_x = bbox[0] text_y = bbox[1] - text_height # Position text above the top-left of the bbox # Adjust text_x if text is going off the left or right edge if text_x < 0: text_x = 0 elif text_x + text_width > W: text_x = W - text_width # Adjust text_y if text is going off the top edge if text_y < 0: text_y = bbox[3] # Move text below the bottom-left of the bbox if it doesn't overlap with bbox # Add the rectangle to the plot ax.add_patch(rect) facecolor = random.choice(colormap) if len(image) == 1 else 'red' # Add the label plt.text( text_x, text_y, indexed_label, color='white', fontsize=12, bbox=dict(facecolor=facecolor, alpha=0.5) ) if fill_mask: mask_tensor = F.to_tensor(mask_layer) mask_tensor = mask_tensor.unsqueeze(0).permute(0, 2, 3, 1).cpu().float() mask_tensor = mask_tensor.mean(dim=0, keepdim=True) mask_tensor = mask_tensor.repeat(1, 1, 1, 3) mask_tensor = mask_tensor[:, :, :, 0] out_masks.append(mask_tensor) # Remove axis and padding around the image ax.axis('off') ax.margins(0,0) ax.get_xaxis().set_major_locator(plt.NullLocator()) ax.get_yaxis().set_major_locator(plt.NullLocator()) fig.canvas.draw() buf = io.BytesIO() plt.savefig(buf, format='png', pad_inches=0) buf.seek(0) annotated_image_pil = Image.open(buf) annotated_image_tensor = F.to_tensor(annotated_image_pil) out_tensor = annotated_image_tensor[:3, :, :].unsqueeze(0).permute(0, 2, 3, 1).cpu().float() out.append(out_tensor) pbar.update(1) plt.close(fig) elif task == 'referring_expression_segmentation': # Create a new black image mask_image = Image.new('RGB', (W, H), 'black') mask_draw = ImageDraw.Draw(mask_image) predictions = parsed_answer[task_prompt] # Iterate over polygons and labels for polygons, label in zip(predictions['polygons'], predictions['labels']): color = random.choice(colormap) for _polygon in polygons: _polygon = np.array(_polygon).reshape(-1, 2) # Clamp polygon points to image boundaries _polygon = np.clip(_polygon, [0, 0], [W - 1, H - 1]) if len(_polygon) < 3: print('Invalid polygon:', _polygon) continue _polygon = _polygon.reshape(-1).tolist() # Draw the polygon if fill_mask: overlay = Image.new('RGBA', image_pil.size, (255, 255, 255, 0)) image_pil = image_pil.convert('RGBA') draw = ImageDraw.Draw(overlay) color_with_opacity = ImageColor.getrgb(color) + (180,) draw.polygon(_polygon, outline=color, fill=color_with_opacity, width=3) image_pil = Image.alpha_composite(image_pil, overlay) else: draw = ImageDraw.Draw(image_pil) draw.polygon(_polygon, outline=color, width=3) #draw mask mask_draw.polygon(_polygon, outline="white", fill="white") image_tensor = F.to_tensor(image_pil) image_tensor = image_tensor[:3, :, :].unsqueeze(0).permute(0, 2, 3, 1).cpu().float() out.append(image_tensor) mask_tensor = F.to_tensor(mask_image) mask_tensor = mask_tensor.unsqueeze(0).permute(0, 2, 3, 1).cpu().float() mask_tensor = mask_tensor.mean(dim=0, keepdim=True) mask_tensor = mask_tensor.repeat(1, 1, 1, 3) mask_tensor = mask_tensor[:, :, :, 0] out_masks.append(mask_tensor) pbar.update(1) elif task == 'ocr_with_region': try: font = ImageFont.load_default().font_variant(size=24) except: font = ImageFont.load_default() predictions = parsed_answer[task_prompt] scale = 1 draw = ImageDraw.Draw(image_pil) bboxes, labels = predictions['quad_boxes'], predictions['labels'] for box, label in zip(bboxes, labels): bbox = calculate_bounding_box(width, height, box) out_data.append({"label": label, "polygon": box, "box": bbox}) color = random.choice(colormap) new_box = (np.array(box) * scale).tolist() draw.polygon(new_box, width=3, outline=color) draw.text((new_box[0]+8, new_box[1]+2), "{}".format(label), align="right", font=font, fill=color) image_tensor = F.to_tensor(image_pil) image_tensor = image_tensor[:3, :, :].unsqueeze(0).permute(0, 2, 3, 1).cpu().float() out.append(image_tensor) elif task == 'docvqa': if text_input == "": raise ValueError("Text input (prompt) is required for 'docvqa'") prompt = " " + text_input inputs = processor(text=prompt, images=image_pil, return_tensors="pt", do_rescale=False).to(dtype).to(device) generated_ids = model.generate( input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=max_new_tokens, do_sample=do_sample, num_beams=num_beams, ) results = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] clean_results = results.replace('', '').replace('', '') if len(image) == 1: out_results = clean_results else: out_results.append(clean_results) out.append(F.to_tensor(image_pil).unsqueeze(0).permute(0, 2, 3, 1).cpu().float()) pbar.update(1) if len(out) > 0: out_tensor = torch.cat(out, dim=0) else: out_tensor = torch.zeros((1, 64,64, 3), dtype=torch.float32, device="cpu") if len(out_masks) > 0: out_mask_tensor = torch.cat(out_masks, dim=0) else: out_mask_tensor = torch.zeros((1,64,64), dtype=torch.float32, device="cpu") if not keep_model_loaded: print("Offloading model...") model.to(offload_device) mm.soft_empty_cache() return (out_tensor, out_mask_tensor, out_results, out_data) NODE_CLASS_MAPPINGS = { "DownloadAndLoadFlorence2Model": DownloadAndLoadFlorence2Model, "Florence2Run": Florence2Run, } NODE_DISPLAY_NAME_MAPPINGS = { "DownloadAndLoadFlorence2Model": "DownloadAndLoadFlorence2Model", "Florence2Run": "Florence2Run", }