arad1367 commited on
Commit
9b1a390
·
verified ·
1 Parent(s): dba43e6

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +162 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import matplotlib.pyplot as plt
3
+ import matplotlib.patches as patches
4
+ from PIL import Image
5
+ import gradio as gr
6
+ from transformers import AutoProcessor, AutoModelForCausalLM
7
+ import torch
8
+ import numpy as np
9
+ import spaces
10
+ import subprocess
11
+
12
+ subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
13
+
14
+ # Initialize Florence-2-large model and processor
15
+ model_id = 'microsoft/Florence-2-large'
16
+ model = AutoModelForCausalLM.from_pretrained(model_id, trust_remote_code=True).to("cuda").eval()
17
+ processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
18
+
19
+ # Function to resize and preprocess image
20
+ def preprocess_image(image_path, max_size=(800, 800)):
21
+ image = Image.open(image_path).convert('RGB')
22
+ if image.size[0] > max_size[0] or image.size[1] > max_size[1]:
23
+ image.thumbnail(max_size, Image.LANCZOS)
24
+
25
+ # Convert image to numpy array
26
+ image_np = np.array(image)
27
+
28
+ # Ensure the image is in the format [height, width, channels]
29
+ if image_np.ndim == 2: # Grayscale image
30
+ image_np = np.expand_dims(image_np, axis=-1)
31
+ elif image_np.shape[0] == 3: # Image in [channels, height, width] format
32
+ image_np = np.transpose(image_np, (1, 2, 0))
33
+
34
+ return image_np, image.size
35
+
36
+ # Function to run Florence-2-large model
37
+ @spaces.GPU
38
+ def run_florence_model(image_np, image_size, task_prompt, text_input=None):
39
+ if text_input is None:
40
+ prompt = task_prompt
41
+ else:
42
+ prompt = task_prompt + text_input
43
+
44
+ inputs = processor(text=prompt, images=image_np, return_tensors="pt")
45
+
46
+ with torch.no_grad():
47
+ outputs = model.generate(
48
+ input_ids=inputs["input_ids"].cuda(),
49
+ pixel_values=inputs["pixel_values"].cuda(),
50
+ max_new_tokens=1024,
51
+ early_stopping=False,
52
+ do_sample=False,
53
+ num_beams=3,
54
+ )
55
+
56
+ generated_text = processor.batch_decode(outputs, skip_special_tokens=False)[0]
57
+ parsed_answer = processor.post_process_generation(
58
+ generated_text,
59
+ task=task_prompt,
60
+ image_size=image_size
61
+ )
62
+
63
+ return parsed_answer, generated_text
64
+
65
+ # Function to plot image with bounding boxes
66
+ def plot_image_with_bboxes(image_np, bboxes, labels=None):
67
+ fig, ax = plt.subplots(1)
68
+ ax.imshow(image_np)
69
+ colors = ['red', 'blue', 'green', 'yellow', 'purple', 'cyan']
70
+ for i, bbox in enumerate(bboxes):
71
+ color = colors[i % len(colors)]
72
+ x, y, width, height = bbox[0], bbox[1], bbox[2] - bbox[0], bbox[3] - bbox[1]
73
+ rect = patches.Rectangle((x, y), width, height, linewidth=2, edgecolor=color, facecolor='none')
74
+ ax.add_patch(rect)
75
+ if labels and i < len(labels):
76
+ ax.text(x, y, labels[i], color=color, fontsize=8, bbox=dict(facecolor='white', alpha=0.7))
77
+ plt.axis('off')
78
+ return fig
79
+
80
+ # Gradio function to process uploaded images
81
+ @spaces.GPU
82
+ def process_image(image_path):
83
+ image_np, image_size = preprocess_image(image_path)
84
+
85
+ # Convert image_np to float32
86
+ image_np = image_np.astype(np.float32)
87
+
88
+ # Image Captioning
89
+ caption_result, _ = run_florence_model(image_np, image_size, '<CAPTION>')
90
+ detailed_caption_result, _ = run_florence_model(image_np, image_size, '<DETAILED_CAPTION>')
91
+
92
+ # Object Detection
93
+ od_result, _ = run_florence_model(image_np, image_size, '<OD>')
94
+ od_bboxes = od_result['<OD>'].get('bboxes', [])
95
+ od_labels = od_result['<OD>'].get('labels', [])
96
+
97
+ # OCR
98
+ ocr_result, _ = run_florence_model(image_np, image_size, '<OCR>')
99
+
100
+ # Phrase Grounding
101
+ pg_result, _ = run_florence_model(image_np, image_size, '<CAPTION_TO_PHRASE_GROUNDING>', text_input=caption_result['<CAPTION>'])
102
+ pg_bboxes = pg_result['<CAPTION_TO_PHRASE_GROUNDING>'].get('bboxes', [])
103
+ pg_labels = pg_result['<CAPTION_TO_PHRASE_GROUNDING>'].get('labels', [])
104
+
105
+ # Cascaded Tasks (Detailed Caption + Phrase Grounding)
106
+ cascaded_result, _ = run_florence_model(image_np, image_size, '<CAPTION_TO_PHRASE_GROUNDING>', text_input=detailed_caption_result['<DETAILED_CAPTION>'])
107
+ cascaded_bboxes = cascaded_result['<CAPTION_TO_PHRASE_GROUNDING>'].get('bboxes', [])
108
+ cascaded_labels = cascaded_result['<CAPTION_TO_PHRASE_GROUNDING>'].get('labels', [])
109
+
110
+ # Create plots
111
+ od_fig = plot_image_with_bboxes(image_np, od_bboxes, od_labels)
112
+ pg_fig = plot_image_with_bboxes(image_np, pg_bboxes, pg_labels)
113
+ cascaded_fig = plot_image_with_bboxes(image_np, cascaded_bboxes, cascaded_labels)
114
+
115
+ # Prepare response
116
+ response = f"""
117
+ Image Captioning:
118
+ - Simple Caption: {caption_result['<CAPTION>']}
119
+ - Detailed Caption: {detailed_caption_result['<DETAILED_CAPTION>']}
120
+
121
+ Object Detection:
122
+ - Detected {len(od_bboxes)} objects
123
+
124
+ OCR:
125
+ {ocr_result['<OCR>']}
126
+
127
+ Phrase Grounding:
128
+ - Grounded {len(pg_bboxes)} phrases from the simple caption
129
+
130
+ Cascaded Tasks:
131
+ - Grounded {len(cascaded_bboxes)} phrases from the detailed caption
132
+ """
133
+
134
+ return response, od_fig, pg_fig, cascaded_fig
135
+
136
+ # Gradio interface
137
+ with gr.Blocks(theme='NoCrypt/miku') as demo:
138
+ gr.Markdown("""
139
+ # Image Processing with Florence-2-large
140
+ Upload an image to perform image captioning, object detection, OCR, phrase grounding, and cascaded tasks.
141
+ """)
142
+
143
+ image_input = gr.Image(type="filepath")
144
+ text_output = gr.Textbox()
145
+ plot_output_1 = gr.Plot()
146
+ plot_output_2 = gr.Plot()
147
+ plot_output_3 = gr.Plot()
148
+
149
+ image_input.upload(process_image, inputs=[image_input], outputs=[text_output, plot_output_1, plot_output_2, plot_output_3])
150
+
151
+ footer = """
152
+ <div style="text-align: center; margin-top: 20px;">
153
+ <a href="https://www.linkedin.com/in/pejman-ebrahimi-4a60151a7/" target="_blank">LinkedIn</a> |
154
+ <a href="https://github.com/arad1367" target="_blank">GitHub</a> |
155
+ <a href="https://arad1367.pythonanywhere.com/" target="_blank">Live demo of my PhD defense</a>
156
+ <br>
157
+ Made with 💖 by Pejman Ebrahimi
158
+ </div>
159
+ """
160
+ gr.HTML(footer)
161
+
162
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ pip>=23.0.0
2
+ transformers
3
+ gradio
4
+ Pillow
5
+ matplotlib
6
+ torch
7
+ timm
8
+ einops
9
+ spaces