ighoshsubho commited on
Commit
9aecc37
0 Parent(s):

Florence sam flux first commit

Browse files
Files changed (6) hide show
  1. .gitignore +3 -0
  2. README.md +12 -0
  3. app.py +121 -0
  4. requirements.txt +13 -0
  5. utils/florence.py +58 -0
  6. utils/sam.py +45 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ /venv
2
+ /.idea
3
+ /tmp
README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Florence2 + SAM2 + FLUX
3
+ emoji: 🔥
4
+ colorFrom: purple
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 4.40.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: apache-2.0
11
+ ---
12
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from diffusers import FluxInpaintPipeline
5
+ from utils.florence import load_florence_model, run_florence_inference, FLORENCE_OPEN_VOCABULARY_DETECTION_TASK
6
+ from utils.sam import load_sam_image_model, run_sam_inference
7
+ import gradio as gr
8
+ import supervision as sv
9
+
10
+ # Load models
11
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ FLUX_PIPE = FluxInpaintPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=torch.bfloat16).to(
13
+ DEVICE)
14
+ FLORENCE_MODEL, FLORENCE_PROCESSOR = load_florence_model(device=DEVICE)
15
+ SAM_MODEL = load_sam_image_model(device=DEVICE)
16
+
17
+ COLORS = ['#FF1493', '#00BFFF', '#FF6347', '#FFD700', '#32CD32', '#8A2BE2']
18
+ COLOR_PALETTE = sv.ColorPalette.from_hex(COLORS)
19
+ BOX_ANNOTATOR = sv.BoxAnnotator(color=COLOR_PALETTE, color_lookup=sv.ColorLookup.INDEX)
20
+ LABEL_ANNOTATOR = sv.LabelAnnotator(
21
+ color=COLOR_PALETTE,
22
+ color_lookup=sv.ColorLookup.INDEX,
23
+ text_position=sv.Position.CENTER_OF_MASS,
24
+ text_color=sv.Color.from_hex("#000000"),
25
+ border_radius=5
26
+ )
27
+ MASK_ANNOTATOR = sv.MaskAnnotator(
28
+ color=COLOR_PALETTE,
29
+ color_lookup=sv.ColorLookup.INDEX
30
+ )
31
+
32
+
33
+ def visualize_detections(image, detections):
34
+ output_image = image.copy()
35
+ output_image = MASK_ANNOTATOR.annotate(output_image, detections)
36
+ output_image = BOX_ANNOTATOR.annotate(output_image, detections)
37
+ output_image = LABEL_ANNOTATOR.annotate(output_image, detections)
38
+ return output_image
39
+
40
+
41
+ def detect_objects(image, text_prompt):
42
+ # Use Florence for object detection
43
+ _, result = run_florence_inference(
44
+ model=FLORENCE_MODEL,
45
+ processor=FLORENCE_PROCESSOR,
46
+ device=DEVICE,
47
+ image=image,
48
+ task=FLORENCE_OPEN_VOCABULARY_DETECTION_TASK,
49
+ text=text_prompt
50
+ )
51
+ detections = sv.Detections.from_lmm(
52
+ lmm=sv.LMM.FLORENCE_2,
53
+ result=result,
54
+ resolution_wh=image.size
55
+ )
56
+
57
+ # Use SAM to refine masks
58
+ detections = run_sam_inference(SAM_MODEL, image, detections)
59
+ return detections
60
+
61
+
62
+ def inpaint_selected_objects(image, detections, selected_indices, inpaint_prompt):
63
+ mask = np.zeros(image.size[::-1], dtype=np.uint8)
64
+ for idx in selected_indices:
65
+ mask |= detections.mask[idx]
66
+
67
+ mask_image = Image.fromarray(mask * 255)
68
+
69
+ result = FLUX_PIPE(
70
+ prompt=inpaint_prompt,
71
+ image=image,
72
+ mask_image=mask_image,
73
+ num_inference_steps=30,
74
+ strength=0.85,
75
+ ).images[0]
76
+
77
+ return result
78
+
79
+
80
+ def process_image(input_image, detection_prompt, inpaint_prompt, selected_objects):
81
+ detections = detect_objects(input_image, detection_prompt)
82
+
83
+ # Visualize detected objects
84
+ detected_image = visualize_detections(input_image, detections)
85
+
86
+ if selected_objects:
87
+ selected_indices = [int(idx) for idx in selected_objects.split(',')]
88
+ inpainted_image = inpaint_selected_objects(input_image, detections, selected_indices, inpaint_prompt)
89
+ return detected_image, inpainted_image
90
+ else:
91
+ return detected_image, None
92
+
93
+
94
+ # Gradio interface
95
+ with gr.Blocks() as demo:
96
+ gr.Markdown("# Object Detection and Inpainting with FLUX, Florence, and SAM")
97
+ with gr.Row():
98
+ with gr.Column():
99
+ input_image = gr.Image(type="pil", label="Input Image")
100
+ detection_prompt = gr.Textbox(label="Detection Prompt", placeholder="Enter objects to detect")
101
+ detect_button = gr.Button("Detect Objects")
102
+ with gr.Column():
103
+ detected_image = gr.Image(type="pil", label="Detected Objects")
104
+ selected_objects = gr.Textbox(label="Selected Objects",
105
+ placeholder="Enter indices of objects to inpaint (comma-separated)")
106
+ inpaint_prompt = gr.Textbox(label="Inpainting Prompt", placeholder="Describe what to inpaint")
107
+ inpaint_button = gr.Button("Inpaint Selected Objects")
108
+ output_image = gr.Image(type="pil", label="Inpainted Result")
109
+
110
+ detect_button.click(
111
+ fn=lambda img, prompt: process_image(img, prompt, "", "")[0],
112
+ inputs=[input_image, detection_prompt],
113
+ outputs=detected_image
114
+ )
115
+ inpaint_button.click(
116
+ fn=process_image,
117
+ inputs=[input_image, detection_prompt, inpaint_prompt, selected_objects],
118
+ outputs=[detected_image, output_image]
119
+ )
120
+
121
+ demo.launch(debug=False, show_error=True)
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tqdm
2
+ einops
3
+ spaces
4
+ timm
5
+ transformers
6
+ samv2
7
+ gradio
8
+ supervision
9
+ opencv-python
10
+ pytest
11
+ torch
12
+ numpy
13
+ diffusers
utils/florence.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union, Any, Tuple, Dict
3
+ from unittest.mock import patch
4
+
5
+ import torch
6
+ from PIL import Image
7
+ from transformers import AutoModelForCausalLM, AutoProcessor
8
+ from transformers.dynamic_module_utils import get_imports
9
+
10
+ FLORENCE_CHECKPOINT = "microsoft/Florence-2-base"
11
+ FLORENCE_OBJECT_DETECTION_TASK = '<OD>'
12
+ FLORENCE_DETAILED_CAPTION_TASK = '<MORE_DETAILED_CAPTION>'
13
+ FLORENCE_CAPTION_TO_PHRASE_GROUNDING_TASK = '<CAPTION_TO_PHRASE_GROUNDING>'
14
+ FLORENCE_OPEN_VOCABULARY_DETECTION_TASK = '<OPEN_VOCABULARY_DETECTION>'
15
+ FLORENCE_DENSE_REGION_CAPTION_TASK = '<DENSE_REGION_CAPTION>'
16
+
17
+
18
+ def fixed_get_imports(filename: Union[str, os.PathLike]) -> list[str]:
19
+ """Work around for https://huggingface.co/microsoft/phi-1_5/discussions/72."""
20
+ if not str(filename).endswith("/modeling_florence2.py"):
21
+ return get_imports(filename)
22
+ imports = get_imports(filename)
23
+ imports.remove("flash_attn")
24
+ return imports
25
+
26
+
27
+ def load_florence_model(
28
+ device: torch.device, checkpoint: str = FLORENCE_CHECKPOINT
29
+ ) -> Tuple[Any, Any]:
30
+ with patch("transformers.dynamic_module_utils.get_imports", fixed_get_imports):
31
+ model = AutoModelForCausalLM.from_pretrained(
32
+ checkpoint, trust_remote_code=True).to(device).eval()
33
+ processor = AutoProcessor.from_pretrained(
34
+ checkpoint, trust_remote_code=True)
35
+ return model, processor
36
+
37
+
38
+ def run_florence_inference(
39
+ model: Any,
40
+ processor: Any,
41
+ device: torch.device,
42
+ image: Image,
43
+ task: str,
44
+ text: str = ""
45
+ ) -> Tuple[str, Dict]:
46
+ prompt = task + text
47
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
48
+ generated_ids = model.generate(
49
+ input_ids=inputs["input_ids"],
50
+ pixel_values=inputs["pixel_values"],
51
+ max_new_tokens=1024,
52
+ num_beams=3
53
+ )
54
+ generated_text = processor.batch_decode(
55
+ generated_ids, skip_special_tokens=False)[0]
56
+ response = processor.post_process_generation(
57
+ generated_text, task=task, image_size=image.size)
58
+ return generated_text, response
utils/sam.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import numpy as np
4
+ import supervision as sv
5
+ import torch
6
+ from PIL import Image
7
+ from sam2.build_sam import build_sam2, build_sam2_video_predictor
8
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
9
+
10
+ SAM_CHECKPOINT = "checkpoints/sam2_hiera_small.pt"
11
+ SAM_CONFIG = "sam2_hiera_s.yaml"
12
+
13
+
14
+ def load_sam_image_model(
15
+ device: torch.device,
16
+ config: str = SAM_CONFIG,
17
+ checkpoint: str = SAM_CHECKPOINT
18
+ ) -> SAM2ImagePredictor:
19
+ model = build_sam2(config, checkpoint, device=device)
20
+ return SAM2ImagePredictor(sam_model=model)
21
+
22
+
23
+ def load_sam_video_model(
24
+ device: torch.device,
25
+ config: str = SAM_CONFIG,
26
+ checkpoint: str = SAM_CHECKPOINT
27
+ ) -> Any:
28
+ return build_sam2_video_predictor(config, checkpoint, device=device)
29
+
30
+
31
+ def run_sam_inference(
32
+ model: Any,
33
+ image: Image,
34
+ detections: sv.Detections
35
+ ) -> sv.Detections:
36
+ image = np.array(image.convert("RGB"))
37
+ model.set_image(image)
38
+ mask, score, _ = model.predict(box=detections.xyxy, multimask_output=False)
39
+
40
+ # dirty fix; remove this later
41
+ if len(mask.shape) == 4:
42
+ mask = np.squeeze(mask)
43
+
44
+ detections.mask = mask.astype(bool)
45
+ return detections