Spaces:
Running
Running
feat: initial commit for porting the webapp
Browse files- __init__.py +0 -0
- app.py +222 -0
- data/04_models/pipeline/webapp/installed/cropped_images/0495c348-a87c-4e70-8a1f-9e07e8510977.jpg +0 -0
- data/04_models/pipeline/webapp/installed/cropped_images/4d7bd307-1224-41e6-ba89-82dd85e69ea9.jpg +0 -0
- data/04_models/pipeline/webapp/installed/cropped_images/8042e02f-1d73-4251-8b76-3f634ffe0196.jpg +0 -0
- data/04_models/pipeline/webapp/installed/cropped_images/8ed739d7-d24f-4919-bd1d-339e34536a9b.jpg +0 -0
- data/04_models/pipeline/webapp/installed/cropped_images/ec1c6fe4-b2b5-491a-9d7f-542b85e2b078.jpg +0 -0
- data/images/2023-08-03_1157_7_900088000908738_F.jpg +0 -0
- data/images/2023-08-16_1547_7_989001006037192_F.jpg +0 -0
- data/images/2023-10-20_1604_7_900088000909142_F.jpg +0 -0
- data/images/2023-10-20_1625_7_900088000913636_F.jpg +0 -0
- data/images/989,001,006,004,046_F.jpg +0 -0
- data/pipeline/config.yaml +13 -0
- data/pipeline/db/db.csv +6 -0
- data/pipeline/models/identification/config.yaml +2 -0
- data/pipeline/models/identification/features.pt +3 -0
- data/pipeline/models/pose/weights.pt +3 -0
- data/pipeline/models/segmentation/weights.pt +3 -0
- data/summary.csv +6 -0
- identification.py +241 -0
- pipeline.py +348 -0
- pose.py +166 -0
- requirements.txt +8 -0
- segmentation.py +103 -0
- ui.py +0 -0
- utils.py +659 -0
- viz2d.py +326 -0
- yolo.py +68 -0
__init__.py
ADDED
File without changes
|
app.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Simple Gradio Interface to showcase the ML outputs in a UI and webapp.
|
3 |
+
"""
|
4 |
+
|
5 |
+
import math
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
from PIL import Image
|
10 |
+
from ultralytics import YOLO
|
11 |
+
|
12 |
+
import pipeline
|
13 |
+
from identification import IdentificationModel, generate_visualization
|
14 |
+
from utils import bgr_to_rgb, select_best_device
|
15 |
+
|
16 |
+
DEFAULT_IMAGE_INDEX = 0
|
17 |
+
|
18 |
+
DIR_INSTALLED_PIPELINE = Path("./data/pipeline/")
|
19 |
+
DIR_EXAMPLES = Path("./data/images/")
|
20 |
+
FILEPATH_IDENTIFICATION_LIGHTGLUE_CONFIG = (
|
21 |
+
DIR_INSTALLED_PIPELINE / "models/identification/config.yaml"
|
22 |
+
)
|
23 |
+
FILEPATH_IDENTIFICATION_DB = DIR_INSTALLED_PIPELINE / "db/db.csv"
|
24 |
+
FILEPATH_IDENTIFICATION_LIGHTGLUE_FEATURES = (
|
25 |
+
DIR_INSTALLED_PIPELINE / "models/identification/features.pt"
|
26 |
+
)
|
27 |
+
FILEPATH_WEIGHTS_SEGMENTATION_MODEL = (
|
28 |
+
DIR_INSTALLED_PIPELINE / "models/segmentation/weights.pt"
|
29 |
+
)
|
30 |
+
FILEPATH_WEIGHTS_POSE_MODEL = DIR_INSTALLED_PIPELINE / "models/pose/weights.pt"
|
31 |
+
|
32 |
+
|
33 |
+
def examples(dir_examples: Path) -> list[Path]:
|
34 |
+
"""
|
35 |
+
Function to retrieve the default example images.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
examples (list[Path]): list of image filepaths.
|
39 |
+
"""
|
40 |
+
return list(dir_examples.glob("*.jpg"))
|
41 |
+
|
42 |
+
|
43 |
+
def make_ui(loaded_models: dict[str, YOLO | IdentificationModel]):
|
44 |
+
"""
|
45 |
+
Main entrypoint to wire up the Gradio interface.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
loaded_models (dict[str, YOLO | IdentificationModel]): loaded models ready to run inference with.
|
49 |
+
|
50 |
+
Returns:
|
51 |
+
gradio_ui
|
52 |
+
"""
|
53 |
+
with gr.Blocks() as demo:
|
54 |
+
with gr.Row():
|
55 |
+
with gr.Column():
|
56 |
+
image_input = gr.Image(
|
57 |
+
type="pil",
|
58 |
+
value=default_value_input,
|
59 |
+
label="input image",
|
60 |
+
sources=["upload", "clipboard"],
|
61 |
+
)
|
62 |
+
gr.Examples(
|
63 |
+
examples=example_filepaths,
|
64 |
+
inputs=image_input,
|
65 |
+
)
|
66 |
+
submit_btn = gr.Button(value="Identify", variant="primary")
|
67 |
+
|
68 |
+
with gr.Column():
|
69 |
+
with gr.Tab("Prediction"):
|
70 |
+
with gr.Row():
|
71 |
+
pit_prediction = gr.Text(label="predicted individual")
|
72 |
+
name_prediction = gr.Text(label="fish name", visible=False)
|
73 |
+
image_feature_matching = gr.Image(
|
74 |
+
label="pattern matching", visible=False
|
75 |
+
)
|
76 |
+
image_extracted_keypoints = gr.Image(
|
77 |
+
label="extracted keypoints", visible=False
|
78 |
+
)
|
79 |
+
|
80 |
+
with gr.Tab("Details", visible=False) as tab_details:
|
81 |
+
with gr.Column():
|
82 |
+
with gr.Row():
|
83 |
+
text_rotation_angle = gr.Text(
|
84 |
+
label="correction angle (degrees)"
|
85 |
+
)
|
86 |
+
text_side = gr.Text(label="predicted side")
|
87 |
+
|
88 |
+
image_pose_keypoints = gr.Image(
|
89 |
+
type="pil", label="pose keypoints"
|
90 |
+
)
|
91 |
+
image_rotated_keypoints = gr.Image(
|
92 |
+
type="pil", label="rotated keypoints"
|
93 |
+
)
|
94 |
+
image_segmentation_mask = gr.Image(type="pil", label="mask")
|
95 |
+
image_masked = gr.Image(type="pil", label="masked")
|
96 |
+
|
97 |
+
def submit_fn(
|
98 |
+
loaded_models: dict[str, YOLO | IdentificationModel],
|
99 |
+
orig_image: Image.Image,
|
100 |
+
):
|
101 |
+
"""
|
102 |
+
Main function used for the Gradio interface.
|
103 |
+
|
104 |
+
Args:
|
105 |
+
loaded_models (dict[str, YOLO]): loaded models.
|
106 |
+
orig_image (PIL): original image picked by the user
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
fish side (str): predicted fish side
|
110 |
+
correction angle (str): rotation to do in degrees to re align the image.
|
111 |
+
keypoints image (PIL): image displaying the bbox and keypoints from the
|
112 |
+
pose estimation model.
|
113 |
+
rotated image (PIL): rotated image after applying the correction angle.
|
114 |
+
segmentation mask (PIL): segmentation mask predicted by the segmentation model.
|
115 |
+
segmented image (PIL): segmented orig_image using the segmentation mask
|
116 |
+
and the crop.
|
117 |
+
predicted_individual (str): The identified individual.
|
118 |
+
pil_image_extracted_keypoints (PIL): The extracted keypoints overlayed on the image.
|
119 |
+
feature_matching_image (PIL): The matching of the source with the identified individual.
|
120 |
+
"""
|
121 |
+
# return {}
|
122 |
+
results = pipeline.run(loaded_models=loaded_models, pil_image=orig_image)
|
123 |
+
side = results["stages"]["pose"]["output"]["side"]
|
124 |
+
theta = results["stages"]["pose"]["output"]["theta"]
|
125 |
+
pil_image_keypoints = Image.fromarray(
|
126 |
+
bgr_to_rgb(results["stages"]["pose"]["output"]["prediction"].plot())
|
127 |
+
)
|
128 |
+
pil_image_rotated = Image.fromarray(
|
129 |
+
results["stages"]["rotation"]["output"]["array_image"]
|
130 |
+
)
|
131 |
+
pil_image_mask = results["stages"]["segmentation"]["output"]["mask"]
|
132 |
+
pil_image_masked_cropped = results["stages"]["crop"]["output"]["pil_image"]
|
133 |
+
|
134 |
+
viz_dict = generate_visualization(
|
135 |
+
pil_image=pil_image_masked_cropped,
|
136 |
+
prediction=results["stages"]["identification"]["output"],
|
137 |
+
)
|
138 |
+
|
139 |
+
is_new_individual = (
|
140 |
+
results["stages"]["identification"]["output"]["type"] == "new"
|
141 |
+
)
|
142 |
+
|
143 |
+
return {
|
144 |
+
text_rotation_angle: f"{math.degrees(theta):0.1f}",
|
145 |
+
text_side: side.value,
|
146 |
+
image_pose_keypoints: pil_image_keypoints,
|
147 |
+
image_rotated_keypoints: pil_image_rotated,
|
148 |
+
image_segmentation_mask: pil_image_mask,
|
149 |
+
image_masked: pil_image_masked_cropped,
|
150 |
+
pit_prediction: (
|
151 |
+
"New Fish!"
|
152 |
+
if is_new_individual
|
153 |
+
else gr.Text(
|
154 |
+
results["stages"]["identification"]["output"]["match"]["pit"],
|
155 |
+
visible=True,
|
156 |
+
)
|
157 |
+
),
|
158 |
+
name_prediction: (
|
159 |
+
gr.Text(visible=False)
|
160 |
+
if is_new_individual
|
161 |
+
else gr.Text(
|
162 |
+
results["stages"]["identification"]["output"]["match"]["name"],
|
163 |
+
visible=True,
|
164 |
+
)
|
165 |
+
),
|
166 |
+
tab_details: gr.Column(visible=True),
|
167 |
+
image_extracted_keypoints: gr.Image(
|
168 |
+
viz_dict["keypoints_source"], visible=True
|
169 |
+
),
|
170 |
+
image_feature_matching: (
|
171 |
+
gr.Image(visible=False)
|
172 |
+
if is_new_individual
|
173 |
+
else gr.Image(viz_dict["matches"], visible=True)
|
174 |
+
),
|
175 |
+
}
|
176 |
+
|
177 |
+
submit_btn.click(
|
178 |
+
fn=lambda pil_image: submit_fn(
|
179 |
+
loaded_models=loaded_models,
|
180 |
+
orig_image=pil_image,
|
181 |
+
),
|
182 |
+
inputs=image_input,
|
183 |
+
outputs=[
|
184 |
+
text_rotation_angle,
|
185 |
+
text_side,
|
186 |
+
image_pose_keypoints,
|
187 |
+
image_rotated_keypoints,
|
188 |
+
image_feature_matching,
|
189 |
+
image_segmentation_mask,
|
190 |
+
image_masked,
|
191 |
+
pit_prediction,
|
192 |
+
name_prediction,
|
193 |
+
tab_details,
|
194 |
+
image_feature_matching,
|
195 |
+
image_extracted_keypoints,
|
196 |
+
],
|
197 |
+
)
|
198 |
+
|
199 |
+
return demo
|
200 |
+
|
201 |
+
|
202 |
+
if __name__ == "__main__":
|
203 |
+
device = select_best_device()
|
204 |
+
# FIXME: get this from the config instead
|
205 |
+
extractor_type = "aliked"
|
206 |
+
n_keypoints = 1024
|
207 |
+
threshold_wasserstein = 0.084
|
208 |
+
loaded_models = pipeline.load_models(
|
209 |
+
device=device,
|
210 |
+
filepath_weights_segmentation_model=FILEPATH_WEIGHTS_SEGMENTATION_MODEL,
|
211 |
+
filepath_weights_pose_model=FILEPATH_WEIGHTS_POSE_MODEL,
|
212 |
+
filepath_identification_lightglue_features=FILEPATH_IDENTIFICATION_LIGHTGLUE_FEATURES,
|
213 |
+
filepath_identification_db=FILEPATH_IDENTIFICATION_DB,
|
214 |
+
extractor_type=extractor_type,
|
215 |
+
n_keypoints=n_keypoints,
|
216 |
+
threshold_wasserstein=threshold_wasserstein,
|
217 |
+
)
|
218 |
+
model_segmentation = loaded_models["segmentation"]
|
219 |
+
example_filepaths = examples(dir_examples=DIR_EXAMPLES)
|
220 |
+
default_value_input = Image.open(example_filepaths[DEFAULT_IMAGE_INDEX])
|
221 |
+
demo = make_ui(loaded_models=loaded_models)
|
222 |
+
demo.launch()
|
data/04_models/pipeline/webapp/installed/cropped_images/0495c348-a87c-4e70-8a1f-9e07e8510977.jpg
ADDED
![]() |
data/04_models/pipeline/webapp/installed/cropped_images/4d7bd307-1224-41e6-ba89-82dd85e69ea9.jpg
ADDED
![]() |
data/04_models/pipeline/webapp/installed/cropped_images/8042e02f-1d73-4251-8b76-3f634ffe0196.jpg
ADDED
![]() |
data/04_models/pipeline/webapp/installed/cropped_images/8ed739d7-d24f-4919-bd1d-339e34536a9b.jpg
ADDED
![]() |
data/04_models/pipeline/webapp/installed/cropped_images/ec1c6fe4-b2b5-491a-9d7f-542b85e2b078.jpg
ADDED
![]() |
data/images/2023-08-03_1157_7_900088000908738_F.jpg
ADDED
![]() |
data/images/2023-08-16_1547_7_989001006037192_F.jpg
ADDED
![]() |
data/images/2023-10-20_1604_7_900088000909142_F.jpg
ADDED
![]() |
data/images/2023-10-20_1625_7_900088000913636_F.jpg
ADDED
![]() |
data/images/989,001,006,004,046_F.jpg
ADDED
![]() |
data/pipeline/config.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
filepath_db: db/db.csv
|
2 |
+
filepath_model_segmentation_weights: models/segmentation/weights.pt
|
3 |
+
filepath_model_pose_weights: models/pose/weights.pt
|
4 |
+
filepath_model_identification_features: models/identification/features.pt
|
5 |
+
filepath_model_identification_config: models/identification/config.yaml
|
6 |
+
filepath_config: config.yaml
|
7 |
+
root_dir: .
|
8 |
+
dir_cropped_images: cropped_images
|
9 |
+
dir_models: models
|
10 |
+
dir_models_segmentation: models/segmentation
|
11 |
+
dir_models_pose: models/pose
|
12 |
+
dir_models_identification: models/identification
|
13 |
+
dir_db: db
|
data/pipeline/db/db.csv
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
,Unnamed: 0.3,Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,filepath,year,pit,created_at_filepath,created_at_exif,is_electrofishing,is_guide_angling,exif_make,exif_model,exif_focal_length,exif_image_width,exif_image_height,exif_shutter_speed,exif_aperture,exif_brightness,coordinates_lat,coordinates_lon,uuid,success,filepath_crop,pose_theta,pose_fish_side,name
|
2 |
+
0,0,0,1311,1312,data/01_raw/elk-river/2023/Guide/Fish/Nupqu 2/2023-09-15_1040_2_900088000908738_F.jpg,2023,900088000908738,2023-09-15 10:40:00,,False,True,,,,4032.0,3024.0,,,,,,8042e02f-1d73-4251-8b76-3f634ffe0196,True,data/04_models/pipeline/webapp/installed/cropped_images/8042e02f-1d73-4251-8b76-3f634ffe0196.jpg,2.9707219143490304,left,Norma Fisher
|
3 |
+
1,1,1,2306,2308,"data/01_raw/elk-river/2022/Boat_Electrofishing/Fish/August/989,001,006,037,192_F.jpg",2022,989001006037192,,2022-08-10 14:24:34,True,False,Apple,iPhone 12 Pro Max,5.1,4032.0,3024.0,10.17001733102253,1.3561438092556088,7.527949339379251,49.447309527777776,-115.05773913888888,0495c348-a87c-4e70-8a1f-9e07e8510977,True,data/04_models/pipeline/webapp/installed/cropped_images/0495c348-a87c-4e70-8a1f-9e07e8510977.jpg,3.0887346928329578,left,Jorge Sullivan
|
4 |
+
2,2,2,142,142,data/01_raw/elk-river/2023/Boat Electrofishing/Fish/Aug 16 2023/2023-08-16_1146_7_900088000913914_F.jpg,2023,900088000913914,2023-08-16 11:46:00,2023-08-16 11:47:39,True,False,Apple,iPhone 13 Pro Max,5.7,4032.0,3024.0,9.47009,1.169925,7.41103,49.479367833333335,-115.06903196666666,4d7bd307-1224-41e6-ba89-82dd85e69ea9,True,data/04_models/pipeline/webapp/installed/cropped_images/4d7bd307-1224-41e6-ba89-82dd85e69ea9.jpg,-0.0302701344557054,left,Elizabeth Woods
|
5 |
+
3,3,3,2264,2266,"data/01_raw/elk-river/2022/Boat_Electrofishing/Fish/August/989,001,006,004,046_F_2.jpg",2022,989001006004046,,2022-08-05 14:00:47,True,False,Apple,iPhone 13 Pro Max,5.7,4032.0,3024.0,10.445739257101238,1.1699250021066825,8.260268601117538,49.374336702777775,-115.01017866666666,8ed739d7-d24f-4919-bd1d-339e34536a9b,True,data/04_models/pipeline/webapp/installed/cropped_images/8ed739d7-d24f-4919-bd1d-339e34536a9b.jpg,-0.1349107606696157,left,Susan Wagner
|
6 |
+
4,4,4,1825,1826,"data/01_raw/elk-river/2022/Boat_Electrofishing/Fish/October/900,088,000,909,142_F.jpg",2022,900088000909142,,2022-10-05 14:14:45,True,False,Apple,iPhone 12 Pro Max,5.1,4032.0,3024.0,6.922439780250739,1.3561438092556088,5.0997288781417165,49.60164128611111,-114.96475872222224,ec1c6fe4-b2b5-491a-9d7f-542b85e2b078,True,data/04_models/pipeline/webapp/installed/cropped_images/ec1c6fe4-b2b5-491a-9d7f-542b85e2b078.jpg,-3.138177921264415,left,Peter Montgomery
|
data/pipeline/models/identification/config.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
n_keypoints: 1024
|
2 |
+
extractor_type: aliked
|
data/pipeline/models/identification/features.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:053c31af351b7aa685c003a26435eadabfcece6d046f10060f8ad5f90ab9e30e
|
3 |
+
size 2689138
|
data/pipeline/models/pose/weights.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4081cc996b6d436bc0131e5756b076cdd58c372a319df3aaa688abf559b286c8
|
3 |
+
size 5701249
|
data/pipeline/models/segmentation/weights.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6bc54dad5a4d62746ddf4e7edb4402642e905663048da1ff61af969a2d2db604
|
3 |
+
size 6015837
|
data/summary.csv
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
,Unnamed: 0.1,Unnamed: 0,filepath,year,pit,created_at_filepath,created_at_exif,is_electrofishing,is_guide_angling,exif_make,exif_model,exif_focal_length,exif_image_width,exif_image_height,exif_shutter_speed,exif_aperture,exif_brightness,coordinates_lat,coordinates_lon,uuid,success,filepath_crop,pose_theta,pose_fish_side
|
2 |
+
0,1311,1312,data/01_raw/elk-river/2023/Guide/Fish/Nupqu 2/2023-09-15_1040_2_900088000908738_F.jpg,2023,900088000908738,2023-09-15 10:40:00,,False,True,,,,4032.0,3024.0,,,,,,8042e02f-1d73-4251-8b76-3f634ffe0196,True,data/03_processed/identification/input/images/8042e02f-1d73-4251-8b76-3f634ffe0196.jpg,2.9707219143490304,left
|
3 |
+
1,2306,2308,"data/01_raw/elk-river/2022/Boat_Electrofishing/Fish/August/989,001,006,037,192_F.jpg",2022,989001006037192,,2022-08-10 14:24:34,True,False,Apple,iPhone 12 Pro Max,5.1,4032.0,3024.0,10.17001733102253,1.3561438092556088,7.527949339379251,49.447309527777776,-115.05773913888888,0495c348-a87c-4e70-8a1f-9e07e8510977,True,data/03_processed/identification/input/images/0495c348-a87c-4e70-8a1f-9e07e8510977.jpg,3.0887346928329578,left
|
4 |
+
2,142,142,data/01_raw/elk-river/2023/Boat Electrofishing/Fish/Aug 16 2023/2023-08-16_1146_7_900088000913914_F.jpg,2023,900088000913914,2023-08-16 11:46:00,2023-08-16 11:47:39,True,False,Apple,iPhone 13 Pro Max,5.7,4032.0,3024.0,9.47009,1.169925,7.41103,49.479367833333335,-115.06903196666666,4d7bd307-1224-41e6-ba89-82dd85e69ea9,True,data/03_processed/identification/input/images/4d7bd307-1224-41e6-ba89-82dd85e69ea9.jpg,-0.0302701344557054,left
|
5 |
+
3,2264,2266,"data/01_raw/elk-river/2022/Boat_Electrofishing/Fish/August/989,001,006,004,046_F_2.jpg",2022,989001006004046,,2022-08-05 14:00:47,True,False,Apple,iPhone 13 Pro Max,5.7,4032.0,3024.0,10.445739257101238,1.1699250021066825,8.260268601117538,49.374336702777775,-115.01017866666666,8ed739d7-d24f-4919-bd1d-339e34536a9b,True,data/03_processed/identification/input/images/8ed739d7-d24f-4919-bd1d-339e34536a9b.jpg,-0.1349107606696157,left
|
6 |
+
4,1825,1826,"data/01_raw/elk-river/2022/Boat_Electrofishing/Fish/October/900,088,000,909,142_F.jpg",2022,900088000909142,,2022-10-05 14:14:45,True,False,Apple,iPhone 12 Pro Max,5.1,4032.0,3024.0,6.922439780250739,1.3561438092556088,5.0997288781417165,49.60164128611111,-114.96475872222224,ec1c6fe4-b2b5-491a-9d7f-542b85e2b078,True,data/03_processed/identification/input/images/ec1c6fe4-b2b5-491a-9d7f-542b85e2b078.jpg,-3.138177921264415,left
|
identification.py
ADDED
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module to manage the identification model. One can load and run inference on a
|
3 |
+
new image.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import Any
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import pandas as pd
|
12 |
+
import torch
|
13 |
+
from lightglue import ALIKED, DISK, SIFT, LightGlue, SuperPoint
|
14 |
+
from lightglue.utils import numpy_image_to_torch, rbd
|
15 |
+
from PIL import Image
|
16 |
+
|
17 |
+
from utils import (
|
18 |
+
extractor_type_to_extractor,
|
19 |
+
extractor_type_to_matcher,
|
20 |
+
get_scores,
|
21 |
+
wasserstein,
|
22 |
+
)
|
23 |
+
from viz2d import keypoints_as_pil_image, matches_as_pil_image
|
24 |
+
|
25 |
+
|
26 |
+
@dataclass
|
27 |
+
class IdentificationModel:
|
28 |
+
extractor: SIFT | ALIKED | DISK | SuperPoint
|
29 |
+
extractor_type: str
|
30 |
+
threshold_wasserstein: float
|
31 |
+
n_keypoints: int
|
32 |
+
matcher: LightGlue
|
33 |
+
features_dict: dict[str, torch.Tensor]
|
34 |
+
df_db: pd.DataFrame
|
35 |
+
|
36 |
+
|
37 |
+
def load(
|
38 |
+
device: torch.device,
|
39 |
+
filepath_features: Path,
|
40 |
+
filepath_db: Path,
|
41 |
+
extractor_type: str,
|
42 |
+
n_keypoints: int,
|
43 |
+
threshold_wasserstein: float,
|
44 |
+
) -> IdentificationModel:
|
45 |
+
"""
|
46 |
+
Load the IdentificationModel provided the arguments.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
device (torch.device): cpu|cuda
|
50 |
+
filepath_features (Path): filepath to the torch cached features on the
|
51 |
+
dataset one wants to predict on.
|
52 |
+
filepath_db (Path): filepath to the csv file containing the dataset to
|
53 |
+
compare with.
|
54 |
+
extractor_type (str): in {sift, disk, aliked, superpoint}.
|
55 |
+
n_keypoints (int): maximum number of keypoints to extract with the extractor.
|
56 |
+
threshold_wasserstein (float): threshold for the wasserstein distance to consider it a match.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
IdentificationModel: an IdentificationModel instance.
|
60 |
+
|
61 |
+
Raises:
|
62 |
+
AssertionError when the extractor_type or n_keypoints are not valid.
|
63 |
+
"""
|
64 |
+
|
65 |
+
allowed_extractor_types = ["sift", "disk", "aliked", "superpoint"]
|
66 |
+
assert (
|
67 |
+
extractor_type in allowed_extractor_types
|
68 |
+
), f"extractor_type should be in {allowed_extractor_types}"
|
69 |
+
assert 1 <= n_keypoints <= 5000, f"n_keypoints should be in range 1..5000"
|
70 |
+
assert (
|
71 |
+
0.0 <= threshold_wasserstein <= 1.0
|
72 |
+
), f"threshold_wasserstein should be in 0..1"
|
73 |
+
|
74 |
+
extractor = extractor_type_to_extractor(
|
75 |
+
device=device,
|
76 |
+
extractor_type=extractor_type,
|
77 |
+
n_keypoints=n_keypoints,
|
78 |
+
)
|
79 |
+
matcher = extractor_type_to_matcher(
|
80 |
+
device=device,
|
81 |
+
extractor_type=extractor_type,
|
82 |
+
)
|
83 |
+
features_dict = torch.load(filepath_features)
|
84 |
+
df_db = pd.read_csv(filepath_db)
|
85 |
+
|
86 |
+
return IdentificationModel(
|
87 |
+
extractor_type=extractor_type,
|
88 |
+
n_keypoints=n_keypoints,
|
89 |
+
extractor=extractor,
|
90 |
+
matcher=matcher,
|
91 |
+
features_dict=features_dict,
|
92 |
+
df_db=df_db,
|
93 |
+
threshold_wasserstein=threshold_wasserstein,
|
94 |
+
)
|
95 |
+
|
96 |
+
|
97 |
+
def _make_prediction_dict(
|
98 |
+
model: IdentificationModel,
|
99 |
+
indexed_matches: dict[str, dict[str, torch.Tensor]],
|
100 |
+
) -> dict[str, Any]:
|
101 |
+
"""
|
102 |
+
Return the prediction dict. Two types of predictions can be made:
|
103 |
+
1. A new individual
|
104 |
+
2. A match from the dataset
|
105 |
+
|
106 |
+
Returns:
|
107 |
+
type (str): new|match
|
108 |
+
match (dict): dict containing the following keys if type==match.
|
109 |
+
pit (str): the PIT of the matched individual.
|
110 |
+
name (str): the name of the matched individual.
|
111 |
+
filepath_crop_closest (Path): the filepath to the matched individual.
|
112 |
+
features (torch.Tensor): LightGlue Features of the matched individual.
|
113 |
+
matches (torch.Tensor): LightGlue Matches of the matched individual.
|
114 |
+
"""
|
115 |
+
indexed_scores = {k: get_scores(v) for k, v in indexed_matches.items()}
|
116 |
+
indexed_wasserstein = {k: wasserstein(v) for k, v in indexed_scores.items()}
|
117 |
+
sorted_wasserstein = sorted(
|
118 |
+
indexed_wasserstein.items(), key=lambda item: item[1], reverse=True
|
119 |
+
)
|
120 |
+
shared_record = {
|
121 |
+
"indexed_matches": indexed_matches,
|
122 |
+
"indexed_scores": indexed_scores,
|
123 |
+
"indexed_wasserstein": indexed_wasserstein,
|
124 |
+
"sorted_wasserstein": sorted_wasserstein,
|
125 |
+
}
|
126 |
+
if not sorted_wasserstein:
|
127 |
+
return {"type": "new", **shared_record}
|
128 |
+
elif model.threshold_wasserstein > sorted_wasserstein[0][1]:
|
129 |
+
return {"type": "new", **shared_record}
|
130 |
+
else:
|
131 |
+
prediction_uuid = sorted_wasserstein[0][0]
|
132 |
+
db_row = model.df_db[model.df_db["uuid"] == prediction_uuid].iloc[0]
|
133 |
+
return {
|
134 |
+
"type": "match",
|
135 |
+
"match": {
|
136 |
+
"pit": db_row["pit"],
|
137 |
+
"name": db_row["name"],
|
138 |
+
"filepath_crop": db_row["filepath_crop"],
|
139 |
+
"features": model.features_dict[prediction_uuid],
|
140 |
+
"matches": indexed_matches[prediction_uuid],
|
141 |
+
},
|
142 |
+
**shared_record,
|
143 |
+
}
|
144 |
+
|
145 |
+
|
146 |
+
# FIXME: Properly run a batch inference here to make it fast on GPU.
|
147 |
+
def _batch_inference(
|
148 |
+
model: IdentificationModel,
|
149 |
+
feats0: dict,
|
150 |
+
) -> dict[str, dict[str, torch.Tensor]]:
|
151 |
+
"""
|
152 |
+
Run batch inference on feats0 with the IdentificationModel.
|
153 |
+
Returns an indexed_matches datastructure containing the results of each run
|
154 |
+
for the given uuid in the features_dict.
|
155 |
+
"""
|
156 |
+
indexed_matches = {}
|
157 |
+
for uuid in model.features_dict.keys():
|
158 |
+
matches01 = model.matcher(
|
159 |
+
{"image0": feats0, "image1": model.features_dict[uuid]}
|
160 |
+
)
|
161 |
+
indexed_matches[uuid] = matches01
|
162 |
+
return indexed_matches
|
163 |
+
|
164 |
+
|
165 |
+
def predict(model: IdentificationModel, pil_image: Image.Image) -> dict:
|
166 |
+
"""
|
167 |
+
Run inference on the pil_image on all the features_dict entries from the
|
168 |
+
IdentificationModel.
|
169 |
+
|
170 |
+
Note: It will try to optimize inference depending on the available device
|
171 |
+
(cpu|gpu).
|
172 |
+
|
173 |
+
Args:
|
174 |
+
model (IdentificationModel): identification model to run inference with.
|
175 |
+
pil_image (PIL): input image to run the inference on.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
type (str): new|match.
|
179 |
+
source (dict): contains the `features` of the input image.
|
180 |
+
match (dict): dict containing the following keys if type==match.
|
181 |
+
pit (str): the PIT of the matched individual.
|
182 |
+
name (str): the name of the matched individual.
|
183 |
+
filepath_crop (Path): the filepath to the matched individual.
|
184 |
+
features (torch.Tensor): LightGlue Features of the matched individual.
|
185 |
+
matches (torch.Tensor): LightGlue Matches of the matched individual.
|
186 |
+
"""
|
187 |
+
# Disable gradient accumulation to make inference faster
|
188 |
+
torch.set_grad_enabled(False)
|
189 |
+
torch_image = numpy_image_to_torch(np.array(pil_image))
|
190 |
+
feats0 = model.extractor.extract(torch_image)
|
191 |
+
indexed_matches = _batch_inference(model=model, feats0=feats0)
|
192 |
+
prediction_dict = _make_prediction_dict(
|
193 |
+
model=model,
|
194 |
+
indexed_matches=indexed_matches,
|
195 |
+
)
|
196 |
+
return {"source": {"features": feats0}, **prediction_dict}
|
197 |
+
|
198 |
+
|
199 |
+
def generate_visualization(pil_image: Image.Image, prediction: dict) -> dict:
|
200 |
+
if "type" not in prediction:
|
201 |
+
return {}
|
202 |
+
elif prediction["type"] == "match":
|
203 |
+
pil_image_masked_closest = Image.open(prediction["match"]["filepath_crop"])
|
204 |
+
torch_image0 = np.array(pil_image)
|
205 |
+
torch_image1 = np.array(pil_image_masked_closest)
|
206 |
+
torch_images = [torch_image0, torch_image1]
|
207 |
+
feats0 = prediction["source"]["features"]
|
208 |
+
feats1 = prediction["match"]["features"]
|
209 |
+
matches01 = prediction["match"]["matches"]
|
210 |
+
|
211 |
+
feats0, feats1, matches01 = [
|
212 |
+
rbd(x) for x in [feats0, feats1, matches01]
|
213 |
+
] # remove batch dimension
|
214 |
+
pil_image_matches = matches_as_pil_image(
|
215 |
+
torch_images=torch_images,
|
216 |
+
feats0=feats0,
|
217 |
+
feats1=feats1,
|
218 |
+
matches01=matches01,
|
219 |
+
mode="column",
|
220 |
+
)
|
221 |
+
pil_image_keypoints_source = keypoints_as_pil_image(
|
222 |
+
torch_image=torch_image0,
|
223 |
+
feats=feats0,
|
224 |
+
ps=23,
|
225 |
+
)
|
226 |
+
return {
|
227 |
+
"matches": pil_image_matches,
|
228 |
+
"keypoints_source": pil_image_keypoints_source,
|
229 |
+
}
|
230 |
+
elif prediction["type"] == "new":
|
231 |
+
torch_image0 = np.array(pil_image)
|
232 |
+
feats0 = prediction["source"]["features"]
|
233 |
+
feats0 = rbd(feats0) # remove the batch dimension
|
234 |
+
pil_image_keypoints_source = keypoints_as_pil_image(
|
235 |
+
torch_image=torch_image0,
|
236 |
+
feats=feats0,
|
237 |
+
ps=23,
|
238 |
+
)
|
239 |
+
return {"keypoints_source": pil_image_keypoints_source}
|
240 |
+
else:
|
241 |
+
return {}
|
pipeline.py
ADDED
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import Any
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from ultralytics import YOLO
|
8 |
+
|
9 |
+
import identification
|
10 |
+
import pose
|
11 |
+
import segmentation
|
12 |
+
from identification import IdentificationModel
|
13 |
+
from utils import (
|
14 |
+
PictureLayout,
|
15 |
+
crop,
|
16 |
+
get_picture_layout,
|
17 |
+
get_segmentation_mask_crop_box,
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
def load_pose_and_segmentation_models(
|
22 |
+
filepath_weights_segmentation_model: Path,
|
23 |
+
filepath_weights_pose_model: Path,
|
24 |
+
) -> dict[str, YOLO]:
|
25 |
+
"""
|
26 |
+
Load into memory the models used by the pipeline.
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
segmentation (YOLO): segmentation model.
|
30 |
+
pose (YOLO): pose estimation model.
|
31 |
+
"""
|
32 |
+
model_segmentation = segmentation.load_pretrained_model(
|
33 |
+
str(filepath_weights_segmentation_model)
|
34 |
+
)
|
35 |
+
|
36 |
+
model_pose = pose.load_pretrained_model(str(filepath_weights_pose_model))
|
37 |
+
return {
|
38 |
+
"segmentation": model_segmentation,
|
39 |
+
"pose": model_pose,
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
def load_models(
|
44 |
+
filepath_weights_segmentation_model: Path,
|
45 |
+
filepath_weights_pose_model: Path,
|
46 |
+
device: torch.device,
|
47 |
+
filepath_identification_lightglue_features: Path,
|
48 |
+
filepath_identification_db: Path,
|
49 |
+
extractor_type: str,
|
50 |
+
n_keypoints: int,
|
51 |
+
threshold_wasserstein: float,
|
52 |
+
) -> dict[str, YOLO | IdentificationModel]:
|
53 |
+
"""
|
54 |
+
Load into memory the models used by the pipeline.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
segmentation (YOLO): segmentation model.
|
58 |
+
pose (YOLO): pose estimation model.
|
59 |
+
identification (IdentificationModel): identification model.
|
60 |
+
"""
|
61 |
+
loaded_pose_seg_models = load_pose_and_segmentation_models(
|
62 |
+
filepath_weights_segmentation_model=filepath_weights_segmentation_model,
|
63 |
+
filepath_weights_pose_model=filepath_weights_pose_model,
|
64 |
+
)
|
65 |
+
|
66 |
+
model_identification = identification.load(
|
67 |
+
device=device,
|
68 |
+
filepath_features=filepath_identification_lightglue_features,
|
69 |
+
filepath_db=filepath_identification_db,
|
70 |
+
n_keypoints=n_keypoints,
|
71 |
+
extractor_type=extractor_type,
|
72 |
+
threshold_wasserstein=threshold_wasserstein,
|
73 |
+
)
|
74 |
+
|
75 |
+
return {**loaded_pose_seg_models, "identification": model_identification}
|
76 |
+
|
77 |
+
|
78 |
+
def run_preprocess(pil_image: Image.Image) -> dict[str, Any]:
|
79 |
+
"""
|
80 |
+
Run the preprocess stage of the pipeline.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
pil_image (PIL): original image.
|
84 |
+
|
85 |
+
Returns:
|
86 |
+
pil_image (PIL): rotated image to make it a landscape.
|
87 |
+
layout (PictureLayout): layout type of the input image.
|
88 |
+
"""
|
89 |
+
picture_layout = get_picture_layout(pil_image=pil_image)
|
90 |
+
|
91 |
+
# If the image is in Portrait Mode, we turn it into Landscape
|
92 |
+
pil_image_preprocessed = (
|
93 |
+
pil_image.rotate(angle=90, expand=True)
|
94 |
+
if picture_layout == PictureLayout.PORTRAIT
|
95 |
+
else pil_image
|
96 |
+
)
|
97 |
+
return {
|
98 |
+
"pil_image": pil_image_preprocessed,
|
99 |
+
"layout": picture_layout,
|
100 |
+
}
|
101 |
+
|
102 |
+
|
103 |
+
def run_pose(model: YOLO, pil_image: Image.Image) -> dict[str, Any]:
|
104 |
+
"""
|
105 |
+
Run the pose stage of the pipeline.
|
106 |
+
|
107 |
+
Args:
|
108 |
+
model (YOLO): loaded pose estimation model.
|
109 |
+
pil_image (PIL): Image to run the model on.
|
110 |
+
|
111 |
+
Returns:
|
112 |
+
prediction: Raw prediction from the model.
|
113 |
+
orig_image: original image used for inference after the preprocessing
|
114 |
+
stages applied by ultralytics.
|
115 |
+
keypoints_xy (np.ndarray): keypoints in xy format.
|
116 |
+
keypoints_xyn (np.ndarray): keyoints in xyn format.
|
117 |
+
theta (float): angle in radians to rotate the image to re-align it
|
118 |
+
horizontally.
|
119 |
+
side (FishSide): Predicted side of the fish.
|
120 |
+
"""
|
121 |
+
return pose.predict(model=model, pil_image=pil_image)
|
122 |
+
|
123 |
+
|
124 |
+
def run_crop(
|
125 |
+
pil_image_mask: Image.Image,
|
126 |
+
pil_image_masked: Image.Image,
|
127 |
+
padding: int,
|
128 |
+
) -> dict[str, Any]:
|
129 |
+
"""
|
130 |
+
Run the crop on the mask and masked images.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
pil_image_mask (PIL): Image containing the segmentation mask.
|
134 |
+
pil_image_masked (PIL): Image containing the applied pil_image_mask on
|
135 |
+
the original image.
|
136 |
+
padding (int): by how much do we want to pad the result image?
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
box (Tuple[int, int, int, int]): 4 tuple representing a rectangle (x1,
|
140 |
+
y1, x2, y2) with the upper left corner given first.
|
141 |
+
pil_image (PIL): cropped masked image.
|
142 |
+
"""
|
143 |
+
|
144 |
+
box_crop = get_segmentation_mask_crop_box(
|
145 |
+
pil_image_mask=pil_image_mask,
|
146 |
+
padding=padding,
|
147 |
+
)
|
148 |
+
pil_image_masked_cropped = crop(
|
149 |
+
pil_image=pil_image_masked,
|
150 |
+
box=box_crop,
|
151 |
+
)
|
152 |
+
return {
|
153 |
+
"box": box_crop,
|
154 |
+
"pil_image": pil_image_masked_cropped,
|
155 |
+
}
|
156 |
+
|
157 |
+
|
158 |
+
def run_rotation(
|
159 |
+
pil_image: Image.Image,
|
160 |
+
angle_rad: float,
|
161 |
+
keypoints_xy: np.ndarray,
|
162 |
+
) -> dict[str, Any]:
|
163 |
+
"""
|
164 |
+
Run the rotation stage of the pipeline.
|
165 |
+
|
166 |
+
Args:
|
167 |
+
pil_image (PIL): image to run the rotation on.
|
168 |
+
angle_rad (float): angle in radian to rotate the image.
|
169 |
+
keypoints_xy (np.ndarray): keypoints from the pose estimation
|
170 |
+
prediction in xy format.
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
array_image (np.ndarray): rotated array_image as a 2D numpy array.
|
174 |
+
keypoints_xy (np.ndarray): rotated keypoints in xy format.
|
175 |
+
pil_image (PIL): rotated PIL image.
|
176 |
+
"""
|
177 |
+
results_rotation = pose.rotate_image_and_keypoints_xy(
|
178 |
+
angle_rad=angle_rad,
|
179 |
+
array_image=np.array(pil_image),
|
180 |
+
keypoints_xy=keypoints_xy,
|
181 |
+
)
|
182 |
+
pil_image_rotated = Image.fromarray(results_rotation["array_image"])
|
183 |
+
|
184 |
+
return {
|
185 |
+
"pil_image": pil_image_rotated,
|
186 |
+
"array_image": results_rotation["array_image"],
|
187 |
+
"keypoints_xy": results_rotation["keypoints_xy"],
|
188 |
+
}
|
189 |
+
|
190 |
+
|
191 |
+
def run_segmentation(model: YOLO, pil_image: Image.Image) -> dict[str, Any]:
|
192 |
+
"""
|
193 |
+
Run the segmentation stage of the pipeline.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
pil_image (PIL): image to run the rotation on.
|
197 |
+
model (YOLO): segmentation model.
|
198 |
+
prediction in xy format.
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
prediction: Raw prediction from the model.
|
202 |
+
orig_image: original image used for inference
|
203 |
+
after preprocessing stages applied by
|
204 |
+
ultralytics.
|
205 |
+
mask (PIL): postprocessed mask in white and black format - used for visualization
|
206 |
+
mask_raw (np.ndarray): Raw mask not postprocessed
|
207 |
+
masked (PIL): mask applied to the pil_image.
|
208 |
+
"""
|
209 |
+
results_segmentation = segmentation.predict(
|
210 |
+
model=model,
|
211 |
+
pil_image=pil_image,
|
212 |
+
)
|
213 |
+
return results_segmentation
|
214 |
+
|
215 |
+
|
216 |
+
def run_pre_identification_stages(
|
217 |
+
loaded_models: dict[str, YOLO],
|
218 |
+
pil_image: Image.Image,
|
219 |
+
param_crop_padding: int = 0,
|
220 |
+
) -> dict[str, Any]:
|
221 |
+
"""
|
222 |
+
Run the partial ML pipeline on `pil_image` up to identifying the fish. It
|
223 |
+
prepares the input image `pil_image` to make it possible to identify it.
|
224 |
+
|
225 |
+
Args:
|
226 |
+
loaded_models (dict[str, YOLO]): resut of calling `load_models`.
|
227 |
+
pil_image (PIL): Image to run the pipeline on.
|
228 |
+
param_crop_padding (int): how much to pad the resulting segmentated
|
229 |
+
image when cropped.
|
230 |
+
|
231 |
+
Returns:
|
232 |
+
order (list[str]): the stages and their order.
|
233 |
+
stages (dict[str, Any]): the description of each stage, its
|
234 |
+
input and output.
|
235 |
+
"""
|
236 |
+
|
237 |
+
# Unpacking the loaded models
|
238 |
+
model_pose = loaded_models["pose"]
|
239 |
+
model_segmentation = loaded_models["segmentation"]
|
240 |
+
|
241 |
+
# Stage: Preprocess
|
242 |
+
results_preprocess = run_preprocess(pil_image=pil_image)
|
243 |
+
|
244 |
+
# Stage: Pose estimation
|
245 |
+
pil_image_preprocessed = results_preprocess["pil_image"]
|
246 |
+
results_pose = run_pose(model=model_pose, pil_image=pil_image_preprocessed)
|
247 |
+
|
248 |
+
# Stage: Rotation
|
249 |
+
results_rotation = run_rotation(
|
250 |
+
pil_image=pil_image_preprocessed,
|
251 |
+
keypoints_xy=results_pose["keypoints_xy"],
|
252 |
+
angle_rad=results_pose["theta"],
|
253 |
+
)
|
254 |
+
|
255 |
+
# Stage: Segmentation
|
256 |
+
pil_image_rotated = Image.fromarray(results_rotation["array_image"])
|
257 |
+
results_segmentation = run_segmentation(
|
258 |
+
model=model_segmentation, pil_image=pil_image_rotated
|
259 |
+
)
|
260 |
+
|
261 |
+
# Stage: Crop
|
262 |
+
results_crop = run_crop(
|
263 |
+
pil_image_mask=results_segmentation["mask"],
|
264 |
+
pil_image_masked=results_segmentation["masked"],
|
265 |
+
padding=param_crop_padding,
|
266 |
+
)
|
267 |
+
|
268 |
+
return {
|
269 |
+
"order": [
|
270 |
+
"preprocess",
|
271 |
+
"pose",
|
272 |
+
"rotation",
|
273 |
+
"segmentation",
|
274 |
+
"crop",
|
275 |
+
],
|
276 |
+
"stages": {
|
277 |
+
"preprocess": {
|
278 |
+
"input": {"pil_image": pil_image},
|
279 |
+
"output": results_preprocess,
|
280 |
+
},
|
281 |
+
"pose": {
|
282 |
+
"input": {"pil_image": pil_image_preprocessed},
|
283 |
+
"output": results_pose,
|
284 |
+
},
|
285 |
+
"rotation": {
|
286 |
+
"input": {
|
287 |
+
"pil_image": pil_image_preprocessed,
|
288 |
+
"angle_rad": results_pose["theta"],
|
289 |
+
"keypoints_xy": results_pose["keypoints_xy"],
|
290 |
+
},
|
291 |
+
"output": results_rotation,
|
292 |
+
},
|
293 |
+
"segmentation": {
|
294 |
+
"input": {"pil_image": pil_image_rotated},
|
295 |
+
"output": results_segmentation,
|
296 |
+
},
|
297 |
+
"crop": {
|
298 |
+
"input": {
|
299 |
+
"pil_image_mask": results_segmentation["mask"],
|
300 |
+
"pil_image_masked": results_segmentation["masked"],
|
301 |
+
"padding": param_crop_padding,
|
302 |
+
},
|
303 |
+
"output": results_crop,
|
304 |
+
},
|
305 |
+
},
|
306 |
+
}
|
307 |
+
|
308 |
+
|
309 |
+
def run(
|
310 |
+
loaded_models: dict[str, YOLO | IdentificationModel],
|
311 |
+
pil_image: Image.Image,
|
312 |
+
param_crop_padding: int = 0,
|
313 |
+
) -> dict[str, Any]:
|
314 |
+
"""
|
315 |
+
Run the ML pipeline on `pil_image`.
|
316 |
+
|
317 |
+
Args:
|
318 |
+
loaded_models (dict[str, YOLO]): resut of calling `load_models`.
|
319 |
+
pil_image (PIL): Image to run the pipeline on.
|
320 |
+
param_crop_padding (int): how much to pad the resulting segmentated
|
321 |
+
image when cropped.
|
322 |
+
|
323 |
+
Returns:
|
324 |
+
order (list[str]): the stages and their order.
|
325 |
+
stages (dict[str, Any]): the description of each stage, its
|
326 |
+
input and output.
|
327 |
+
"""
|
328 |
+
model_identification = loaded_models["identification"]
|
329 |
+
|
330 |
+
results = run_pre_identification_stages(
|
331 |
+
loaded_models=loaded_models,
|
332 |
+
pil_image=pil_image,
|
333 |
+
param_crop_padding=param_crop_padding,
|
334 |
+
)
|
335 |
+
|
336 |
+
results_crop = results["stages"]["crop"]["output"]
|
337 |
+
results_identification = identification.predict(
|
338 |
+
model=model_identification,
|
339 |
+
pil_image=results_crop["pil_image"],
|
340 |
+
)
|
341 |
+
|
342 |
+
results["order"].append("identification")
|
343 |
+
results["stages"]["identification"] = {
|
344 |
+
"input": {"pil_image": results_crop["pil_image"]},
|
345 |
+
"output": results_identification,
|
346 |
+
}
|
347 |
+
|
348 |
+
return results
|
pose.py
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module to manage the pose detection model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from enum import Enum
|
6 |
+
from pathlib import Path
|
7 |
+
from typing import Any
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
from ultralytics import YOLO
|
12 |
+
|
13 |
+
import yolo
|
14 |
+
from utils import get_angle_correction, get_keypoint, rotate_image_and_keypoints_xy
|
15 |
+
|
16 |
+
|
17 |
+
class FishSide(Enum):
|
18 |
+
"""
|
19 |
+
Represents the Side of the Fish.
|
20 |
+
"""
|
21 |
+
|
22 |
+
RIGHT = "right"
|
23 |
+
LEFT = "left"
|
24 |
+
|
25 |
+
|
26 |
+
def predict_fish_side(
|
27 |
+
array_image: np.ndarray,
|
28 |
+
keypoints_xy: np.ndarray,
|
29 |
+
classes_dictionnary: dict[int, str],
|
30 |
+
) -> FishSide:
|
31 |
+
"""
|
32 |
+
Predict which side of the fish is displayed on the image.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
array_image (np.ndarray): numpy array representing the image.
|
36 |
+
keypoints_xy (np.ndarray): detected keypoints on array_image in
|
37 |
+
xy format.
|
38 |
+
classes_dictionnary (dict[int, str]): mapping of class instance
|
39 |
+
to.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
FishSide: Predicted side of the fish.
|
43 |
+
"""
|
44 |
+
|
45 |
+
theta = get_angle_correction(
|
46 |
+
keypoints_xy=keypoints_xy,
|
47 |
+
array_image=array_image,
|
48 |
+
classes_dictionnary=classes_dictionnary,
|
49 |
+
)
|
50 |
+
rotation_results = rotate_image_and_keypoints_xy(
|
51 |
+
angle_rad=theta, array_image=array_image, keypoints_xy=keypoints_xy
|
52 |
+
)
|
53 |
+
|
54 |
+
# We check if the eyes is on the left/right of one of the fins.
|
55 |
+
k_eye = get_keypoint(
|
56 |
+
class_name="eye",
|
57 |
+
keypoints=rotation_results["keypoints_xy"],
|
58 |
+
classes_dictionnary=classes_dictionnary,
|
59 |
+
)
|
60 |
+
k_anal_fin_base = get_keypoint(
|
61 |
+
class_name="anal_fin_base",
|
62 |
+
keypoints=rotation_results["keypoints_xy"],
|
63 |
+
classes_dictionnary=classes_dictionnary,
|
64 |
+
)
|
65 |
+
|
66 |
+
if k_eye[0] <= k_anal_fin_base[0]:
|
67 |
+
return FishSide.LEFT
|
68 |
+
else:
|
69 |
+
return FishSide.RIGHT
|
70 |
+
|
71 |
+
|
72 |
+
# Model prediction classes
|
73 |
+
CLASSES_DICTIONNARY = {
|
74 |
+
0: "eye",
|
75 |
+
1: "front_fin_base",
|
76 |
+
2: "tail_bottom_tip",
|
77 |
+
3: "tail_top_tip",
|
78 |
+
4: "dorsal_fin_base",
|
79 |
+
5: "pelvic_fin_base",
|
80 |
+
6: "anal_fin_base",
|
81 |
+
}
|
82 |
+
|
83 |
+
|
84 |
+
def load_pretrained_model(model_str: str) -> YOLO:
|
85 |
+
"""
|
86 |
+
Load the pretrained model.
|
87 |
+
"""
|
88 |
+
return yolo.load_pretrained_model(model_str)
|
89 |
+
|
90 |
+
|
91 |
+
def train(
|
92 |
+
model: YOLO,
|
93 |
+
data_yaml_path: Path,
|
94 |
+
params: dict,
|
95 |
+
project: Path = Path("data/04_models/yolo/"),
|
96 |
+
experiment_name: str = "train",
|
97 |
+
):
|
98 |
+
"""Main function for running a train run. It saves the results
|
99 |
+
under `project / experiment_name`.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
model (YOLO): result of `load_pretrained_model`.
|
103 |
+
data_yaml_path (Path): filepath to the data.yaml file that specifies the split and classes to train on
|
104 |
+
params (dict): parameters to override when running the training. See https://docs.ultralytics.com/modes/train/#train-settings for a complete list of parameters.
|
105 |
+
project (Path): root path to store the run artifacts and results.
|
106 |
+
experiment_name (str): name of the experiment, that is added to the project root path to store the run.
|
107 |
+
"""
|
108 |
+
return yolo.train(
|
109 |
+
model=model,
|
110 |
+
data_yaml_path=data_yaml_path,
|
111 |
+
params=params,
|
112 |
+
project=project,
|
113 |
+
experiment_name=experiment_name,
|
114 |
+
)
|
115 |
+
|
116 |
+
|
117 |
+
def predict(
|
118 |
+
model: YOLO,
|
119 |
+
pil_image: Image.Image,
|
120 |
+
classes_dictionnary: dict[int, str] = CLASSES_DICTIONNARY,
|
121 |
+
) -> dict[str, Any]:
|
122 |
+
"""
|
123 |
+
Given a loaded model and a PIL image, it returns a map containing the
|
124 |
+
keypoints predictions.
|
125 |
+
|
126 |
+
Args:
|
127 |
+
model (YOLO): loaded YOLO model for pose estimation.
|
128 |
+
pil_image (PIL): image to run the model on.
|
129 |
+
classes_dictionnary (dict[int, str]): mapping of class instance to
|
130 |
+
class name.
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
prediction: Raw prediction from the model.
|
134 |
+
orig_image: original image used for inference after the preprocessing
|
135 |
+
stages applied by ultralytics.
|
136 |
+
keypoints_xy (np.ndarray): keypoints in xy format.
|
137 |
+
keypoints_xyn (np.ndarray): keyoints in xyn format.
|
138 |
+
theta (float): angle in radians to rotate the image to re-align it
|
139 |
+
horizontally.
|
140 |
+
side (FishSide): Predicted side of the fish.
|
141 |
+
"""
|
142 |
+
predictions = model(pil_image)
|
143 |
+
print(predictions)
|
144 |
+
orig_image = predictions[0].orig_img
|
145 |
+
keypoints_xy = predictions[0].keypoints.xy.cpu().numpy().squeeze()
|
146 |
+
|
147 |
+
theta = get_angle_correction(
|
148 |
+
keypoints_xy=keypoints_xy,
|
149 |
+
array_image=orig_image,
|
150 |
+
classes_dictionnary=classes_dictionnary,
|
151 |
+
)
|
152 |
+
|
153 |
+
side = predict_fish_side(
|
154 |
+
array_image=orig_image,
|
155 |
+
keypoints_xy=keypoints_xy,
|
156 |
+
classes_dictionnary=classes_dictionnary,
|
157 |
+
)
|
158 |
+
|
159 |
+
return {
|
160 |
+
"prediction": predictions[0],
|
161 |
+
"orig_image": orig_image,
|
162 |
+
"keypoints_xy": keypoints_xy,
|
163 |
+
"keypoints_xyn": predictions[0].keypoints.xyn.cpu().numpy().squeeze(),
|
164 |
+
"theta": theta,
|
165 |
+
"side": side,
|
166 |
+
}
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==5.4.*
|
2 |
+
pandas==2.2.*
|
3 |
+
torch==2.5.*
|
4 |
+
numpy==2.1.*
|
5 |
+
tqdm==4.66.*
|
6 |
+
ultralytics==8.3.*
|
7 |
+
matplotlib==3.9.*
|
8 |
+
lightglue @ git+https://github.com/cvg/LightGlue.git@edb2b838efb2ecfe3f88097c5fad9887d95aedad
|
segmentation.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Module to manage the segmentation YOLO model.
|
3 |
+
"""
|
4 |
+
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import Any
|
7 |
+
|
8 |
+
import cv2
|
9 |
+
import numpy as np
|
10 |
+
from PIL import Image
|
11 |
+
from ultralytics import YOLO
|
12 |
+
|
13 |
+
import yolo
|
14 |
+
|
15 |
+
|
16 |
+
def load_pretrained_model(model_str: str) -> YOLO:
|
17 |
+
"""
|
18 |
+
Load the pretrained model.
|
19 |
+
"""
|
20 |
+
return yolo.load_pretrained_model(model_str)
|
21 |
+
|
22 |
+
|
23 |
+
def train(
|
24 |
+
model: YOLO,
|
25 |
+
data_yaml_path: Path,
|
26 |
+
params: dict,
|
27 |
+
project: Path = Path("data/04_models/yolo/"),
|
28 |
+
experiment_name: str = "train",
|
29 |
+
):
|
30 |
+
"""Main function for running a train run. It saves the results
|
31 |
+
under `project / experiment_name`.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
model (YOLO): result of `load_pretrained_model`.
|
35 |
+
data_yaml_path (Path): filepath to the data.yaml file that specifies the split and classes to train on
|
36 |
+
params (dict): parameters to override when running the training. See https://docs.ultralytics.com/modes/train/#train-settings for a complete list of parameters.
|
37 |
+
project (Path): root path to store the run artifacts and results.
|
38 |
+
experiment_name (str): name of the experiment, that is added to the project root path to store the run.
|
39 |
+
"""
|
40 |
+
return yolo.train(
|
41 |
+
model=model,
|
42 |
+
data_yaml_path=data_yaml_path,
|
43 |
+
params=params,
|
44 |
+
project=project,
|
45 |
+
experiment_name=experiment_name,
|
46 |
+
)
|
47 |
+
|
48 |
+
|
49 |
+
def predict(model: YOLO, pil_image: Image.Image) -> dict[str, Any]:
|
50 |
+
"""
|
51 |
+
Given a loaded model an a PIL image, it returns a map
|
52 |
+
containing the segmentation predictions.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
model (YOLO): loaded YOLO model for segmentation.
|
56 |
+
pil_image (PIL): image to run the model on.
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
prediction: Raw prediction from the model.
|
60 |
+
orig_image: original image used for inference
|
61 |
+
after preprocessing stages applied by
|
62 |
+
ultralytics.
|
63 |
+
mask (PIL): postprocessed mask in white and black format - used for visualization
|
64 |
+
mask_raw (np.ndarray): Raw mask not postprocessed
|
65 |
+
masked (PIL): mask applied to the pil_image.
|
66 |
+
"""
|
67 |
+
predictions = model(pil_image)
|
68 |
+
mask_raw = predictions[0].masks[0].data.cpu().numpy().transpose(1, 2, 0).squeeze()
|
69 |
+
# Convert single channel grayscale to 3 channel image
|
70 |
+
mask_3channel = cv2.merge((mask_raw, mask_raw, mask_raw))
|
71 |
+
# Get the size of the original image (height, width, channels)
|
72 |
+
h2, w2, c2 = predictions[0].orig_img.shape
|
73 |
+
# Resize the mask to the same size as the image (can probably be removed if image is the same size as the model)
|
74 |
+
mask = cv2.resize(mask_3channel, (w2, h2))
|
75 |
+
# Convert BGR to HSV
|
76 |
+
hsv = cv2.cvtColor(mask, cv2.COLOR_BGR2HSV)
|
77 |
+
|
78 |
+
# Define range of brightness in HSV
|
79 |
+
lower_black = np.array([0, 0, 0])
|
80 |
+
upper_black = np.array([0, 0, 1])
|
81 |
+
|
82 |
+
# Create a mask. Threshold the HSV image to get everything black
|
83 |
+
mask = cv2.inRange(mask, lower_black, upper_black)
|
84 |
+
|
85 |
+
# Invert the mask to get everything but black
|
86 |
+
mask = cv2.bitwise_not(mask)
|
87 |
+
|
88 |
+
# Apply the mask to the original image
|
89 |
+
masked = cv2.bitwise_and(
|
90 |
+
predictions[0].orig_img,
|
91 |
+
predictions[0].orig_img,
|
92 |
+
mask=mask,
|
93 |
+
)
|
94 |
+
|
95 |
+
# bgr to rgb and PIL conversion
|
96 |
+
image_output2 = Image.fromarray(masked[:, :, ::-1])
|
97 |
+
# return Image.fromarray(mask), image_output2
|
98 |
+
return {
|
99 |
+
"prediction": predictions[0],
|
100 |
+
"mask": Image.fromarray(mask),
|
101 |
+
"mask_raw": mask_raw,
|
102 |
+
"masked": Image.fromarray(masked[:, :, ::-1]),
|
103 |
+
}
|
ui.py
ADDED
File without changes
|
utils.py
ADDED
@@ -0,0 +1,659 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from enum import Enum
|
3 |
+
from pathlib import Path
|
4 |
+
from typing import Tuple
|
5 |
+
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
from lightglue import ALIKED, DISK, SIFT, LightGlue, SuperPoint
|
10 |
+
from PIL import Image
|
11 |
+
from scipy.stats import wasserstein_distance
|
12 |
+
|
13 |
+
|
14 |
+
def select_best_device() -> torch.device:
|
15 |
+
"""
|
16 |
+
Select best available device (cpu or cuda) based on availability.
|
17 |
+
"""
|
18 |
+
if torch.cuda.is_available():
|
19 |
+
return torch.device("cuda")
|
20 |
+
else:
|
21 |
+
return torch.device("cpu")
|
22 |
+
|
23 |
+
|
24 |
+
def bgr_to_rgb(a: np.ndarray) -> np.ndarray:
|
25 |
+
"""
|
26 |
+
Turn a BGR numpy array into a RGB numpy array.
|
27 |
+
"""
|
28 |
+
return a[:, :, ::-1]
|
29 |
+
|
30 |
+
|
31 |
+
ALLOWED_EXTRACTOR_TYPES = ["sift", "disk", "superpoint", "aliked"]
|
32 |
+
|
33 |
+
|
34 |
+
def extractor_type_to_extractor(
|
35 |
+
device: torch.device,
|
36 |
+
extractor_type: str,
|
37 |
+
n_keypoints: int = 1024,
|
38 |
+
):
|
39 |
+
"""
|
40 |
+
Given an extractor_type in {'sift', 'superpoint', 'aliked', 'disk'},
|
41 |
+
returns a LightGlue extractor.
|
42 |
+
|
43 |
+
Args:
|
44 |
+
device (torch.device): cpu or cuda
|
45 |
+
extractor_type (str): in {sift, superpoint, aliked, disk}
|
46 |
+
n_keypoints (int): number of max keypoints to generate with the
|
47 |
+
extractor. The higher the better accuracy but the longer.
|
48 |
+
|
49 |
+
Returns:
|
50 |
+
LigthGlueExtractor: ALIKED | DISK | SIFT | SuperPoint
|
51 |
+
|
52 |
+
Raises:
|
53 |
+
AssertionError: when the n_keypoints are outside the valid range
|
54 |
+
0..5000
|
55 |
+
AssertionError: when extractor_type is not valid
|
56 |
+
"""
|
57 |
+
assert 0 <= n_keypoints <= 5000, "n_keypoints should be in range 0..5000"
|
58 |
+
assert (
|
59 |
+
extractor_type in ALLOWED_EXTRACTOR_TYPES
|
60 |
+
), f"extractor type {extractor_type} should be in {ALLOWED_EXTRACTOR_TYPES}."
|
61 |
+
|
62 |
+
if extractor_type == "sift":
|
63 |
+
return SIFT(max_num_keypoints=n_keypoints).eval().to(device)
|
64 |
+
elif extractor_type == "superpoint":
|
65 |
+
return SuperPoint(max_num_keypoints=n_keypoints).eval().to(device)
|
66 |
+
elif extractor_type == "disk":
|
67 |
+
return DISK(max_num_keypoints=n_keypoints).eval().to(device)
|
68 |
+
elif extractor_type == "aliked":
|
69 |
+
return ALIKED(max_num_keypoints=n_keypoints).eval().to(device)
|
70 |
+
else:
|
71 |
+
raise Exception("extractor_type is not valid")
|
72 |
+
|
73 |
+
|
74 |
+
def extractor_type_to_matcher(device: torch.device, extractor_type: str) -> LightGlue:
|
75 |
+
"""
|
76 |
+
Return the LightGlue matcher given an `extractor_type`.
|
77 |
+
|
78 |
+
Args:
|
79 |
+
device (torch.device): cpu or cuda
|
80 |
+
extractor_type (str): in {sift, superpoint, aliked, disk}
|
81 |
+
|
82 |
+
Returns:
|
83 |
+
LightGlue Matcher
|
84 |
+
"""
|
85 |
+
assert (
|
86 |
+
extractor_type in ALLOWED_EXTRACTOR_TYPES
|
87 |
+
), f"extractor type {extractor_type} should be in {ALLOWED_EXTRACTOR_TYPES}."
|
88 |
+
return LightGlue(features=extractor_type).eval().to(device)
|
89 |
+
|
90 |
+
|
91 |
+
def get_scores(matches: dict[str, torch.Tensor]) -> np.ndarray:
|
92 |
+
"""
|
93 |
+
Given a `matches` dict from the LightGlue matcher output, it returns the
|
94 |
+
scores as a numpy array.
|
95 |
+
"""
|
96 |
+
return matches["matching_scores0"][0].to("cpu").numpy()
|
97 |
+
|
98 |
+
|
99 |
+
def wasserstein(scores: np.ndarray) -> float:
|
100 |
+
"""
|
101 |
+
Return the Wasserstein distance of the scores against the null
|
102 |
+
distribution.
|
103 |
+
The greater the distance, the farther away it is from the null
|
104 |
+
distribution.
|
105 |
+
"""
|
106 |
+
x_null_distribution = [0.0] * 1024
|
107 |
+
return wasserstein_distance(x_null_distribution, scores).item()
|
108 |
+
|
109 |
+
|
110 |
+
class PictureLayout(Enum):
|
111 |
+
"""
|
112 |
+
Layout of a picture.
|
113 |
+
"""
|
114 |
+
|
115 |
+
PORTRAIT = "portrait"
|
116 |
+
LANDSCAPE = "landscape"
|
117 |
+
SQUARE = "square"
|
118 |
+
|
119 |
+
|
120 |
+
def crop(
|
121 |
+
pil_image: Image.Image,
|
122 |
+
box: Tuple[int, int, int, int],
|
123 |
+
) -> Image.Image:
|
124 |
+
"""
|
125 |
+
Crop a pil_image based on the provided rectangle in (x1, y1,
|
126 |
+
x2, y2) format - with the upper left corner given first.
|
127 |
+
"""
|
128 |
+
return pil_image.crop(box=box)
|
129 |
+
|
130 |
+
|
131 |
+
def get_picture_layout(pil_image: Image.Image) -> PictureLayout:
|
132 |
+
"""
|
133 |
+
Return the picture layout.
|
134 |
+
"""
|
135 |
+
width, height = pil_image.size
|
136 |
+
|
137 |
+
if width > height:
|
138 |
+
return PictureLayout.LANDSCAPE
|
139 |
+
elif width == height:
|
140 |
+
return PictureLayout.SQUARE
|
141 |
+
else:
|
142 |
+
return PictureLayout.PORTRAIT
|
143 |
+
|
144 |
+
|
145 |
+
def get_segmentation_mask_crop_box(
|
146 |
+
pil_image_mask: Image.Image,
|
147 |
+
padding: int = 0,
|
148 |
+
) -> Tuple[int, int, int, int]:
|
149 |
+
"""
|
150 |
+
Return a crop box for the given pil_image that contains the segmentation mask (black and white).
|
151 |
+
|
152 |
+
Args:
|
153 |
+
pil_image_mask (PIL): image containing the segmentation mask
|
154 |
+
padding (int): how much to pad around the segmentation mask.
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
Rectangle (Tuple[int, int, int, int]): 4 tuple representing a rectangle (x1, y1, x2, y2) with the upper left corner given first.
|
158 |
+
"""
|
159 |
+
array_image_mask = np.array(pil_image_mask)
|
160 |
+
a = np.where(array_image_mask != 0)
|
161 |
+
y_min = np.min(a[0]).item()
|
162 |
+
y_max = np.max(a[0]).item()
|
163 |
+
x_min = np.min(a[1]).item()
|
164 |
+
x_max = np.max(a[1]).item()
|
165 |
+
box = (x_min, y_min, x_max, y_max)
|
166 |
+
box_with_padding = (
|
167 |
+
box[0] - padding,
|
168 |
+
box[1] - padding,
|
169 |
+
box[2] + padding,
|
170 |
+
box[3] + padding,
|
171 |
+
)
|
172 |
+
return box_with_padding
|
173 |
+
|
174 |
+
|
175 |
+
def scale_keypoints_to_image_size(
|
176 |
+
image_width: int,
|
177 |
+
image_height: int,
|
178 |
+
keypoints_xyn: np.ndarray,
|
179 |
+
) -> np.ndarray:
|
180 |
+
"""
|
181 |
+
Given keypoints in xyn format, it returns new keypoints in xy format.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
image_width (int): width of the image
|
185 |
+
image_height (int): height of the image
|
186 |
+
keypoints_xyn (np.ndarray): 2D numpy array representing the keypoints
|
187 |
+
in xyn format.
|
188 |
+
|
189 |
+
Returns:
|
190 |
+
keypoints_xy (np.ndarray): 2D numpy array representing the keypoints in
|
191 |
+
xy format.
|
192 |
+
"""
|
193 |
+
keypoints_xy = keypoints_xyn.copy()
|
194 |
+
keypoints_xy[:, 0] = keypoints_xyn[:, 0] * image_width
|
195 |
+
keypoints_xy[:, 1] = keypoints_xyn[:, 1] * image_height
|
196 |
+
return keypoints_xy
|
197 |
+
|
198 |
+
|
199 |
+
def normalize_keypoints_to_image_size(
|
200 |
+
image_width: int,
|
201 |
+
image_height: int,
|
202 |
+
keypoints_xy: np.ndarray,
|
203 |
+
) -> np.ndarray:
|
204 |
+
"""
|
205 |
+
Given keypoints in xy format, it returns new keypoints in xyn format.
|
206 |
+
|
207 |
+
Args:
|
208 |
+
image_width (int): width of the image
|
209 |
+
image_height (int): height of the image
|
210 |
+
keypoints_xy (np.ndarray): 2D numpy array representing the keypoints
|
211 |
+
in xy format.
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
keypoints_xyn (np.ndarray): 2D numpy array representing the keypoints in
|
215 |
+
xyn format.
|
216 |
+
"""
|
217 |
+
keypoints_xyn = keypoints_xy.copy()
|
218 |
+
keypoints_xyn[:, 0] = keypoints_xy[:, 0] / image_width
|
219 |
+
keypoints_xyn[:, 1] = keypoints_xy[:, 1] / image_height
|
220 |
+
return keypoints_xyn
|
221 |
+
|
222 |
+
|
223 |
+
def show_keypoints_xy(
|
224 |
+
array_image: np.ndarray,
|
225 |
+
keypoints_xy: np.ndarray,
|
226 |
+
classes_dictionnary: dict[int, str],
|
227 |
+
verbose: bool = True,
|
228 |
+
) -> None:
|
229 |
+
"""
|
230 |
+
Show keypoints on top of an `array_image`, useful in jupyter notebooks for
|
231 |
+
instance.
|
232 |
+
|
233 |
+
Args:
|
234 |
+
array_image (np.ndarray): numpy array representing an image.
|
235 |
+
keypoints_xy (np.ndarray): 2D numpy array representing the keypoints in
|
236 |
+
xy format.
|
237 |
+
classes_dictionnary (dict[int, str]): Model prediction classes.
|
238 |
+
verbose (bool): should we make the image verbose by adding some label
|
239 |
+
for each keypoint?
|
240 |
+
"""
|
241 |
+
|
242 |
+
colors = ["r", "g", "b", "c", "m", "y", "w"]
|
243 |
+
plt.imshow(array_image)
|
244 |
+
label_margin = 20
|
245 |
+
height, width, _ = array_image.shape
|
246 |
+
|
247 |
+
for class_inst, class_name in classes_dictionnary.items():
|
248 |
+
color = colors[class_inst]
|
249 |
+
x, y = keypoints_xy[class_inst]
|
250 |
+
plt.scatter(x=[x], y=[y], c=color)
|
251 |
+
if verbose:
|
252 |
+
plt.annotate(class_name, (x - label_margin, y - label_margin), c="w")
|
253 |
+
|
254 |
+
|
255 |
+
def draw_keypoints_xy_on_ax(
|
256 |
+
ax,
|
257 |
+
array_image: np.ndarray,
|
258 |
+
keypoints_xy: np.ndarray,
|
259 |
+
classes_dictionnary: dict,
|
260 |
+
verbose: bool = True,
|
261 |
+
) -> None:
|
262 |
+
"""
|
263 |
+
Dray keypoints on top of an `array_image`, useful in jupyter notebooks for
|
264 |
+
instance.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
array_image (np.ndarray): numpy array representing an image.
|
268 |
+
keypoints_xy (np.ndarray): 2D numpy array representing the keypoints in
|
269 |
+
xy format.
|
270 |
+
classes_dictionnary (dict[int, str]): Model prediction classes.
|
271 |
+
verbose (bool): should we make the image verbose by adding some label
|
272 |
+
for each keypoint?
|
273 |
+
"""
|
274 |
+
|
275 |
+
colors = ["r", "g", "b", "c", "m", "y", "w"]
|
276 |
+
ax.imshow(array_image)
|
277 |
+
label_margin = 20
|
278 |
+
height, width, _ = array_image.shape
|
279 |
+
|
280 |
+
for class_inst, class_name in classes_dictionnary.items():
|
281 |
+
color = colors[class_inst]
|
282 |
+
x, y = keypoints_xy[class_inst]
|
283 |
+
ax.scatter(x=[x], y=[y], c=color)
|
284 |
+
|
285 |
+
if verbose:
|
286 |
+
ax.annotate(class_name, (x - label_margin, y - label_margin), c="w")
|
287 |
+
|
288 |
+
k_pelvic_fin_base = get_keypoint(
|
289 |
+
class_name="pelvic_fin_base",
|
290 |
+
keypoints=keypoints_xy,
|
291 |
+
classes_dictionnary=classes_dictionnary,
|
292 |
+
)
|
293 |
+
k_anal_fin_base = get_keypoint(
|
294 |
+
class_name="anal_fin_base",
|
295 |
+
keypoints=keypoints_xy,
|
296 |
+
classes_dictionnary=classes_dictionnary,
|
297 |
+
)
|
298 |
+
|
299 |
+
ax.axline(k_pelvic_fin_base, k_anal_fin_base, c="lime")
|
300 |
+
|
301 |
+
|
302 |
+
def show_keypoints_xyn(
|
303 |
+
array_image: np.ndarray,
|
304 |
+
keypoints_xyn: np.ndarray,
|
305 |
+
classes_dictionnary: dict,
|
306 |
+
verbose: bool = True,
|
307 |
+
) -> None:
|
308 |
+
"""
|
309 |
+
Dray keypoints on top of an `array_image`, useful in jupyter notebooks for
|
310 |
+
instance.
|
311 |
+
|
312 |
+
Args:
|
313 |
+
array_image (np.ndarray): numpy array representing an image.
|
314 |
+
keypoints_xy (np.ndarray): 2D numpy array representing the keypoints in
|
315 |
+
xy format.
|
316 |
+
classes_dictionnary (dict[int, str]): Model prediction classes.
|
317 |
+
verbose (bool): should we make the image verbose by adding some label
|
318 |
+
for each keypoint?
|
319 |
+
"""
|
320 |
+
height, width, _ = array_image.shape
|
321 |
+
keypoints_xy = scale_keypoints_to_image_size(
|
322 |
+
image_height=height,
|
323 |
+
image_width=width,
|
324 |
+
keypoints_xyn=keypoints_xyn,
|
325 |
+
)
|
326 |
+
show_keypoints_xy(
|
327 |
+
array_image=array_image,
|
328 |
+
keypoints_xy=keypoints_xy,
|
329 |
+
classes_dictionnary=classes_dictionnary,
|
330 |
+
verbose=verbose,
|
331 |
+
)
|
332 |
+
|
333 |
+
|
334 |
+
def rotate_point(
|
335 |
+
clockwise: bool,
|
336 |
+
origin: Tuple[float, float],
|
337 |
+
point: Tuple[float, float],
|
338 |
+
angle: float,
|
339 |
+
) -> Tuple[float, float]:
|
340 |
+
"""
|
341 |
+
Rotate a point clockwise or counterclockwise by a given angle around a
|
342 |
+
given origin.
|
343 |
+
|
344 |
+
Args:
|
345 |
+
clockwise (bool): should the rotation be clockwise?
|
346 |
+
origin (Tuple[float, float]): origin 2D point to perform the rotation.
|
347 |
+
point (Tuple[float, float]): 2D point to rotate.
|
348 |
+
angle (float): angle in radian.
|
349 |
+
|
350 |
+
Returns:
|
351 |
+
rotated_point (Tuple[float, float]): rotated point after applying the
|
352 |
+
2D transformation.
|
353 |
+
"""
|
354 |
+
if clockwise:
|
355 |
+
angle = 0 - angle
|
356 |
+
|
357 |
+
ox, oy = origin
|
358 |
+
px, py = point
|
359 |
+
|
360 |
+
qx = ox + math.cos(angle) * (px - ox) - math.sin(angle) * (py - oy)
|
361 |
+
qy = oy + math.sin(angle) * (px - ox) + math.cos(angle) * (py - oy)
|
362 |
+
|
363 |
+
return qx, qy
|
364 |
+
|
365 |
+
|
366 |
+
def rotate_image(angle_rad: float, array_image: np.ndarray, expand=False) -> np.ndarray:
|
367 |
+
"""
|
368 |
+
Rotate an `array_image` by an angle defined in radians, clockwise using the
|
369 |
+
center as origin.
|
370 |
+
|
371 |
+
Args:
|
372 |
+
angle_rad (float): angle in radian.
|
373 |
+
array_image (np.ndarray): numpy array representing the image to rotate.
|
374 |
+
expand (bool): should we expand the image as we rotate it to not
|
375 |
+
truncate some parts of it if the image is not square?
|
376 |
+
"""
|
377 |
+
angle_degrees = math.degrees(angle_rad)
|
378 |
+
return np.array(Image.fromarray(array_image).rotate(angle_degrees, expand=expand))
|
379 |
+
|
380 |
+
|
381 |
+
def rotate_keypoints_xy(
|
382 |
+
angle_rad: float,
|
383 |
+
keypoints_xy: np.ndarray,
|
384 |
+
origin: Tuple[float, float],
|
385 |
+
clockwise: bool = True,
|
386 |
+
) -> np.ndarray:
|
387 |
+
"""
|
388 |
+
Rotate keypoints by an angle defined in radians, clockwise or
|
389 |
+
counterclockwise using the `origin_xyn` point.
|
390 |
+
|
391 |
+
Args:
|
392 |
+
angle_rad (float): angle in radian.
|
393 |
+
origin (Tuple[float, float]): origin 2D point to perform the rotation.
|
394 |
+
keypoints_xy (np.ndarray): 2D numpy array representing the keypoints in
|
395 |
+
xy format.
|
396 |
+
clockwise (bool): should the rotation be clockwise?
|
397 |
+
|
398 |
+
Returns:
|
399 |
+
rotated_keypoints_xy (np.ndarray): rotated keypoints in xy format.
|
400 |
+
"""
|
401 |
+
return np.array(
|
402 |
+
[
|
403 |
+
rotate_point(
|
404 |
+
clockwise=clockwise,
|
405 |
+
origin=origin,
|
406 |
+
point=(kp[0].item(), kp[1].item()),
|
407 |
+
angle=angle_rad,
|
408 |
+
)
|
409 |
+
for kp in keypoints_xy
|
410 |
+
]
|
411 |
+
)
|
412 |
+
|
413 |
+
|
414 |
+
def rotate_image_and_keypoints_xy(
|
415 |
+
angle_rad: float,
|
416 |
+
array_image: np.ndarray,
|
417 |
+
keypoints_xy: np.ndarray,
|
418 |
+
) -> dict[str, np.ndarray]:
|
419 |
+
"""
|
420 |
+
Rotate the image and its keypoints provided the parameters.
|
421 |
+
|
422 |
+
Args:
|
423 |
+
angle_rad (float): angle in radian.
|
424 |
+
array_image (np.ndarray): numpy array representing the image to rotate.
|
425 |
+
keypoints_xy (np.ndarray): 2D numpy array representing the keypoints in
|
426 |
+
xy format.
|
427 |
+
|
428 |
+
Returns:
|
429 |
+
array_image (np.ndarray): rotated array_image as a 2D numpy array.
|
430 |
+
keypoints_xy (np.ndarray): rotated keypoints in xy format.
|
431 |
+
"""
|
432 |
+
height, width, _ = array_image.shape
|
433 |
+
center_x, center_y = int(width / 2), int(height / 2)
|
434 |
+
origin = (center_x, center_y)
|
435 |
+
image_rotated = rotate_image(angle_rad=angle_rad, array_image=array_image)
|
436 |
+
keypoints_xy_rotated = rotate_keypoints_xy(
|
437 |
+
angle_rad=angle_rad, keypoints_xy=keypoints_xy, origin=origin, clockwise=True
|
438 |
+
)
|
439 |
+
|
440 |
+
return {
|
441 |
+
"array_image": image_rotated,
|
442 |
+
"keypoints_xy": keypoints_xy_rotated,
|
443 |
+
}
|
444 |
+
|
445 |
+
|
446 |
+
def get_keypoint(
|
447 |
+
class_name: str,
|
448 |
+
keypoints: np.ndarray,
|
449 |
+
classes_dictionnary: dict[int, str],
|
450 |
+
) -> np.ndarray:
|
451 |
+
"""
|
452 |
+
Return the keypoint for the provided `class_name` (eg. eye, front_fin_base, etc).
|
453 |
+
|
454 |
+
Raises:
|
455 |
+
AssertionError: when the provided class_name is not compatible or when the number of keypoints does not match.
|
456 |
+
"""
|
457 |
+
assert (
|
458 |
+
class_name in classes_dictionnary.values()
|
459 |
+
), f"class_name should be in {classes_dictionnary.values()}"
|
460 |
+
assert len(classes_dictionnary) == len(
|
461 |
+
keypoints
|
462 |
+
), "Number of provided keypoints does not match the number of class names"
|
463 |
+
|
464 |
+
class_name_to_class_inst = {v: k for k, v in classes_dictionnary.items()}
|
465 |
+
return keypoints[class_name_to_class_inst[class_name]]
|
466 |
+
|
467 |
+
|
468 |
+
def to_direction_vector(p1: np.ndarray, p2: np.ndarray) -> np.ndarray:
|
469 |
+
"""
|
470 |
+
Return the direction vector between two points p1 and p2.
|
471 |
+
"""
|
472 |
+
assert len(p1) == len(p2), "p1 and p2 should have the same length"
|
473 |
+
return p2 - p1
|
474 |
+
|
475 |
+
|
476 |
+
def is_upside_down(
|
477 |
+
keypoints_xy: np.ndarray,
|
478 |
+
classes_dictionnary: dict[int, str],
|
479 |
+
) -> bool:
|
480 |
+
"""
|
481 |
+
Is the fish upside down?
|
482 |
+
"""
|
483 |
+
k_pelvic_fin_base = get_keypoint(
|
484 |
+
class_name="pelvic_fin_base",
|
485 |
+
keypoints=keypoints_xy,
|
486 |
+
classes_dictionnary=classes_dictionnary,
|
487 |
+
)
|
488 |
+
k_anal_fin_base = get_keypoint(
|
489 |
+
class_name="anal_fin_base",
|
490 |
+
keypoints=keypoints_xy,
|
491 |
+
classes_dictionnary=classes_dictionnary,
|
492 |
+
)
|
493 |
+
k_dorsal_fin_base = get_keypoint(
|
494 |
+
class_name="dorsal_fin_base",
|
495 |
+
keypoints=keypoints_xy,
|
496 |
+
classes_dictionnary=classes_dictionnary,
|
497 |
+
)
|
498 |
+
|
499 |
+
print(f"dorsal_fin_base: {k_dorsal_fin_base}")
|
500 |
+
print(f"pelvic_fin_base: {k_pelvic_fin_base}")
|
501 |
+
print(f"anal_fin_base: {k_anal_fin_base}")
|
502 |
+
return (k_dorsal_fin_base[1] > k_pelvic_fin_base[1]).item()
|
503 |
+
|
504 |
+
|
505 |
+
def get_direction_vector(
|
506 |
+
keypoints_xy: np.ndarray, classes_dictionnary: dict[int, str]
|
507 |
+
) -> np.ndarray:
|
508 |
+
"""
|
509 |
+
Get the direction vector for the realignment.
|
510 |
+
"""
|
511 |
+
# Align horizontally the fish based on its pelvic fin base and its anal fin base
|
512 |
+
k_pelvic_fin_base = get_keypoint(
|
513 |
+
class_name="pelvic_fin_base",
|
514 |
+
keypoints=keypoints_xy,
|
515 |
+
classes_dictionnary=classes_dictionnary,
|
516 |
+
)
|
517 |
+
k_anal_fin_base = get_keypoint(
|
518 |
+
class_name="anal_fin_base",
|
519 |
+
keypoints=keypoints_xy,
|
520 |
+
classes_dictionnary=classes_dictionnary,
|
521 |
+
)
|
522 |
+
|
523 |
+
return to_direction_vector(
|
524 |
+
p1=k_pelvic_fin_base, p2=k_anal_fin_base
|
525 |
+
) # line between the pelvic and anal fins
|
526 |
+
|
527 |
+
|
528 |
+
def get_reference_vector() -> np.ndarray:
|
529 |
+
"""
|
530 |
+
Get the reference vector to align the direction vector to.
|
531 |
+
"""
|
532 |
+
return np.array([1, 0]) # horizontal axis
|
533 |
+
|
534 |
+
|
535 |
+
def get_angle(v1: np.ndarray, v2: np.ndarray) -> float:
|
536 |
+
"""
|
537 |
+
Return the angle (couterclockwise) in radians between vectors v1 and v2.
|
538 |
+
"""
|
539 |
+
cos_theta = (
|
540 |
+
np.dot(v1, v2) / np.linalg.norm(v1, ord=2) * np.linalg.norm(v2, ord=2)
|
541 |
+
).item()
|
542 |
+
return -math.acos(cos_theta)
|
543 |
+
|
544 |
+
|
545 |
+
def is_aligned(keypoints_xy: np.ndarray, classes_dictionnary: dict[int, str]) -> bool:
|
546 |
+
"""
|
547 |
+
Return wether the keypoints are now properly aligned with the direction
|
548 |
+
vector used to make the rotation.
|
549 |
+
"""
|
550 |
+
v1 = get_direction_vector(
|
551 |
+
keypoints_xy=keypoints_xy, classes_dictionnary=classes_dictionnary
|
552 |
+
)
|
553 |
+
v_ref = get_reference_vector()
|
554 |
+
theta = get_angle(v1, v_ref)
|
555 |
+
return abs(theta) <= 0.001
|
556 |
+
|
557 |
+
|
558 |
+
def get_angle_correction_sign(
|
559 |
+
angle_rad: float,
|
560 |
+
array_image: np.ndarray,
|
561 |
+
keypoints_xy: np.ndarray,
|
562 |
+
classes_dictionnary: dict[int, str],
|
563 |
+
) -> int:
|
564 |
+
"""
|
565 |
+
Returns 1 or -1 depending on the angle sign to set.
|
566 |
+
"""
|
567 |
+
rotation_results = rotate_image_and_keypoints_xy(
|
568 |
+
angle_rad=angle_rad, array_image=array_image, keypoints_xy=keypoints_xy
|
569 |
+
)
|
570 |
+
if not is_aligned(
|
571 |
+
keypoints_xy=rotation_results["keypoints_xy"],
|
572 |
+
classes_dictionnary=classes_dictionnary,
|
573 |
+
):
|
574 |
+
return -1
|
575 |
+
else:
|
576 |
+
return 1
|
577 |
+
|
578 |
+
|
579 |
+
def get_angle_correction(
|
580 |
+
keypoints_xy: np.ndarray,
|
581 |
+
array_image: np.ndarray,
|
582 |
+
classes_dictionnary: dict[int, str],
|
583 |
+
) -> float:
|
584 |
+
"""
|
585 |
+
Get the angle correction in radians that aligns the fish (based on the
|
586 |
+
keypoints) horizontally.
|
587 |
+
"""
|
588 |
+
v1 = get_direction_vector(
|
589 |
+
keypoints_xy=keypoints_xy, classes_dictionnary=classes_dictionnary
|
590 |
+
)
|
591 |
+
v_ref = get_reference_vector()
|
592 |
+
theta = get_angle(v1, v_ref)
|
593 |
+
|
594 |
+
angle_sign = get_angle_correction_sign(
|
595 |
+
angle_rad=theta,
|
596 |
+
array_image=array_image,
|
597 |
+
keypoints_xy=keypoints_xy,
|
598 |
+
classes_dictionnary=classes_dictionnary,
|
599 |
+
)
|
600 |
+
theta = angle_sign * theta
|
601 |
+
rotation_results = rotate_image_and_keypoints_xy(
|
602 |
+
angle_rad=theta, array_image=array_image, keypoints_xy=keypoints_xy
|
603 |
+
)
|
604 |
+
|
605 |
+
# Check whether the fish is upside down
|
606 |
+
if is_upside_down(
|
607 |
+
keypoints_xy=rotation_results["keypoints_xy"],
|
608 |
+
classes_dictionnary=classes_dictionnary,
|
609 |
+
):
|
610 |
+
print("the fish is upside down...")
|
611 |
+
return theta + math.pi
|
612 |
+
else:
|
613 |
+
print("The fish is not upside down")
|
614 |
+
return theta # No need to rotate the fish more
|
615 |
+
|
616 |
+
|
617 |
+
def show_algorithm_steps(
|
618 |
+
image_filepath: Path,
|
619 |
+
keypoints_xy: np.ndarray,
|
620 |
+
rotation_results: dict,
|
621 |
+
theta: float,
|
622 |
+
classes_dictionnary: dict,
|
623 |
+
) -> None:
|
624 |
+
"""
|
625 |
+
Display a matplotlib figure that details step by step the result of the rotation.
|
626 |
+
Keypoints can be overlayed with the images.
|
627 |
+
"""
|
628 |
+
array_image = np.array(Image.open(image_filepath))
|
629 |
+
array_image_final = np.array(
|
630 |
+
Image.open(image_filepath).rotate(math.degrees(theta), expand=True)
|
631 |
+
)
|
632 |
+
|
633 |
+
fig, axs = plt.subplots(1, 4, figsize=(20, 4))
|
634 |
+
fig.suptitle(f"{image_filepath.name}")
|
635 |
+
print(f"image_filepath: {image_filepath}")
|
636 |
+
|
637 |
+
# Hiding the x and y axis ticks
|
638 |
+
for ax in axs:
|
639 |
+
ax.xaxis.set_visible(False)
|
640 |
+
ax.yaxis.set_visible(False)
|
641 |
+
|
642 |
+
axs[0].set_title("original")
|
643 |
+
axs[0].imshow(array_image)
|
644 |
+
axs[1].set_title("predicted keypoints")
|
645 |
+
draw_keypoints_xy_on_ax(
|
646 |
+
ax=axs[1],
|
647 |
+
array_image=array_image,
|
648 |
+
keypoints_xy=keypoints_xy,
|
649 |
+
classes_dictionnary=classes_dictionnary,
|
650 |
+
)
|
651 |
+
axs[2].set_title(f"rotation of {math.degrees(theta):.1f} degrees")
|
652 |
+
draw_keypoints_xy_on_ax(
|
653 |
+
ax=axs[2],
|
654 |
+
array_image=rotation_results["array_image"],
|
655 |
+
keypoints_xy=rotation_results["keypoints_xy"],
|
656 |
+
classes_dictionnary=classes_dictionnary,
|
657 |
+
)
|
658 |
+
axs[3].set_title("final")
|
659 |
+
axs[3].imshow(array_image_final)
|
viz2d.py
ADDED
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
2D visualization primitives based on Matplotlib.
|
3 |
+
1) Plot images with `plot_images`.
|
4 |
+
2) Call `plot_keypoints` or `plot_matches` any number of times.
|
5 |
+
3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`.
|
6 |
+
"""
|
7 |
+
|
8 |
+
import io
|
9 |
+
from typing import Callable
|
10 |
+
|
11 |
+
import matplotlib
|
12 |
+
import matplotlib.patheffects as path_effects
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
|
19 |
+
def pyplot_to_pil_image(plot_fn: Callable[..., None]) -> Image.Image:
|
20 |
+
"""
|
21 |
+
Turn a plot_fn side effectful function that uses pyplot into a pil_image by
|
22 |
+
writing to an IO buffer.
|
23 |
+
"""
|
24 |
+
plot_fn()
|
25 |
+
buf = io.BytesIO()
|
26 |
+
plt.savefig(buf, format="png")
|
27 |
+
buf.seek(0) # Move to the beginning of the BytesIO buffer
|
28 |
+
return Image.open(buf)
|
29 |
+
|
30 |
+
|
31 |
+
def cm_RdGn(x):
|
32 |
+
"""Custom colormap: red (0) -> yellow (0.5) -> green (1)."""
|
33 |
+
x = np.clip(x, 0, 1)[..., None] * 2
|
34 |
+
c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]])
|
35 |
+
return np.clip(c, 0, 1)
|
36 |
+
|
37 |
+
|
38 |
+
def cm_BlRdGn(x_):
|
39 |
+
"""Custom colormap: blue (-1) -> red (0.0) -> green (1)."""
|
40 |
+
x = np.clip(x_, 0, 1)[..., None] * 2
|
41 |
+
c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0, 1.0]])
|
42 |
+
|
43 |
+
xn = -np.clip(x_, -1, 0)[..., None] * 2
|
44 |
+
cn = xn * np.array([[0, 0.1, 1, 1.0]]) + (2 - xn) * np.array([[1.0, 0, 0, 1.0]])
|
45 |
+
out = np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1)
|
46 |
+
return out
|
47 |
+
|
48 |
+
|
49 |
+
def cm_prune(x_):
|
50 |
+
"""Custom colormap to visualize pruning"""
|
51 |
+
if isinstance(x_, torch.Tensor):
|
52 |
+
x_ = x_.cpu().numpy()
|
53 |
+
max_i = max(x_)
|
54 |
+
norm_x = np.where(x_ == max_i, -1, (x_ - 1) / 9)
|
55 |
+
return cm_BlRdGn(norm_x)
|
56 |
+
|
57 |
+
|
58 |
+
def matches_as_pil_image(
|
59 |
+
torch_images,
|
60 |
+
feats0,
|
61 |
+
feats1,
|
62 |
+
matches01,
|
63 |
+
mode: str = "column",
|
64 |
+
) -> Image.Image:
|
65 |
+
"""
|
66 |
+
Generate a PIL image outlining the keypoints from `feats0` and `feats1` and
|
67 |
+
how they match.
|
68 |
+
Overlay it on the torch_images.
|
69 |
+
"""
|
70 |
+
|
71 |
+
def plot_fn():
|
72 |
+
kpts0, kpts1, matches = (
|
73 |
+
feats0["keypoints"],
|
74 |
+
feats1["keypoints"],
|
75 |
+
matches01["matches"],
|
76 |
+
)
|
77 |
+
m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]
|
78 |
+
|
79 |
+
axes = plot_images(imgs=torch_images, mode=mode)
|
80 |
+
plot_matches(m_kpts0, m_kpts1, color="lime", lw=0.2)
|
81 |
+
|
82 |
+
return pyplot_to_pil_image(plot_fn=plot_fn)
|
83 |
+
|
84 |
+
|
85 |
+
def keypoints_as_pil_image(
|
86 |
+
torch_image: torch.Tensor,
|
87 |
+
feats: dict[str, torch.Tensor],
|
88 |
+
color: str = "blue",
|
89 |
+
ps: int = 10,
|
90 |
+
) -> Image.Image:
|
91 |
+
"""
|
92 |
+
Generate a PIL image outlining the keypoints from `feats` and overlay it on
|
93 |
+
the torch_image.
|
94 |
+
"""
|
95 |
+
|
96 |
+
def plot_fn():
|
97 |
+
kpts = feats["keypoints"]
|
98 |
+
plot_images([torch_image])
|
99 |
+
plot_keypoints(kpts=[kpts], colors=color, ps=ps)
|
100 |
+
|
101 |
+
return pyplot_to_pil_image(plot_fn=plot_fn)
|
102 |
+
|
103 |
+
|
104 |
+
def matching_keypoints_as_pil_image(
|
105 |
+
torch_images,
|
106 |
+
feats0,
|
107 |
+
feats1,
|
108 |
+
matches01,
|
109 |
+
mode: str = "column",
|
110 |
+
) -> Image.Image:
|
111 |
+
"""
|
112 |
+
Generate a PIL image outlining the keypoints from `feats0` and `feats1`.
|
113 |
+
Overlay it on the torch_images.
|
114 |
+
"""
|
115 |
+
|
116 |
+
def plot_fn():
|
117 |
+
kpts0, kpts1 = (
|
118 |
+
feats0["keypoints"],
|
119 |
+
feats1["keypoints"],
|
120 |
+
)
|
121 |
+
kpc0, kpc1 = cm_prune(matches01["prune0"]), cm_prune(matches01["prune1"])
|
122 |
+
plot_images(torch_images, mode=mode)
|
123 |
+
plot_keypoints([kpts0, kpts1], colors=[kpc0, kpc1], ps=10)
|
124 |
+
|
125 |
+
return pyplot_to_pil_image(plot_fn=plot_fn)
|
126 |
+
|
127 |
+
|
128 |
+
def as_pil_image(
|
129 |
+
torch_images,
|
130 |
+
feats0,
|
131 |
+
feats1,
|
132 |
+
matches01,
|
133 |
+
mode: str = "column",
|
134 |
+
) -> Image.Image:
|
135 |
+
kpts0, kpts1, matches = (
|
136 |
+
feats0["keypoints"],
|
137 |
+
feats1["keypoints"],
|
138 |
+
matches01["matches"],
|
139 |
+
)
|
140 |
+
m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]
|
141 |
+
|
142 |
+
axes = plot_images(imgs=torch_images, mode=mode)
|
143 |
+
plot_matches(m_kpts0, m_kpts1, color="lime", lw=0.2)
|
144 |
+
|
145 |
+
buf = io.BytesIO()
|
146 |
+
plt.savefig(buf, format="png")
|
147 |
+
|
148 |
+
buf.seek(0) # Move to the beginning of the BytesIO buffer
|
149 |
+
|
150 |
+
# Step 3: Open the image with PIL
|
151 |
+
image = Image.open(buf)
|
152 |
+
return image
|
153 |
+
|
154 |
+
|
155 |
+
def plot_images(
|
156 |
+
imgs,
|
157 |
+
titles=None,
|
158 |
+
cmaps="gray",
|
159 |
+
dpi=100,
|
160 |
+
pad=0.5,
|
161 |
+
adaptive=True,
|
162 |
+
mode: str = "column",
|
163 |
+
):
|
164 |
+
"""Plot a set of images horizontally.
|
165 |
+
Args:
|
166 |
+
imgs: list of NumPy RGB (H, W, 3) or PyTorch RGB (3, H, W) or mono (H, W).
|
167 |
+
titles: a list of strings, as titles for each image.
|
168 |
+
cmaps: colormaps for monochrome images.
|
169 |
+
adaptive: whether the figure size should fit the image aspect ratios.
|
170 |
+
mode (str): value in {column, row}
|
171 |
+
"""
|
172 |
+
assert mode in [
|
173 |
+
"column",
|
174 |
+
"row",
|
175 |
+
], f"mode is not valid, should be in ['column', 'row']."
|
176 |
+
|
177 |
+
# conversion to (H, W, 3) for torch.Tensor
|
178 |
+
imgs = [
|
179 |
+
(
|
180 |
+
img.permute(1, 2, 0).cpu().numpy()
|
181 |
+
if (isinstance(img, torch.Tensor) and img.dim() == 3)
|
182 |
+
else img
|
183 |
+
)
|
184 |
+
for img in imgs
|
185 |
+
]
|
186 |
+
|
187 |
+
n = len(imgs)
|
188 |
+
if not isinstance(cmaps, (list, tuple)):
|
189 |
+
cmaps = [cmaps] * n
|
190 |
+
|
191 |
+
if adaptive:
|
192 |
+
ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H
|
193 |
+
elif mode == "row":
|
194 |
+
ratios = [4 / 3] * n
|
195 |
+
elif mode == "column":
|
196 |
+
ratios = [1 / 3] * n
|
197 |
+
else:
|
198 |
+
ratios = [4 / 3] * n
|
199 |
+
|
200 |
+
if mode == "column":
|
201 |
+
figsize = [10, 5]
|
202 |
+
fig, ax = plt.subplots(
|
203 |
+
n, 1, figsize=figsize, dpi=dpi, gridspec_kw={"height_ratios": ratios}
|
204 |
+
)
|
205 |
+
elif mode == "row":
|
206 |
+
figsize = [sum(ratios) * 4.5, 4.5]
|
207 |
+
fig, ax = plt.subplots(
|
208 |
+
1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
|
209 |
+
)
|
210 |
+
|
211 |
+
if n == 1:
|
212 |
+
ax = [ax]
|
213 |
+
for i in range(n):
|
214 |
+
ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
|
215 |
+
ax[i].get_yaxis().set_ticks([])
|
216 |
+
ax[i].get_xaxis().set_ticks([])
|
217 |
+
ax[i].set_axis_off()
|
218 |
+
for spine in ax[i].spines.values(): # remove frame
|
219 |
+
spine.set_visible(False)
|
220 |
+
if titles:
|
221 |
+
ax[i].set_title(titles[i])
|
222 |
+
|
223 |
+
fig.tight_layout(pad=pad)
|
224 |
+
|
225 |
+
|
226 |
+
def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0):
|
227 |
+
"""Plot keypoints for existing images.
|
228 |
+
Args:
|
229 |
+
kpts: list of ndarrays of size (N, 2).
|
230 |
+
colors: string, or list of list of tuples (one for each keypoints).
|
231 |
+
ps: size of the keypoints as float.
|
232 |
+
"""
|
233 |
+
if not isinstance(colors, list):
|
234 |
+
colors = [colors] * len(kpts)
|
235 |
+
if not isinstance(a, list):
|
236 |
+
a = [a] * len(kpts)
|
237 |
+
if axes is None:
|
238 |
+
axes = plt.gcf().axes
|
239 |
+
for ax, k, c, alpha in zip(axes, kpts, colors, a):
|
240 |
+
if isinstance(k, torch.Tensor):
|
241 |
+
k = k.cpu().numpy()
|
242 |
+
ax.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0, alpha=alpha)
|
243 |
+
|
244 |
+
|
245 |
+
def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axes=None):
|
246 |
+
"""Plot matches for a pair of existing images.
|
247 |
+
Args:
|
248 |
+
kpts0, kpts1: corresponding keypoints of size (N, 2).
|
249 |
+
color: color of each match, string or RGB tuple. Random if not given.
|
250 |
+
lw: width of the lines.
|
251 |
+
ps: size of the end points (no endpoint if ps=0)
|
252 |
+
indices: indices of the images to draw the matches on.
|
253 |
+
a: alpha opacity of the match lines.
|
254 |
+
"""
|
255 |
+
fig = plt.gcf()
|
256 |
+
if axes is None:
|
257 |
+
ax = fig.axes
|
258 |
+
ax0, ax1 = ax[0], ax[1]
|
259 |
+
else:
|
260 |
+
ax0, ax1 = axes
|
261 |
+
if isinstance(kpts0, torch.Tensor):
|
262 |
+
kpts0 = kpts0.cpu().numpy()
|
263 |
+
if isinstance(kpts1, torch.Tensor):
|
264 |
+
kpts1 = kpts1.cpu().numpy()
|
265 |
+
assert len(kpts0) == len(kpts1)
|
266 |
+
if color is None:
|
267 |
+
color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
|
268 |
+
elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
|
269 |
+
color = [color] * len(kpts0)
|
270 |
+
|
271 |
+
if lw > 0:
|
272 |
+
for i in range(len(kpts0)):
|
273 |
+
line = matplotlib.patches.ConnectionPatch(
|
274 |
+
xyA=(kpts0[i, 0], kpts0[i, 1]),
|
275 |
+
xyB=(kpts1[i, 0], kpts1[i, 1]),
|
276 |
+
coordsA=ax0.transData,
|
277 |
+
coordsB=ax1.transData,
|
278 |
+
axesA=ax0,
|
279 |
+
axesB=ax1,
|
280 |
+
zorder=1,
|
281 |
+
color=color[i],
|
282 |
+
linewidth=lw,
|
283 |
+
clip_on=True,
|
284 |
+
alpha=a,
|
285 |
+
label=None if labels is None else labels[i],
|
286 |
+
picker=5.0,
|
287 |
+
)
|
288 |
+
line.set_annotation_clip(True)
|
289 |
+
fig.add_artist(line)
|
290 |
+
|
291 |
+
# freeze the axes to prevent the transform to change
|
292 |
+
ax0.autoscale(enable=False)
|
293 |
+
ax1.autoscale(enable=False)
|
294 |
+
|
295 |
+
if ps > 0:
|
296 |
+
ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
|
297 |
+
ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
|
298 |
+
|
299 |
+
|
300 |
+
def add_text(
|
301 |
+
idx,
|
302 |
+
text,
|
303 |
+
pos=(0.01, 0.99),
|
304 |
+
fs=15,
|
305 |
+
color="w",
|
306 |
+
lcolor="k",
|
307 |
+
lwidth=2,
|
308 |
+
ha="left",
|
309 |
+
va="top",
|
310 |
+
):
|
311 |
+
ax = plt.gcf().axes[idx]
|
312 |
+
t = ax.text(
|
313 |
+
*pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes
|
314 |
+
)
|
315 |
+
if lcolor is not None:
|
316 |
+
t.set_path_effects(
|
317 |
+
[
|
318 |
+
path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
|
319 |
+
path_effects.Normal(),
|
320 |
+
]
|
321 |
+
)
|
322 |
+
|
323 |
+
|
324 |
+
def save_plot(path, **kw):
|
325 |
+
"""Save the current figure without any white margin."""
|
326 |
+
plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)
|
yolo.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Generic helper functions to interact with the ultralytics yolo
|
3 |
+
models.
|
4 |
+
"""
|
5 |
+
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
from ultralytics import YOLO
|
9 |
+
|
10 |
+
|
11 |
+
def load_pretrained_model(model_str: str) -> YOLO:
|
12 |
+
"""Loads the pretrained `model`"""
|
13 |
+
return YOLO(model_str)
|
14 |
+
|
15 |
+
|
16 |
+
DEFAULT_TRAIN_PARAMS = {
|
17 |
+
"batch": 16,
|
18 |
+
"epochs": 100,
|
19 |
+
"patience": 100,
|
20 |
+
"imgsz": 640,
|
21 |
+
"lr0": 0.01,
|
22 |
+
"lrf": 0.01,
|
23 |
+
"optimizer": "auto",
|
24 |
+
# data augmentation
|
25 |
+
"mixup": 0.0,
|
26 |
+
"close_mosaic": 10,
|
27 |
+
"degrees": 0.0,
|
28 |
+
"translate": 0.1,
|
29 |
+
"flipud": 0.0,
|
30 |
+
"fliplr": 0.5,
|
31 |
+
}
|
32 |
+
|
33 |
+
|
34 |
+
def train(
|
35 |
+
model: YOLO,
|
36 |
+
data_yaml_path: Path,
|
37 |
+
params: dict,
|
38 |
+
project: Path = Path("data/04_models/yolo/"),
|
39 |
+
experiment_name: str = "train",
|
40 |
+
):
|
41 |
+
"""Main function for running a train run. It saves the results
|
42 |
+
under `project / experiment_name`.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
model (YOLO): result of `load_pretrained_model`.
|
46 |
+
data_yaml_path (Path): filepath to the data.yaml file that specifies the split and classes to train on
|
47 |
+
params (dict): parameters to override when running the training. See https://docs.ultralytics.com/modes/train/#train-settings for a complete list of parameters.
|
48 |
+
project (Path): root path to store the run artifacts and results.
|
49 |
+
experiment_name (str): name of the experiment, that is added to the project root path to store the run.
|
50 |
+
"""
|
51 |
+
assert data_yaml_path.exists(), f"data_yaml_path does not exist, {data_yaml_path}"
|
52 |
+
params = {**DEFAULT_TRAIN_PARAMS, **params}
|
53 |
+
model.train(
|
54 |
+
project=str(project),
|
55 |
+
name=experiment_name,
|
56 |
+
data=data_yaml_path.absolute(),
|
57 |
+
epochs=params["epochs"],
|
58 |
+
lr0=params["lr0"],
|
59 |
+
lrf=params["lrf"],
|
60 |
+
optimizer=params["optimizer"],
|
61 |
+
imgsz=params["imgsz"],
|
62 |
+
close_mosaic=params["close_mosaic"],
|
63 |
+
# Data Augmentation parameters
|
64 |
+
mixup=params["mixup"],
|
65 |
+
degrees=params["degrees"],
|
66 |
+
flipud=params["flipud"],
|
67 |
+
translate=params["translate"],
|
68 |
+
)
|