achouffe commited on
Commit
1726cb2
·
verified ·
1 Parent(s): f3a57dd

feat: initial commit for porting the webapp

Browse files
__init__.py ADDED
File without changes
app.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Simple Gradio Interface to showcase the ML outputs in a UI and webapp.
3
+ """
4
+
5
+ import math
6
+ from pathlib import Path
7
+
8
+ import gradio as gr
9
+ from PIL import Image
10
+ from ultralytics import YOLO
11
+
12
+ import pipeline
13
+ from identification import IdentificationModel, generate_visualization
14
+ from utils import bgr_to_rgb, select_best_device
15
+
16
+ DEFAULT_IMAGE_INDEX = 0
17
+
18
+ DIR_INSTALLED_PIPELINE = Path("./data/pipeline/")
19
+ DIR_EXAMPLES = Path("./data/images/")
20
+ FILEPATH_IDENTIFICATION_LIGHTGLUE_CONFIG = (
21
+ DIR_INSTALLED_PIPELINE / "models/identification/config.yaml"
22
+ )
23
+ FILEPATH_IDENTIFICATION_DB = DIR_INSTALLED_PIPELINE / "db/db.csv"
24
+ FILEPATH_IDENTIFICATION_LIGHTGLUE_FEATURES = (
25
+ DIR_INSTALLED_PIPELINE / "models/identification/features.pt"
26
+ )
27
+ FILEPATH_WEIGHTS_SEGMENTATION_MODEL = (
28
+ DIR_INSTALLED_PIPELINE / "models/segmentation/weights.pt"
29
+ )
30
+ FILEPATH_WEIGHTS_POSE_MODEL = DIR_INSTALLED_PIPELINE / "models/pose/weights.pt"
31
+
32
+
33
+ def examples(dir_examples: Path) -> list[Path]:
34
+ """
35
+ Function to retrieve the default example images.
36
+
37
+ Returns:
38
+ examples (list[Path]): list of image filepaths.
39
+ """
40
+ return list(dir_examples.glob("*.jpg"))
41
+
42
+
43
+ def make_ui(loaded_models: dict[str, YOLO | IdentificationModel]):
44
+ """
45
+ Main entrypoint to wire up the Gradio interface.
46
+
47
+ Args:
48
+ loaded_models (dict[str, YOLO | IdentificationModel]): loaded models ready to run inference with.
49
+
50
+ Returns:
51
+ gradio_ui
52
+ """
53
+ with gr.Blocks() as demo:
54
+ with gr.Row():
55
+ with gr.Column():
56
+ image_input = gr.Image(
57
+ type="pil",
58
+ value=default_value_input,
59
+ label="input image",
60
+ sources=["upload", "clipboard"],
61
+ )
62
+ gr.Examples(
63
+ examples=example_filepaths,
64
+ inputs=image_input,
65
+ )
66
+ submit_btn = gr.Button(value="Identify", variant="primary")
67
+
68
+ with gr.Column():
69
+ with gr.Tab("Prediction"):
70
+ with gr.Row():
71
+ pit_prediction = gr.Text(label="predicted individual")
72
+ name_prediction = gr.Text(label="fish name", visible=False)
73
+ image_feature_matching = gr.Image(
74
+ label="pattern matching", visible=False
75
+ )
76
+ image_extracted_keypoints = gr.Image(
77
+ label="extracted keypoints", visible=False
78
+ )
79
+
80
+ with gr.Tab("Details", visible=False) as tab_details:
81
+ with gr.Column():
82
+ with gr.Row():
83
+ text_rotation_angle = gr.Text(
84
+ label="correction angle (degrees)"
85
+ )
86
+ text_side = gr.Text(label="predicted side")
87
+
88
+ image_pose_keypoints = gr.Image(
89
+ type="pil", label="pose keypoints"
90
+ )
91
+ image_rotated_keypoints = gr.Image(
92
+ type="pil", label="rotated keypoints"
93
+ )
94
+ image_segmentation_mask = gr.Image(type="pil", label="mask")
95
+ image_masked = gr.Image(type="pil", label="masked")
96
+
97
+ def submit_fn(
98
+ loaded_models: dict[str, YOLO | IdentificationModel],
99
+ orig_image: Image.Image,
100
+ ):
101
+ """
102
+ Main function used for the Gradio interface.
103
+
104
+ Args:
105
+ loaded_models (dict[str, YOLO]): loaded models.
106
+ orig_image (PIL): original image picked by the user
107
+
108
+ Returns:
109
+ fish side (str): predicted fish side
110
+ correction angle (str): rotation to do in degrees to re align the image.
111
+ keypoints image (PIL): image displaying the bbox and keypoints from the
112
+ pose estimation model.
113
+ rotated image (PIL): rotated image after applying the correction angle.
114
+ segmentation mask (PIL): segmentation mask predicted by the segmentation model.
115
+ segmented image (PIL): segmented orig_image using the segmentation mask
116
+ and the crop.
117
+ predicted_individual (str): The identified individual.
118
+ pil_image_extracted_keypoints (PIL): The extracted keypoints overlayed on the image.
119
+ feature_matching_image (PIL): The matching of the source with the identified individual.
120
+ """
121
+ # return {}
122
+ results = pipeline.run(loaded_models=loaded_models, pil_image=orig_image)
123
+ side = results["stages"]["pose"]["output"]["side"]
124
+ theta = results["stages"]["pose"]["output"]["theta"]
125
+ pil_image_keypoints = Image.fromarray(
126
+ bgr_to_rgb(results["stages"]["pose"]["output"]["prediction"].plot())
127
+ )
128
+ pil_image_rotated = Image.fromarray(
129
+ results["stages"]["rotation"]["output"]["array_image"]
130
+ )
131
+ pil_image_mask = results["stages"]["segmentation"]["output"]["mask"]
132
+ pil_image_masked_cropped = results["stages"]["crop"]["output"]["pil_image"]
133
+
134
+ viz_dict = generate_visualization(
135
+ pil_image=pil_image_masked_cropped,
136
+ prediction=results["stages"]["identification"]["output"],
137
+ )
138
+
139
+ is_new_individual = (
140
+ results["stages"]["identification"]["output"]["type"] == "new"
141
+ )
142
+
143
+ return {
144
+ text_rotation_angle: f"{math.degrees(theta):0.1f}",
145
+ text_side: side.value,
146
+ image_pose_keypoints: pil_image_keypoints,
147
+ image_rotated_keypoints: pil_image_rotated,
148
+ image_segmentation_mask: pil_image_mask,
149
+ image_masked: pil_image_masked_cropped,
150
+ pit_prediction: (
151
+ "New Fish!"
152
+ if is_new_individual
153
+ else gr.Text(
154
+ results["stages"]["identification"]["output"]["match"]["pit"],
155
+ visible=True,
156
+ )
157
+ ),
158
+ name_prediction: (
159
+ gr.Text(visible=False)
160
+ if is_new_individual
161
+ else gr.Text(
162
+ results["stages"]["identification"]["output"]["match"]["name"],
163
+ visible=True,
164
+ )
165
+ ),
166
+ tab_details: gr.Column(visible=True),
167
+ image_extracted_keypoints: gr.Image(
168
+ viz_dict["keypoints_source"], visible=True
169
+ ),
170
+ image_feature_matching: (
171
+ gr.Image(visible=False)
172
+ if is_new_individual
173
+ else gr.Image(viz_dict["matches"], visible=True)
174
+ ),
175
+ }
176
+
177
+ submit_btn.click(
178
+ fn=lambda pil_image: submit_fn(
179
+ loaded_models=loaded_models,
180
+ orig_image=pil_image,
181
+ ),
182
+ inputs=image_input,
183
+ outputs=[
184
+ text_rotation_angle,
185
+ text_side,
186
+ image_pose_keypoints,
187
+ image_rotated_keypoints,
188
+ image_feature_matching,
189
+ image_segmentation_mask,
190
+ image_masked,
191
+ pit_prediction,
192
+ name_prediction,
193
+ tab_details,
194
+ image_feature_matching,
195
+ image_extracted_keypoints,
196
+ ],
197
+ )
198
+
199
+ return demo
200
+
201
+
202
+ if __name__ == "__main__":
203
+ device = select_best_device()
204
+ # FIXME: get this from the config instead
205
+ extractor_type = "aliked"
206
+ n_keypoints = 1024
207
+ threshold_wasserstein = 0.084
208
+ loaded_models = pipeline.load_models(
209
+ device=device,
210
+ filepath_weights_segmentation_model=FILEPATH_WEIGHTS_SEGMENTATION_MODEL,
211
+ filepath_weights_pose_model=FILEPATH_WEIGHTS_POSE_MODEL,
212
+ filepath_identification_lightglue_features=FILEPATH_IDENTIFICATION_LIGHTGLUE_FEATURES,
213
+ filepath_identification_db=FILEPATH_IDENTIFICATION_DB,
214
+ extractor_type=extractor_type,
215
+ n_keypoints=n_keypoints,
216
+ threshold_wasserstein=threshold_wasserstein,
217
+ )
218
+ model_segmentation = loaded_models["segmentation"]
219
+ example_filepaths = examples(dir_examples=DIR_EXAMPLES)
220
+ default_value_input = Image.open(example_filepaths[DEFAULT_IMAGE_INDEX])
221
+ demo = make_ui(loaded_models=loaded_models)
222
+ demo.launch()
data/04_models/pipeline/webapp/installed/cropped_images/0495c348-a87c-4e70-8a1f-9e07e8510977.jpg ADDED
data/04_models/pipeline/webapp/installed/cropped_images/4d7bd307-1224-41e6-ba89-82dd85e69ea9.jpg ADDED
data/04_models/pipeline/webapp/installed/cropped_images/8042e02f-1d73-4251-8b76-3f634ffe0196.jpg ADDED
data/04_models/pipeline/webapp/installed/cropped_images/8ed739d7-d24f-4919-bd1d-339e34536a9b.jpg ADDED
data/04_models/pipeline/webapp/installed/cropped_images/ec1c6fe4-b2b5-491a-9d7f-542b85e2b078.jpg ADDED
data/images/2023-08-03_1157_7_900088000908738_F.jpg ADDED
data/images/2023-08-16_1547_7_989001006037192_F.jpg ADDED
data/images/2023-10-20_1604_7_900088000909142_F.jpg ADDED
data/images/2023-10-20_1625_7_900088000913636_F.jpg ADDED
data/images/989,001,006,004,046_F.jpg ADDED
data/pipeline/config.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ filepath_db: db/db.csv
2
+ filepath_model_segmentation_weights: models/segmentation/weights.pt
3
+ filepath_model_pose_weights: models/pose/weights.pt
4
+ filepath_model_identification_features: models/identification/features.pt
5
+ filepath_model_identification_config: models/identification/config.yaml
6
+ filepath_config: config.yaml
7
+ root_dir: .
8
+ dir_cropped_images: cropped_images
9
+ dir_models: models
10
+ dir_models_segmentation: models/segmentation
11
+ dir_models_pose: models/pose
12
+ dir_models_identification: models/identification
13
+ dir_db: db
data/pipeline/db/db.csv ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ,Unnamed: 0.3,Unnamed: 0.2,Unnamed: 0.1,Unnamed: 0,filepath,year,pit,created_at_filepath,created_at_exif,is_electrofishing,is_guide_angling,exif_make,exif_model,exif_focal_length,exif_image_width,exif_image_height,exif_shutter_speed,exif_aperture,exif_brightness,coordinates_lat,coordinates_lon,uuid,success,filepath_crop,pose_theta,pose_fish_side,name
2
+ 0,0,0,1311,1312,data/01_raw/elk-river/2023/Guide/Fish/Nupqu 2/2023-09-15_1040_2_900088000908738_F.jpg,2023,900088000908738,2023-09-15 10:40:00,,False,True,,,,4032.0,3024.0,,,,,,8042e02f-1d73-4251-8b76-3f634ffe0196,True,data/04_models/pipeline/webapp/installed/cropped_images/8042e02f-1d73-4251-8b76-3f634ffe0196.jpg,2.9707219143490304,left,Norma Fisher
3
+ 1,1,1,2306,2308,"data/01_raw/elk-river/2022/Boat_Electrofishing/Fish/August/989,001,006,037,192_F.jpg",2022,989001006037192,,2022-08-10 14:24:34,True,False,Apple,iPhone 12 Pro Max,5.1,4032.0,3024.0,10.17001733102253,1.3561438092556088,7.527949339379251,49.447309527777776,-115.05773913888888,0495c348-a87c-4e70-8a1f-9e07e8510977,True,data/04_models/pipeline/webapp/installed/cropped_images/0495c348-a87c-4e70-8a1f-9e07e8510977.jpg,3.0887346928329578,left,Jorge Sullivan
4
+ 2,2,2,142,142,data/01_raw/elk-river/2023/Boat Electrofishing/Fish/Aug 16 2023/2023-08-16_1146_7_900088000913914_F.jpg,2023,900088000913914,2023-08-16 11:46:00,2023-08-16 11:47:39,True,False,Apple,iPhone 13 Pro Max,5.7,4032.0,3024.0,9.47009,1.169925,7.41103,49.479367833333335,-115.06903196666666,4d7bd307-1224-41e6-ba89-82dd85e69ea9,True,data/04_models/pipeline/webapp/installed/cropped_images/4d7bd307-1224-41e6-ba89-82dd85e69ea9.jpg,-0.0302701344557054,left,Elizabeth Woods
5
+ 3,3,3,2264,2266,"data/01_raw/elk-river/2022/Boat_Electrofishing/Fish/August/989,001,006,004,046_F_2.jpg",2022,989001006004046,,2022-08-05 14:00:47,True,False,Apple,iPhone 13 Pro Max,5.7,4032.0,3024.0,10.445739257101238,1.1699250021066825,8.260268601117538,49.374336702777775,-115.01017866666666,8ed739d7-d24f-4919-bd1d-339e34536a9b,True,data/04_models/pipeline/webapp/installed/cropped_images/8ed739d7-d24f-4919-bd1d-339e34536a9b.jpg,-0.1349107606696157,left,Susan Wagner
6
+ 4,4,4,1825,1826,"data/01_raw/elk-river/2022/Boat_Electrofishing/Fish/October/900,088,000,909,142_F.jpg",2022,900088000909142,,2022-10-05 14:14:45,True,False,Apple,iPhone 12 Pro Max,5.1,4032.0,3024.0,6.922439780250739,1.3561438092556088,5.0997288781417165,49.60164128611111,-114.96475872222224,ec1c6fe4-b2b5-491a-9d7f-542b85e2b078,True,data/04_models/pipeline/webapp/installed/cropped_images/ec1c6fe4-b2b5-491a-9d7f-542b85e2b078.jpg,-3.138177921264415,left,Peter Montgomery
data/pipeline/models/identification/config.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ n_keypoints: 1024
2
+ extractor_type: aliked
data/pipeline/models/identification/features.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:053c31af351b7aa685c003a26435eadabfcece6d046f10060f8ad5f90ab9e30e
3
+ size 2689138
data/pipeline/models/pose/weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4081cc996b6d436bc0131e5756b076cdd58c372a319df3aaa688abf559b286c8
3
+ size 5701249
data/pipeline/models/segmentation/weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6bc54dad5a4d62746ddf4e7edb4402642e905663048da1ff61af969a2d2db604
3
+ size 6015837
data/summary.csv ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ,Unnamed: 0.1,Unnamed: 0,filepath,year,pit,created_at_filepath,created_at_exif,is_electrofishing,is_guide_angling,exif_make,exif_model,exif_focal_length,exif_image_width,exif_image_height,exif_shutter_speed,exif_aperture,exif_brightness,coordinates_lat,coordinates_lon,uuid,success,filepath_crop,pose_theta,pose_fish_side
2
+ 0,1311,1312,data/01_raw/elk-river/2023/Guide/Fish/Nupqu 2/2023-09-15_1040_2_900088000908738_F.jpg,2023,900088000908738,2023-09-15 10:40:00,,False,True,,,,4032.0,3024.0,,,,,,8042e02f-1d73-4251-8b76-3f634ffe0196,True,data/03_processed/identification/input/images/8042e02f-1d73-4251-8b76-3f634ffe0196.jpg,2.9707219143490304,left
3
+ 1,2306,2308,"data/01_raw/elk-river/2022/Boat_Electrofishing/Fish/August/989,001,006,037,192_F.jpg",2022,989001006037192,,2022-08-10 14:24:34,True,False,Apple,iPhone 12 Pro Max,5.1,4032.0,3024.0,10.17001733102253,1.3561438092556088,7.527949339379251,49.447309527777776,-115.05773913888888,0495c348-a87c-4e70-8a1f-9e07e8510977,True,data/03_processed/identification/input/images/0495c348-a87c-4e70-8a1f-9e07e8510977.jpg,3.0887346928329578,left
4
+ 2,142,142,data/01_raw/elk-river/2023/Boat Electrofishing/Fish/Aug 16 2023/2023-08-16_1146_7_900088000913914_F.jpg,2023,900088000913914,2023-08-16 11:46:00,2023-08-16 11:47:39,True,False,Apple,iPhone 13 Pro Max,5.7,4032.0,3024.0,9.47009,1.169925,7.41103,49.479367833333335,-115.06903196666666,4d7bd307-1224-41e6-ba89-82dd85e69ea9,True,data/03_processed/identification/input/images/4d7bd307-1224-41e6-ba89-82dd85e69ea9.jpg,-0.0302701344557054,left
5
+ 3,2264,2266,"data/01_raw/elk-river/2022/Boat_Electrofishing/Fish/August/989,001,006,004,046_F_2.jpg",2022,989001006004046,,2022-08-05 14:00:47,True,False,Apple,iPhone 13 Pro Max,5.7,4032.0,3024.0,10.445739257101238,1.1699250021066825,8.260268601117538,49.374336702777775,-115.01017866666666,8ed739d7-d24f-4919-bd1d-339e34536a9b,True,data/03_processed/identification/input/images/8ed739d7-d24f-4919-bd1d-339e34536a9b.jpg,-0.1349107606696157,left
6
+ 4,1825,1826,"data/01_raw/elk-river/2022/Boat_Electrofishing/Fish/October/900,088,000,909,142_F.jpg",2022,900088000909142,,2022-10-05 14:14:45,True,False,Apple,iPhone 12 Pro Max,5.1,4032.0,3024.0,6.922439780250739,1.3561438092556088,5.0997288781417165,49.60164128611111,-114.96475872222224,ec1c6fe4-b2b5-491a-9d7f-542b85e2b078,True,data/03_processed/identification/input/images/ec1c6fe4-b2b5-491a-9d7f-542b85e2b078.jpg,-3.138177921264415,left
identification.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module to manage the identification model. One can load and run inference on a
3
+ new image.
4
+ """
5
+
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Any
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torch
13
+ from lightglue import ALIKED, DISK, SIFT, LightGlue, SuperPoint
14
+ from lightglue.utils import numpy_image_to_torch, rbd
15
+ from PIL import Image
16
+
17
+ from utils import (
18
+ extractor_type_to_extractor,
19
+ extractor_type_to_matcher,
20
+ get_scores,
21
+ wasserstein,
22
+ )
23
+ from viz2d import keypoints_as_pil_image, matches_as_pil_image
24
+
25
+
26
+ @dataclass
27
+ class IdentificationModel:
28
+ extractor: SIFT | ALIKED | DISK | SuperPoint
29
+ extractor_type: str
30
+ threshold_wasserstein: float
31
+ n_keypoints: int
32
+ matcher: LightGlue
33
+ features_dict: dict[str, torch.Tensor]
34
+ df_db: pd.DataFrame
35
+
36
+
37
+ def load(
38
+ device: torch.device,
39
+ filepath_features: Path,
40
+ filepath_db: Path,
41
+ extractor_type: str,
42
+ n_keypoints: int,
43
+ threshold_wasserstein: float,
44
+ ) -> IdentificationModel:
45
+ """
46
+ Load the IdentificationModel provided the arguments.
47
+
48
+ Args:
49
+ device (torch.device): cpu|cuda
50
+ filepath_features (Path): filepath to the torch cached features on the
51
+ dataset one wants to predict on.
52
+ filepath_db (Path): filepath to the csv file containing the dataset to
53
+ compare with.
54
+ extractor_type (str): in {sift, disk, aliked, superpoint}.
55
+ n_keypoints (int): maximum number of keypoints to extract with the extractor.
56
+ threshold_wasserstein (float): threshold for the wasserstein distance to consider it a match.
57
+
58
+ Returns:
59
+ IdentificationModel: an IdentificationModel instance.
60
+
61
+ Raises:
62
+ AssertionError when the extractor_type or n_keypoints are not valid.
63
+ """
64
+
65
+ allowed_extractor_types = ["sift", "disk", "aliked", "superpoint"]
66
+ assert (
67
+ extractor_type in allowed_extractor_types
68
+ ), f"extractor_type should be in {allowed_extractor_types}"
69
+ assert 1 <= n_keypoints <= 5000, f"n_keypoints should be in range 1..5000"
70
+ assert (
71
+ 0.0 <= threshold_wasserstein <= 1.0
72
+ ), f"threshold_wasserstein should be in 0..1"
73
+
74
+ extractor = extractor_type_to_extractor(
75
+ device=device,
76
+ extractor_type=extractor_type,
77
+ n_keypoints=n_keypoints,
78
+ )
79
+ matcher = extractor_type_to_matcher(
80
+ device=device,
81
+ extractor_type=extractor_type,
82
+ )
83
+ features_dict = torch.load(filepath_features)
84
+ df_db = pd.read_csv(filepath_db)
85
+
86
+ return IdentificationModel(
87
+ extractor_type=extractor_type,
88
+ n_keypoints=n_keypoints,
89
+ extractor=extractor,
90
+ matcher=matcher,
91
+ features_dict=features_dict,
92
+ df_db=df_db,
93
+ threshold_wasserstein=threshold_wasserstein,
94
+ )
95
+
96
+
97
+ def _make_prediction_dict(
98
+ model: IdentificationModel,
99
+ indexed_matches: dict[str, dict[str, torch.Tensor]],
100
+ ) -> dict[str, Any]:
101
+ """
102
+ Return the prediction dict. Two types of predictions can be made:
103
+ 1. A new individual
104
+ 2. A match from the dataset
105
+
106
+ Returns:
107
+ type (str): new|match
108
+ match (dict): dict containing the following keys if type==match.
109
+ pit (str): the PIT of the matched individual.
110
+ name (str): the name of the matched individual.
111
+ filepath_crop_closest (Path): the filepath to the matched individual.
112
+ features (torch.Tensor): LightGlue Features of the matched individual.
113
+ matches (torch.Tensor): LightGlue Matches of the matched individual.
114
+ """
115
+ indexed_scores = {k: get_scores(v) for k, v in indexed_matches.items()}
116
+ indexed_wasserstein = {k: wasserstein(v) for k, v in indexed_scores.items()}
117
+ sorted_wasserstein = sorted(
118
+ indexed_wasserstein.items(), key=lambda item: item[1], reverse=True
119
+ )
120
+ shared_record = {
121
+ "indexed_matches": indexed_matches,
122
+ "indexed_scores": indexed_scores,
123
+ "indexed_wasserstein": indexed_wasserstein,
124
+ "sorted_wasserstein": sorted_wasserstein,
125
+ }
126
+ if not sorted_wasserstein:
127
+ return {"type": "new", **shared_record}
128
+ elif model.threshold_wasserstein > sorted_wasserstein[0][1]:
129
+ return {"type": "new", **shared_record}
130
+ else:
131
+ prediction_uuid = sorted_wasserstein[0][0]
132
+ db_row = model.df_db[model.df_db["uuid"] == prediction_uuid].iloc[0]
133
+ return {
134
+ "type": "match",
135
+ "match": {
136
+ "pit": db_row["pit"],
137
+ "name": db_row["name"],
138
+ "filepath_crop": db_row["filepath_crop"],
139
+ "features": model.features_dict[prediction_uuid],
140
+ "matches": indexed_matches[prediction_uuid],
141
+ },
142
+ **shared_record,
143
+ }
144
+
145
+
146
+ # FIXME: Properly run a batch inference here to make it fast on GPU.
147
+ def _batch_inference(
148
+ model: IdentificationModel,
149
+ feats0: dict,
150
+ ) -> dict[str, dict[str, torch.Tensor]]:
151
+ """
152
+ Run batch inference on feats0 with the IdentificationModel.
153
+ Returns an indexed_matches datastructure containing the results of each run
154
+ for the given uuid in the features_dict.
155
+ """
156
+ indexed_matches = {}
157
+ for uuid in model.features_dict.keys():
158
+ matches01 = model.matcher(
159
+ {"image0": feats0, "image1": model.features_dict[uuid]}
160
+ )
161
+ indexed_matches[uuid] = matches01
162
+ return indexed_matches
163
+
164
+
165
+ def predict(model: IdentificationModel, pil_image: Image.Image) -> dict:
166
+ """
167
+ Run inference on the pil_image on all the features_dict entries from the
168
+ IdentificationModel.
169
+
170
+ Note: It will try to optimize inference depending on the available device
171
+ (cpu|gpu).
172
+
173
+ Args:
174
+ model (IdentificationModel): identification model to run inference with.
175
+ pil_image (PIL): input image to run the inference on.
176
+
177
+ Returns:
178
+ type (str): new|match.
179
+ source (dict): contains the `features` of the input image.
180
+ match (dict): dict containing the following keys if type==match.
181
+ pit (str): the PIT of the matched individual.
182
+ name (str): the name of the matched individual.
183
+ filepath_crop (Path): the filepath to the matched individual.
184
+ features (torch.Tensor): LightGlue Features of the matched individual.
185
+ matches (torch.Tensor): LightGlue Matches of the matched individual.
186
+ """
187
+ # Disable gradient accumulation to make inference faster
188
+ torch.set_grad_enabled(False)
189
+ torch_image = numpy_image_to_torch(np.array(pil_image))
190
+ feats0 = model.extractor.extract(torch_image)
191
+ indexed_matches = _batch_inference(model=model, feats0=feats0)
192
+ prediction_dict = _make_prediction_dict(
193
+ model=model,
194
+ indexed_matches=indexed_matches,
195
+ )
196
+ return {"source": {"features": feats0}, **prediction_dict}
197
+
198
+
199
+ def generate_visualization(pil_image: Image.Image, prediction: dict) -> dict:
200
+ if "type" not in prediction:
201
+ return {}
202
+ elif prediction["type"] == "match":
203
+ pil_image_masked_closest = Image.open(prediction["match"]["filepath_crop"])
204
+ torch_image0 = np.array(pil_image)
205
+ torch_image1 = np.array(pil_image_masked_closest)
206
+ torch_images = [torch_image0, torch_image1]
207
+ feats0 = prediction["source"]["features"]
208
+ feats1 = prediction["match"]["features"]
209
+ matches01 = prediction["match"]["matches"]
210
+
211
+ feats0, feats1, matches01 = [
212
+ rbd(x) for x in [feats0, feats1, matches01]
213
+ ] # remove batch dimension
214
+ pil_image_matches = matches_as_pil_image(
215
+ torch_images=torch_images,
216
+ feats0=feats0,
217
+ feats1=feats1,
218
+ matches01=matches01,
219
+ mode="column",
220
+ )
221
+ pil_image_keypoints_source = keypoints_as_pil_image(
222
+ torch_image=torch_image0,
223
+ feats=feats0,
224
+ ps=23,
225
+ )
226
+ return {
227
+ "matches": pil_image_matches,
228
+ "keypoints_source": pil_image_keypoints_source,
229
+ }
230
+ elif prediction["type"] == "new":
231
+ torch_image0 = np.array(pil_image)
232
+ feats0 = prediction["source"]["features"]
233
+ feats0 = rbd(feats0) # remove the batch dimension
234
+ pil_image_keypoints_source = keypoints_as_pil_image(
235
+ torch_image=torch_image0,
236
+ feats=feats0,
237
+ ps=23,
238
+ )
239
+ return {"keypoints_source": pil_image_keypoints_source}
240
+ else:
241
+ return {}
pipeline.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Any
3
+
4
+ import numpy as np
5
+ import torch
6
+ from PIL import Image
7
+ from ultralytics import YOLO
8
+
9
+ import identification
10
+ import pose
11
+ import segmentation
12
+ from identification import IdentificationModel
13
+ from utils import (
14
+ PictureLayout,
15
+ crop,
16
+ get_picture_layout,
17
+ get_segmentation_mask_crop_box,
18
+ )
19
+
20
+
21
+ def load_pose_and_segmentation_models(
22
+ filepath_weights_segmentation_model: Path,
23
+ filepath_weights_pose_model: Path,
24
+ ) -> dict[str, YOLO]:
25
+ """
26
+ Load into memory the models used by the pipeline.
27
+
28
+ Returns:
29
+ segmentation (YOLO): segmentation model.
30
+ pose (YOLO): pose estimation model.
31
+ """
32
+ model_segmentation = segmentation.load_pretrained_model(
33
+ str(filepath_weights_segmentation_model)
34
+ )
35
+
36
+ model_pose = pose.load_pretrained_model(str(filepath_weights_pose_model))
37
+ return {
38
+ "segmentation": model_segmentation,
39
+ "pose": model_pose,
40
+ }
41
+
42
+
43
+ def load_models(
44
+ filepath_weights_segmentation_model: Path,
45
+ filepath_weights_pose_model: Path,
46
+ device: torch.device,
47
+ filepath_identification_lightglue_features: Path,
48
+ filepath_identification_db: Path,
49
+ extractor_type: str,
50
+ n_keypoints: int,
51
+ threshold_wasserstein: float,
52
+ ) -> dict[str, YOLO | IdentificationModel]:
53
+ """
54
+ Load into memory the models used by the pipeline.
55
+
56
+ Returns:
57
+ segmentation (YOLO): segmentation model.
58
+ pose (YOLO): pose estimation model.
59
+ identification (IdentificationModel): identification model.
60
+ """
61
+ loaded_pose_seg_models = load_pose_and_segmentation_models(
62
+ filepath_weights_segmentation_model=filepath_weights_segmentation_model,
63
+ filepath_weights_pose_model=filepath_weights_pose_model,
64
+ )
65
+
66
+ model_identification = identification.load(
67
+ device=device,
68
+ filepath_features=filepath_identification_lightglue_features,
69
+ filepath_db=filepath_identification_db,
70
+ n_keypoints=n_keypoints,
71
+ extractor_type=extractor_type,
72
+ threshold_wasserstein=threshold_wasserstein,
73
+ )
74
+
75
+ return {**loaded_pose_seg_models, "identification": model_identification}
76
+
77
+
78
+ def run_preprocess(pil_image: Image.Image) -> dict[str, Any]:
79
+ """
80
+ Run the preprocess stage of the pipeline.
81
+
82
+ Args:
83
+ pil_image (PIL): original image.
84
+
85
+ Returns:
86
+ pil_image (PIL): rotated image to make it a landscape.
87
+ layout (PictureLayout): layout type of the input image.
88
+ """
89
+ picture_layout = get_picture_layout(pil_image=pil_image)
90
+
91
+ # If the image is in Portrait Mode, we turn it into Landscape
92
+ pil_image_preprocessed = (
93
+ pil_image.rotate(angle=90, expand=True)
94
+ if picture_layout == PictureLayout.PORTRAIT
95
+ else pil_image
96
+ )
97
+ return {
98
+ "pil_image": pil_image_preprocessed,
99
+ "layout": picture_layout,
100
+ }
101
+
102
+
103
+ def run_pose(model: YOLO, pil_image: Image.Image) -> dict[str, Any]:
104
+ """
105
+ Run the pose stage of the pipeline.
106
+
107
+ Args:
108
+ model (YOLO): loaded pose estimation model.
109
+ pil_image (PIL): Image to run the model on.
110
+
111
+ Returns:
112
+ prediction: Raw prediction from the model.
113
+ orig_image: original image used for inference after the preprocessing
114
+ stages applied by ultralytics.
115
+ keypoints_xy (np.ndarray): keypoints in xy format.
116
+ keypoints_xyn (np.ndarray): keyoints in xyn format.
117
+ theta (float): angle in radians to rotate the image to re-align it
118
+ horizontally.
119
+ side (FishSide): Predicted side of the fish.
120
+ """
121
+ return pose.predict(model=model, pil_image=pil_image)
122
+
123
+
124
+ def run_crop(
125
+ pil_image_mask: Image.Image,
126
+ pil_image_masked: Image.Image,
127
+ padding: int,
128
+ ) -> dict[str, Any]:
129
+ """
130
+ Run the crop on the mask and masked images.
131
+
132
+ Args:
133
+ pil_image_mask (PIL): Image containing the segmentation mask.
134
+ pil_image_masked (PIL): Image containing the applied pil_image_mask on
135
+ the original image.
136
+ padding (int): by how much do we want to pad the result image?
137
+
138
+ Returns:
139
+ box (Tuple[int, int, int, int]): 4 tuple representing a rectangle (x1,
140
+ y1, x2, y2) with the upper left corner given first.
141
+ pil_image (PIL): cropped masked image.
142
+ """
143
+
144
+ box_crop = get_segmentation_mask_crop_box(
145
+ pil_image_mask=pil_image_mask,
146
+ padding=padding,
147
+ )
148
+ pil_image_masked_cropped = crop(
149
+ pil_image=pil_image_masked,
150
+ box=box_crop,
151
+ )
152
+ return {
153
+ "box": box_crop,
154
+ "pil_image": pil_image_masked_cropped,
155
+ }
156
+
157
+
158
+ def run_rotation(
159
+ pil_image: Image.Image,
160
+ angle_rad: float,
161
+ keypoints_xy: np.ndarray,
162
+ ) -> dict[str, Any]:
163
+ """
164
+ Run the rotation stage of the pipeline.
165
+
166
+ Args:
167
+ pil_image (PIL): image to run the rotation on.
168
+ angle_rad (float): angle in radian to rotate the image.
169
+ keypoints_xy (np.ndarray): keypoints from the pose estimation
170
+ prediction in xy format.
171
+
172
+ Returns:
173
+ array_image (np.ndarray): rotated array_image as a 2D numpy array.
174
+ keypoints_xy (np.ndarray): rotated keypoints in xy format.
175
+ pil_image (PIL): rotated PIL image.
176
+ """
177
+ results_rotation = pose.rotate_image_and_keypoints_xy(
178
+ angle_rad=angle_rad,
179
+ array_image=np.array(pil_image),
180
+ keypoints_xy=keypoints_xy,
181
+ )
182
+ pil_image_rotated = Image.fromarray(results_rotation["array_image"])
183
+
184
+ return {
185
+ "pil_image": pil_image_rotated,
186
+ "array_image": results_rotation["array_image"],
187
+ "keypoints_xy": results_rotation["keypoints_xy"],
188
+ }
189
+
190
+
191
+ def run_segmentation(model: YOLO, pil_image: Image.Image) -> dict[str, Any]:
192
+ """
193
+ Run the segmentation stage of the pipeline.
194
+
195
+ Args:
196
+ pil_image (PIL): image to run the rotation on.
197
+ model (YOLO): segmentation model.
198
+ prediction in xy format.
199
+
200
+ Returns:
201
+ prediction: Raw prediction from the model.
202
+ orig_image: original image used for inference
203
+ after preprocessing stages applied by
204
+ ultralytics.
205
+ mask (PIL): postprocessed mask in white and black format - used for visualization
206
+ mask_raw (np.ndarray): Raw mask not postprocessed
207
+ masked (PIL): mask applied to the pil_image.
208
+ """
209
+ results_segmentation = segmentation.predict(
210
+ model=model,
211
+ pil_image=pil_image,
212
+ )
213
+ return results_segmentation
214
+
215
+
216
+ def run_pre_identification_stages(
217
+ loaded_models: dict[str, YOLO],
218
+ pil_image: Image.Image,
219
+ param_crop_padding: int = 0,
220
+ ) -> dict[str, Any]:
221
+ """
222
+ Run the partial ML pipeline on `pil_image` up to identifying the fish. It
223
+ prepares the input image `pil_image` to make it possible to identify it.
224
+
225
+ Args:
226
+ loaded_models (dict[str, YOLO]): resut of calling `load_models`.
227
+ pil_image (PIL): Image to run the pipeline on.
228
+ param_crop_padding (int): how much to pad the resulting segmentated
229
+ image when cropped.
230
+
231
+ Returns:
232
+ order (list[str]): the stages and their order.
233
+ stages (dict[str, Any]): the description of each stage, its
234
+ input and output.
235
+ """
236
+
237
+ # Unpacking the loaded models
238
+ model_pose = loaded_models["pose"]
239
+ model_segmentation = loaded_models["segmentation"]
240
+
241
+ # Stage: Preprocess
242
+ results_preprocess = run_preprocess(pil_image=pil_image)
243
+
244
+ # Stage: Pose estimation
245
+ pil_image_preprocessed = results_preprocess["pil_image"]
246
+ results_pose = run_pose(model=model_pose, pil_image=pil_image_preprocessed)
247
+
248
+ # Stage: Rotation
249
+ results_rotation = run_rotation(
250
+ pil_image=pil_image_preprocessed,
251
+ keypoints_xy=results_pose["keypoints_xy"],
252
+ angle_rad=results_pose["theta"],
253
+ )
254
+
255
+ # Stage: Segmentation
256
+ pil_image_rotated = Image.fromarray(results_rotation["array_image"])
257
+ results_segmentation = run_segmentation(
258
+ model=model_segmentation, pil_image=pil_image_rotated
259
+ )
260
+
261
+ # Stage: Crop
262
+ results_crop = run_crop(
263
+ pil_image_mask=results_segmentation["mask"],
264
+ pil_image_masked=results_segmentation["masked"],
265
+ padding=param_crop_padding,
266
+ )
267
+
268
+ return {
269
+ "order": [
270
+ "preprocess",
271
+ "pose",
272
+ "rotation",
273
+ "segmentation",
274
+ "crop",
275
+ ],
276
+ "stages": {
277
+ "preprocess": {
278
+ "input": {"pil_image": pil_image},
279
+ "output": results_preprocess,
280
+ },
281
+ "pose": {
282
+ "input": {"pil_image": pil_image_preprocessed},
283
+ "output": results_pose,
284
+ },
285
+ "rotation": {
286
+ "input": {
287
+ "pil_image": pil_image_preprocessed,
288
+ "angle_rad": results_pose["theta"],
289
+ "keypoints_xy": results_pose["keypoints_xy"],
290
+ },
291
+ "output": results_rotation,
292
+ },
293
+ "segmentation": {
294
+ "input": {"pil_image": pil_image_rotated},
295
+ "output": results_segmentation,
296
+ },
297
+ "crop": {
298
+ "input": {
299
+ "pil_image_mask": results_segmentation["mask"],
300
+ "pil_image_masked": results_segmentation["masked"],
301
+ "padding": param_crop_padding,
302
+ },
303
+ "output": results_crop,
304
+ },
305
+ },
306
+ }
307
+
308
+
309
+ def run(
310
+ loaded_models: dict[str, YOLO | IdentificationModel],
311
+ pil_image: Image.Image,
312
+ param_crop_padding: int = 0,
313
+ ) -> dict[str, Any]:
314
+ """
315
+ Run the ML pipeline on `pil_image`.
316
+
317
+ Args:
318
+ loaded_models (dict[str, YOLO]): resut of calling `load_models`.
319
+ pil_image (PIL): Image to run the pipeline on.
320
+ param_crop_padding (int): how much to pad the resulting segmentated
321
+ image when cropped.
322
+
323
+ Returns:
324
+ order (list[str]): the stages and their order.
325
+ stages (dict[str, Any]): the description of each stage, its
326
+ input and output.
327
+ """
328
+ model_identification = loaded_models["identification"]
329
+
330
+ results = run_pre_identification_stages(
331
+ loaded_models=loaded_models,
332
+ pil_image=pil_image,
333
+ param_crop_padding=param_crop_padding,
334
+ )
335
+
336
+ results_crop = results["stages"]["crop"]["output"]
337
+ results_identification = identification.predict(
338
+ model=model_identification,
339
+ pil_image=results_crop["pil_image"],
340
+ )
341
+
342
+ results["order"].append("identification")
343
+ results["stages"]["identification"] = {
344
+ "input": {"pil_image": results_crop["pil_image"]},
345
+ "output": results_identification,
346
+ }
347
+
348
+ return results
pose.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module to manage the pose detection model.
3
+ """
4
+
5
+ from enum import Enum
6
+ from pathlib import Path
7
+ from typing import Any
8
+
9
+ import numpy as np
10
+ from PIL import Image
11
+ from ultralytics import YOLO
12
+
13
+ import yolo
14
+ from utils import get_angle_correction, get_keypoint, rotate_image_and_keypoints_xy
15
+
16
+
17
+ class FishSide(Enum):
18
+ """
19
+ Represents the Side of the Fish.
20
+ """
21
+
22
+ RIGHT = "right"
23
+ LEFT = "left"
24
+
25
+
26
+ def predict_fish_side(
27
+ array_image: np.ndarray,
28
+ keypoints_xy: np.ndarray,
29
+ classes_dictionnary: dict[int, str],
30
+ ) -> FishSide:
31
+ """
32
+ Predict which side of the fish is displayed on the image.
33
+
34
+ Args:
35
+ array_image (np.ndarray): numpy array representing the image.
36
+ keypoints_xy (np.ndarray): detected keypoints on array_image in
37
+ xy format.
38
+ classes_dictionnary (dict[int, str]): mapping of class instance
39
+ to.
40
+
41
+ Returns:
42
+ FishSide: Predicted side of the fish.
43
+ """
44
+
45
+ theta = get_angle_correction(
46
+ keypoints_xy=keypoints_xy,
47
+ array_image=array_image,
48
+ classes_dictionnary=classes_dictionnary,
49
+ )
50
+ rotation_results = rotate_image_and_keypoints_xy(
51
+ angle_rad=theta, array_image=array_image, keypoints_xy=keypoints_xy
52
+ )
53
+
54
+ # We check if the eyes is on the left/right of one of the fins.
55
+ k_eye = get_keypoint(
56
+ class_name="eye",
57
+ keypoints=rotation_results["keypoints_xy"],
58
+ classes_dictionnary=classes_dictionnary,
59
+ )
60
+ k_anal_fin_base = get_keypoint(
61
+ class_name="anal_fin_base",
62
+ keypoints=rotation_results["keypoints_xy"],
63
+ classes_dictionnary=classes_dictionnary,
64
+ )
65
+
66
+ if k_eye[0] <= k_anal_fin_base[0]:
67
+ return FishSide.LEFT
68
+ else:
69
+ return FishSide.RIGHT
70
+
71
+
72
+ # Model prediction classes
73
+ CLASSES_DICTIONNARY = {
74
+ 0: "eye",
75
+ 1: "front_fin_base",
76
+ 2: "tail_bottom_tip",
77
+ 3: "tail_top_tip",
78
+ 4: "dorsal_fin_base",
79
+ 5: "pelvic_fin_base",
80
+ 6: "anal_fin_base",
81
+ }
82
+
83
+
84
+ def load_pretrained_model(model_str: str) -> YOLO:
85
+ """
86
+ Load the pretrained model.
87
+ """
88
+ return yolo.load_pretrained_model(model_str)
89
+
90
+
91
+ def train(
92
+ model: YOLO,
93
+ data_yaml_path: Path,
94
+ params: dict,
95
+ project: Path = Path("data/04_models/yolo/"),
96
+ experiment_name: str = "train",
97
+ ):
98
+ """Main function for running a train run. It saves the results
99
+ under `project / experiment_name`.
100
+
101
+ Args:
102
+ model (YOLO): result of `load_pretrained_model`.
103
+ data_yaml_path (Path): filepath to the data.yaml file that specifies the split and classes to train on
104
+ params (dict): parameters to override when running the training. See https://docs.ultralytics.com/modes/train/#train-settings for a complete list of parameters.
105
+ project (Path): root path to store the run artifacts and results.
106
+ experiment_name (str): name of the experiment, that is added to the project root path to store the run.
107
+ """
108
+ return yolo.train(
109
+ model=model,
110
+ data_yaml_path=data_yaml_path,
111
+ params=params,
112
+ project=project,
113
+ experiment_name=experiment_name,
114
+ )
115
+
116
+
117
+ def predict(
118
+ model: YOLO,
119
+ pil_image: Image.Image,
120
+ classes_dictionnary: dict[int, str] = CLASSES_DICTIONNARY,
121
+ ) -> dict[str, Any]:
122
+ """
123
+ Given a loaded model and a PIL image, it returns a map containing the
124
+ keypoints predictions.
125
+
126
+ Args:
127
+ model (YOLO): loaded YOLO model for pose estimation.
128
+ pil_image (PIL): image to run the model on.
129
+ classes_dictionnary (dict[int, str]): mapping of class instance to
130
+ class name.
131
+
132
+ Returns:
133
+ prediction: Raw prediction from the model.
134
+ orig_image: original image used for inference after the preprocessing
135
+ stages applied by ultralytics.
136
+ keypoints_xy (np.ndarray): keypoints in xy format.
137
+ keypoints_xyn (np.ndarray): keyoints in xyn format.
138
+ theta (float): angle in radians to rotate the image to re-align it
139
+ horizontally.
140
+ side (FishSide): Predicted side of the fish.
141
+ """
142
+ predictions = model(pil_image)
143
+ print(predictions)
144
+ orig_image = predictions[0].orig_img
145
+ keypoints_xy = predictions[0].keypoints.xy.cpu().numpy().squeeze()
146
+
147
+ theta = get_angle_correction(
148
+ keypoints_xy=keypoints_xy,
149
+ array_image=orig_image,
150
+ classes_dictionnary=classes_dictionnary,
151
+ )
152
+
153
+ side = predict_fish_side(
154
+ array_image=orig_image,
155
+ keypoints_xy=keypoints_xy,
156
+ classes_dictionnary=classes_dictionnary,
157
+ )
158
+
159
+ return {
160
+ "prediction": predictions[0],
161
+ "orig_image": orig_image,
162
+ "keypoints_xy": keypoints_xy,
163
+ "keypoints_xyn": predictions[0].keypoints.xyn.cpu().numpy().squeeze(),
164
+ "theta": theta,
165
+ "side": side,
166
+ }
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio==5.4.*
2
+ pandas==2.2.*
3
+ torch==2.5.*
4
+ numpy==2.1.*
5
+ tqdm==4.66.*
6
+ ultralytics==8.3.*
7
+ matplotlib==3.9.*
8
+ lightglue @ git+https://github.com/cvg/LightGlue.git@edb2b838efb2ecfe3f88097c5fad9887d95aedad
segmentation.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Module to manage the segmentation YOLO model.
3
+ """
4
+
5
+ from pathlib import Path
6
+ from typing import Any
7
+
8
+ import cv2
9
+ import numpy as np
10
+ from PIL import Image
11
+ from ultralytics import YOLO
12
+
13
+ import yolo
14
+
15
+
16
+ def load_pretrained_model(model_str: str) -> YOLO:
17
+ """
18
+ Load the pretrained model.
19
+ """
20
+ return yolo.load_pretrained_model(model_str)
21
+
22
+
23
+ def train(
24
+ model: YOLO,
25
+ data_yaml_path: Path,
26
+ params: dict,
27
+ project: Path = Path("data/04_models/yolo/"),
28
+ experiment_name: str = "train",
29
+ ):
30
+ """Main function for running a train run. It saves the results
31
+ under `project / experiment_name`.
32
+
33
+ Args:
34
+ model (YOLO): result of `load_pretrained_model`.
35
+ data_yaml_path (Path): filepath to the data.yaml file that specifies the split and classes to train on
36
+ params (dict): parameters to override when running the training. See https://docs.ultralytics.com/modes/train/#train-settings for a complete list of parameters.
37
+ project (Path): root path to store the run artifacts and results.
38
+ experiment_name (str): name of the experiment, that is added to the project root path to store the run.
39
+ """
40
+ return yolo.train(
41
+ model=model,
42
+ data_yaml_path=data_yaml_path,
43
+ params=params,
44
+ project=project,
45
+ experiment_name=experiment_name,
46
+ )
47
+
48
+
49
+ def predict(model: YOLO, pil_image: Image.Image) -> dict[str, Any]:
50
+ """
51
+ Given a loaded model an a PIL image, it returns a map
52
+ containing the segmentation predictions.
53
+
54
+ Args:
55
+ model (YOLO): loaded YOLO model for segmentation.
56
+ pil_image (PIL): image to run the model on.
57
+
58
+ Returns:
59
+ prediction: Raw prediction from the model.
60
+ orig_image: original image used for inference
61
+ after preprocessing stages applied by
62
+ ultralytics.
63
+ mask (PIL): postprocessed mask in white and black format - used for visualization
64
+ mask_raw (np.ndarray): Raw mask not postprocessed
65
+ masked (PIL): mask applied to the pil_image.
66
+ """
67
+ predictions = model(pil_image)
68
+ mask_raw = predictions[0].masks[0].data.cpu().numpy().transpose(1, 2, 0).squeeze()
69
+ # Convert single channel grayscale to 3 channel image
70
+ mask_3channel = cv2.merge((mask_raw, mask_raw, mask_raw))
71
+ # Get the size of the original image (height, width, channels)
72
+ h2, w2, c2 = predictions[0].orig_img.shape
73
+ # Resize the mask to the same size as the image (can probably be removed if image is the same size as the model)
74
+ mask = cv2.resize(mask_3channel, (w2, h2))
75
+ # Convert BGR to HSV
76
+ hsv = cv2.cvtColor(mask, cv2.COLOR_BGR2HSV)
77
+
78
+ # Define range of brightness in HSV
79
+ lower_black = np.array([0, 0, 0])
80
+ upper_black = np.array([0, 0, 1])
81
+
82
+ # Create a mask. Threshold the HSV image to get everything black
83
+ mask = cv2.inRange(mask, lower_black, upper_black)
84
+
85
+ # Invert the mask to get everything but black
86
+ mask = cv2.bitwise_not(mask)
87
+
88
+ # Apply the mask to the original image
89
+ masked = cv2.bitwise_and(
90
+ predictions[0].orig_img,
91
+ predictions[0].orig_img,
92
+ mask=mask,
93
+ )
94
+
95
+ # bgr to rgb and PIL conversion
96
+ image_output2 = Image.fromarray(masked[:, :, ::-1])
97
+ # return Image.fromarray(mask), image_output2
98
+ return {
99
+ "prediction": predictions[0],
100
+ "mask": Image.fromarray(mask),
101
+ "mask_raw": mask_raw,
102
+ "masked": Image.fromarray(masked[:, :, ::-1]),
103
+ }
ui.py ADDED
File without changes
utils.py ADDED
@@ -0,0 +1,659 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from enum import Enum
3
+ from pathlib import Path
4
+ from typing import Tuple
5
+
6
+ import matplotlib.pyplot as plt
7
+ import numpy as np
8
+ import torch
9
+ from lightglue import ALIKED, DISK, SIFT, LightGlue, SuperPoint
10
+ from PIL import Image
11
+ from scipy.stats import wasserstein_distance
12
+
13
+
14
+ def select_best_device() -> torch.device:
15
+ """
16
+ Select best available device (cpu or cuda) based on availability.
17
+ """
18
+ if torch.cuda.is_available():
19
+ return torch.device("cuda")
20
+ else:
21
+ return torch.device("cpu")
22
+
23
+
24
+ def bgr_to_rgb(a: np.ndarray) -> np.ndarray:
25
+ """
26
+ Turn a BGR numpy array into a RGB numpy array.
27
+ """
28
+ return a[:, :, ::-1]
29
+
30
+
31
+ ALLOWED_EXTRACTOR_TYPES = ["sift", "disk", "superpoint", "aliked"]
32
+
33
+
34
+ def extractor_type_to_extractor(
35
+ device: torch.device,
36
+ extractor_type: str,
37
+ n_keypoints: int = 1024,
38
+ ):
39
+ """
40
+ Given an extractor_type in {'sift', 'superpoint', 'aliked', 'disk'},
41
+ returns a LightGlue extractor.
42
+
43
+ Args:
44
+ device (torch.device): cpu or cuda
45
+ extractor_type (str): in {sift, superpoint, aliked, disk}
46
+ n_keypoints (int): number of max keypoints to generate with the
47
+ extractor. The higher the better accuracy but the longer.
48
+
49
+ Returns:
50
+ LigthGlueExtractor: ALIKED | DISK | SIFT | SuperPoint
51
+
52
+ Raises:
53
+ AssertionError: when the n_keypoints are outside the valid range
54
+ 0..5000
55
+ AssertionError: when extractor_type is not valid
56
+ """
57
+ assert 0 <= n_keypoints <= 5000, "n_keypoints should be in range 0..5000"
58
+ assert (
59
+ extractor_type in ALLOWED_EXTRACTOR_TYPES
60
+ ), f"extractor type {extractor_type} should be in {ALLOWED_EXTRACTOR_TYPES}."
61
+
62
+ if extractor_type == "sift":
63
+ return SIFT(max_num_keypoints=n_keypoints).eval().to(device)
64
+ elif extractor_type == "superpoint":
65
+ return SuperPoint(max_num_keypoints=n_keypoints).eval().to(device)
66
+ elif extractor_type == "disk":
67
+ return DISK(max_num_keypoints=n_keypoints).eval().to(device)
68
+ elif extractor_type == "aliked":
69
+ return ALIKED(max_num_keypoints=n_keypoints).eval().to(device)
70
+ else:
71
+ raise Exception("extractor_type is not valid")
72
+
73
+
74
+ def extractor_type_to_matcher(device: torch.device, extractor_type: str) -> LightGlue:
75
+ """
76
+ Return the LightGlue matcher given an `extractor_type`.
77
+
78
+ Args:
79
+ device (torch.device): cpu or cuda
80
+ extractor_type (str): in {sift, superpoint, aliked, disk}
81
+
82
+ Returns:
83
+ LightGlue Matcher
84
+ """
85
+ assert (
86
+ extractor_type in ALLOWED_EXTRACTOR_TYPES
87
+ ), f"extractor type {extractor_type} should be in {ALLOWED_EXTRACTOR_TYPES}."
88
+ return LightGlue(features=extractor_type).eval().to(device)
89
+
90
+
91
+ def get_scores(matches: dict[str, torch.Tensor]) -> np.ndarray:
92
+ """
93
+ Given a `matches` dict from the LightGlue matcher output, it returns the
94
+ scores as a numpy array.
95
+ """
96
+ return matches["matching_scores0"][0].to("cpu").numpy()
97
+
98
+
99
+ def wasserstein(scores: np.ndarray) -> float:
100
+ """
101
+ Return the Wasserstein distance of the scores against the null
102
+ distribution.
103
+ The greater the distance, the farther away it is from the null
104
+ distribution.
105
+ """
106
+ x_null_distribution = [0.0] * 1024
107
+ return wasserstein_distance(x_null_distribution, scores).item()
108
+
109
+
110
+ class PictureLayout(Enum):
111
+ """
112
+ Layout of a picture.
113
+ """
114
+
115
+ PORTRAIT = "portrait"
116
+ LANDSCAPE = "landscape"
117
+ SQUARE = "square"
118
+
119
+
120
+ def crop(
121
+ pil_image: Image.Image,
122
+ box: Tuple[int, int, int, int],
123
+ ) -> Image.Image:
124
+ """
125
+ Crop a pil_image based on the provided rectangle in (x1, y1,
126
+ x2, y2) format - with the upper left corner given first.
127
+ """
128
+ return pil_image.crop(box=box)
129
+
130
+
131
+ def get_picture_layout(pil_image: Image.Image) -> PictureLayout:
132
+ """
133
+ Return the picture layout.
134
+ """
135
+ width, height = pil_image.size
136
+
137
+ if width > height:
138
+ return PictureLayout.LANDSCAPE
139
+ elif width == height:
140
+ return PictureLayout.SQUARE
141
+ else:
142
+ return PictureLayout.PORTRAIT
143
+
144
+
145
+ def get_segmentation_mask_crop_box(
146
+ pil_image_mask: Image.Image,
147
+ padding: int = 0,
148
+ ) -> Tuple[int, int, int, int]:
149
+ """
150
+ Return a crop box for the given pil_image that contains the segmentation mask (black and white).
151
+
152
+ Args:
153
+ pil_image_mask (PIL): image containing the segmentation mask
154
+ padding (int): how much to pad around the segmentation mask.
155
+
156
+ Returns:
157
+ Rectangle (Tuple[int, int, int, int]): 4 tuple representing a rectangle (x1, y1, x2, y2) with the upper left corner given first.
158
+ """
159
+ array_image_mask = np.array(pil_image_mask)
160
+ a = np.where(array_image_mask != 0)
161
+ y_min = np.min(a[0]).item()
162
+ y_max = np.max(a[0]).item()
163
+ x_min = np.min(a[1]).item()
164
+ x_max = np.max(a[1]).item()
165
+ box = (x_min, y_min, x_max, y_max)
166
+ box_with_padding = (
167
+ box[0] - padding,
168
+ box[1] - padding,
169
+ box[2] + padding,
170
+ box[3] + padding,
171
+ )
172
+ return box_with_padding
173
+
174
+
175
+ def scale_keypoints_to_image_size(
176
+ image_width: int,
177
+ image_height: int,
178
+ keypoints_xyn: np.ndarray,
179
+ ) -> np.ndarray:
180
+ """
181
+ Given keypoints in xyn format, it returns new keypoints in xy format.
182
+
183
+ Args:
184
+ image_width (int): width of the image
185
+ image_height (int): height of the image
186
+ keypoints_xyn (np.ndarray): 2D numpy array representing the keypoints
187
+ in xyn format.
188
+
189
+ Returns:
190
+ keypoints_xy (np.ndarray): 2D numpy array representing the keypoints in
191
+ xy format.
192
+ """
193
+ keypoints_xy = keypoints_xyn.copy()
194
+ keypoints_xy[:, 0] = keypoints_xyn[:, 0] * image_width
195
+ keypoints_xy[:, 1] = keypoints_xyn[:, 1] * image_height
196
+ return keypoints_xy
197
+
198
+
199
+ def normalize_keypoints_to_image_size(
200
+ image_width: int,
201
+ image_height: int,
202
+ keypoints_xy: np.ndarray,
203
+ ) -> np.ndarray:
204
+ """
205
+ Given keypoints in xy format, it returns new keypoints in xyn format.
206
+
207
+ Args:
208
+ image_width (int): width of the image
209
+ image_height (int): height of the image
210
+ keypoints_xy (np.ndarray): 2D numpy array representing the keypoints
211
+ in xy format.
212
+
213
+ Returns:
214
+ keypoints_xyn (np.ndarray): 2D numpy array representing the keypoints in
215
+ xyn format.
216
+ """
217
+ keypoints_xyn = keypoints_xy.copy()
218
+ keypoints_xyn[:, 0] = keypoints_xy[:, 0] / image_width
219
+ keypoints_xyn[:, 1] = keypoints_xy[:, 1] / image_height
220
+ return keypoints_xyn
221
+
222
+
223
+ def show_keypoints_xy(
224
+ array_image: np.ndarray,
225
+ keypoints_xy: np.ndarray,
226
+ classes_dictionnary: dict[int, str],
227
+ verbose: bool = True,
228
+ ) -> None:
229
+ """
230
+ Show keypoints on top of an `array_image`, useful in jupyter notebooks for
231
+ instance.
232
+
233
+ Args:
234
+ array_image (np.ndarray): numpy array representing an image.
235
+ keypoints_xy (np.ndarray): 2D numpy array representing the keypoints in
236
+ xy format.
237
+ classes_dictionnary (dict[int, str]): Model prediction classes.
238
+ verbose (bool): should we make the image verbose by adding some label
239
+ for each keypoint?
240
+ """
241
+
242
+ colors = ["r", "g", "b", "c", "m", "y", "w"]
243
+ plt.imshow(array_image)
244
+ label_margin = 20
245
+ height, width, _ = array_image.shape
246
+
247
+ for class_inst, class_name in classes_dictionnary.items():
248
+ color = colors[class_inst]
249
+ x, y = keypoints_xy[class_inst]
250
+ plt.scatter(x=[x], y=[y], c=color)
251
+ if verbose:
252
+ plt.annotate(class_name, (x - label_margin, y - label_margin), c="w")
253
+
254
+
255
+ def draw_keypoints_xy_on_ax(
256
+ ax,
257
+ array_image: np.ndarray,
258
+ keypoints_xy: np.ndarray,
259
+ classes_dictionnary: dict,
260
+ verbose: bool = True,
261
+ ) -> None:
262
+ """
263
+ Dray keypoints on top of an `array_image`, useful in jupyter notebooks for
264
+ instance.
265
+
266
+ Args:
267
+ array_image (np.ndarray): numpy array representing an image.
268
+ keypoints_xy (np.ndarray): 2D numpy array representing the keypoints in
269
+ xy format.
270
+ classes_dictionnary (dict[int, str]): Model prediction classes.
271
+ verbose (bool): should we make the image verbose by adding some label
272
+ for each keypoint?
273
+ """
274
+
275
+ colors = ["r", "g", "b", "c", "m", "y", "w"]
276
+ ax.imshow(array_image)
277
+ label_margin = 20
278
+ height, width, _ = array_image.shape
279
+
280
+ for class_inst, class_name in classes_dictionnary.items():
281
+ color = colors[class_inst]
282
+ x, y = keypoints_xy[class_inst]
283
+ ax.scatter(x=[x], y=[y], c=color)
284
+
285
+ if verbose:
286
+ ax.annotate(class_name, (x - label_margin, y - label_margin), c="w")
287
+
288
+ k_pelvic_fin_base = get_keypoint(
289
+ class_name="pelvic_fin_base",
290
+ keypoints=keypoints_xy,
291
+ classes_dictionnary=classes_dictionnary,
292
+ )
293
+ k_anal_fin_base = get_keypoint(
294
+ class_name="anal_fin_base",
295
+ keypoints=keypoints_xy,
296
+ classes_dictionnary=classes_dictionnary,
297
+ )
298
+
299
+ ax.axline(k_pelvic_fin_base, k_anal_fin_base, c="lime")
300
+
301
+
302
+ def show_keypoints_xyn(
303
+ array_image: np.ndarray,
304
+ keypoints_xyn: np.ndarray,
305
+ classes_dictionnary: dict,
306
+ verbose: bool = True,
307
+ ) -> None:
308
+ """
309
+ Dray keypoints on top of an `array_image`, useful in jupyter notebooks for
310
+ instance.
311
+
312
+ Args:
313
+ array_image (np.ndarray): numpy array representing an image.
314
+ keypoints_xy (np.ndarray): 2D numpy array representing the keypoints in
315
+ xy format.
316
+ classes_dictionnary (dict[int, str]): Model prediction classes.
317
+ verbose (bool): should we make the image verbose by adding some label
318
+ for each keypoint?
319
+ """
320
+ height, width, _ = array_image.shape
321
+ keypoints_xy = scale_keypoints_to_image_size(
322
+ image_height=height,
323
+ image_width=width,
324
+ keypoints_xyn=keypoints_xyn,
325
+ )
326
+ show_keypoints_xy(
327
+ array_image=array_image,
328
+ keypoints_xy=keypoints_xy,
329
+ classes_dictionnary=classes_dictionnary,
330
+ verbose=verbose,
331
+ )
332
+
333
+
334
+ def rotate_point(
335
+ clockwise: bool,
336
+ origin: Tuple[float, float],
337
+ point: Tuple[float, float],
338
+ angle: float,
339
+ ) -> Tuple[float, float]:
340
+ """
341
+ Rotate a point clockwise or counterclockwise by a given angle around a
342
+ given origin.
343
+
344
+ Args:
345
+ clockwise (bool): should the rotation be clockwise?
346
+ origin (Tuple[float, float]): origin 2D point to perform the rotation.
347
+ point (Tuple[float, float]): 2D point to rotate.
348
+ angle (float): angle in radian.
349
+
350
+ Returns:
351
+ rotated_point (Tuple[float, float]): rotated point after applying the
352
+ 2D transformation.
353
+ """
354
+ if clockwise:
355
+ angle = 0 - angle
356
+
357
+ ox, oy = origin
358
+ px, py = point
359
+
360
+ qx = ox + math.cos(angle) * (px - ox) - math.sin(angle) * (py - oy)
361
+ qy = oy + math.sin(angle) * (px - ox) + math.cos(angle) * (py - oy)
362
+
363
+ return qx, qy
364
+
365
+
366
+ def rotate_image(angle_rad: float, array_image: np.ndarray, expand=False) -> np.ndarray:
367
+ """
368
+ Rotate an `array_image` by an angle defined in radians, clockwise using the
369
+ center as origin.
370
+
371
+ Args:
372
+ angle_rad (float): angle in radian.
373
+ array_image (np.ndarray): numpy array representing the image to rotate.
374
+ expand (bool): should we expand the image as we rotate it to not
375
+ truncate some parts of it if the image is not square?
376
+ """
377
+ angle_degrees = math.degrees(angle_rad)
378
+ return np.array(Image.fromarray(array_image).rotate(angle_degrees, expand=expand))
379
+
380
+
381
+ def rotate_keypoints_xy(
382
+ angle_rad: float,
383
+ keypoints_xy: np.ndarray,
384
+ origin: Tuple[float, float],
385
+ clockwise: bool = True,
386
+ ) -> np.ndarray:
387
+ """
388
+ Rotate keypoints by an angle defined in radians, clockwise or
389
+ counterclockwise using the `origin_xyn` point.
390
+
391
+ Args:
392
+ angle_rad (float): angle in radian.
393
+ origin (Tuple[float, float]): origin 2D point to perform the rotation.
394
+ keypoints_xy (np.ndarray): 2D numpy array representing the keypoints in
395
+ xy format.
396
+ clockwise (bool): should the rotation be clockwise?
397
+
398
+ Returns:
399
+ rotated_keypoints_xy (np.ndarray): rotated keypoints in xy format.
400
+ """
401
+ return np.array(
402
+ [
403
+ rotate_point(
404
+ clockwise=clockwise,
405
+ origin=origin,
406
+ point=(kp[0].item(), kp[1].item()),
407
+ angle=angle_rad,
408
+ )
409
+ for kp in keypoints_xy
410
+ ]
411
+ )
412
+
413
+
414
+ def rotate_image_and_keypoints_xy(
415
+ angle_rad: float,
416
+ array_image: np.ndarray,
417
+ keypoints_xy: np.ndarray,
418
+ ) -> dict[str, np.ndarray]:
419
+ """
420
+ Rotate the image and its keypoints provided the parameters.
421
+
422
+ Args:
423
+ angle_rad (float): angle in radian.
424
+ array_image (np.ndarray): numpy array representing the image to rotate.
425
+ keypoints_xy (np.ndarray): 2D numpy array representing the keypoints in
426
+ xy format.
427
+
428
+ Returns:
429
+ array_image (np.ndarray): rotated array_image as a 2D numpy array.
430
+ keypoints_xy (np.ndarray): rotated keypoints in xy format.
431
+ """
432
+ height, width, _ = array_image.shape
433
+ center_x, center_y = int(width / 2), int(height / 2)
434
+ origin = (center_x, center_y)
435
+ image_rotated = rotate_image(angle_rad=angle_rad, array_image=array_image)
436
+ keypoints_xy_rotated = rotate_keypoints_xy(
437
+ angle_rad=angle_rad, keypoints_xy=keypoints_xy, origin=origin, clockwise=True
438
+ )
439
+
440
+ return {
441
+ "array_image": image_rotated,
442
+ "keypoints_xy": keypoints_xy_rotated,
443
+ }
444
+
445
+
446
+ def get_keypoint(
447
+ class_name: str,
448
+ keypoints: np.ndarray,
449
+ classes_dictionnary: dict[int, str],
450
+ ) -> np.ndarray:
451
+ """
452
+ Return the keypoint for the provided `class_name` (eg. eye, front_fin_base, etc).
453
+
454
+ Raises:
455
+ AssertionError: when the provided class_name is not compatible or when the number of keypoints does not match.
456
+ """
457
+ assert (
458
+ class_name in classes_dictionnary.values()
459
+ ), f"class_name should be in {classes_dictionnary.values()}"
460
+ assert len(classes_dictionnary) == len(
461
+ keypoints
462
+ ), "Number of provided keypoints does not match the number of class names"
463
+
464
+ class_name_to_class_inst = {v: k for k, v in classes_dictionnary.items()}
465
+ return keypoints[class_name_to_class_inst[class_name]]
466
+
467
+
468
+ def to_direction_vector(p1: np.ndarray, p2: np.ndarray) -> np.ndarray:
469
+ """
470
+ Return the direction vector between two points p1 and p2.
471
+ """
472
+ assert len(p1) == len(p2), "p1 and p2 should have the same length"
473
+ return p2 - p1
474
+
475
+
476
+ def is_upside_down(
477
+ keypoints_xy: np.ndarray,
478
+ classes_dictionnary: dict[int, str],
479
+ ) -> bool:
480
+ """
481
+ Is the fish upside down?
482
+ """
483
+ k_pelvic_fin_base = get_keypoint(
484
+ class_name="pelvic_fin_base",
485
+ keypoints=keypoints_xy,
486
+ classes_dictionnary=classes_dictionnary,
487
+ )
488
+ k_anal_fin_base = get_keypoint(
489
+ class_name="anal_fin_base",
490
+ keypoints=keypoints_xy,
491
+ classes_dictionnary=classes_dictionnary,
492
+ )
493
+ k_dorsal_fin_base = get_keypoint(
494
+ class_name="dorsal_fin_base",
495
+ keypoints=keypoints_xy,
496
+ classes_dictionnary=classes_dictionnary,
497
+ )
498
+
499
+ print(f"dorsal_fin_base: {k_dorsal_fin_base}")
500
+ print(f"pelvic_fin_base: {k_pelvic_fin_base}")
501
+ print(f"anal_fin_base: {k_anal_fin_base}")
502
+ return (k_dorsal_fin_base[1] > k_pelvic_fin_base[1]).item()
503
+
504
+
505
+ def get_direction_vector(
506
+ keypoints_xy: np.ndarray, classes_dictionnary: dict[int, str]
507
+ ) -> np.ndarray:
508
+ """
509
+ Get the direction vector for the realignment.
510
+ """
511
+ # Align horizontally the fish based on its pelvic fin base and its anal fin base
512
+ k_pelvic_fin_base = get_keypoint(
513
+ class_name="pelvic_fin_base",
514
+ keypoints=keypoints_xy,
515
+ classes_dictionnary=classes_dictionnary,
516
+ )
517
+ k_anal_fin_base = get_keypoint(
518
+ class_name="anal_fin_base",
519
+ keypoints=keypoints_xy,
520
+ classes_dictionnary=classes_dictionnary,
521
+ )
522
+
523
+ return to_direction_vector(
524
+ p1=k_pelvic_fin_base, p2=k_anal_fin_base
525
+ ) # line between the pelvic and anal fins
526
+
527
+
528
+ def get_reference_vector() -> np.ndarray:
529
+ """
530
+ Get the reference vector to align the direction vector to.
531
+ """
532
+ return np.array([1, 0]) # horizontal axis
533
+
534
+
535
+ def get_angle(v1: np.ndarray, v2: np.ndarray) -> float:
536
+ """
537
+ Return the angle (couterclockwise) in radians between vectors v1 and v2.
538
+ """
539
+ cos_theta = (
540
+ np.dot(v1, v2) / np.linalg.norm(v1, ord=2) * np.linalg.norm(v2, ord=2)
541
+ ).item()
542
+ return -math.acos(cos_theta)
543
+
544
+
545
+ def is_aligned(keypoints_xy: np.ndarray, classes_dictionnary: dict[int, str]) -> bool:
546
+ """
547
+ Return wether the keypoints are now properly aligned with the direction
548
+ vector used to make the rotation.
549
+ """
550
+ v1 = get_direction_vector(
551
+ keypoints_xy=keypoints_xy, classes_dictionnary=classes_dictionnary
552
+ )
553
+ v_ref = get_reference_vector()
554
+ theta = get_angle(v1, v_ref)
555
+ return abs(theta) <= 0.001
556
+
557
+
558
+ def get_angle_correction_sign(
559
+ angle_rad: float,
560
+ array_image: np.ndarray,
561
+ keypoints_xy: np.ndarray,
562
+ classes_dictionnary: dict[int, str],
563
+ ) -> int:
564
+ """
565
+ Returns 1 or -1 depending on the angle sign to set.
566
+ """
567
+ rotation_results = rotate_image_and_keypoints_xy(
568
+ angle_rad=angle_rad, array_image=array_image, keypoints_xy=keypoints_xy
569
+ )
570
+ if not is_aligned(
571
+ keypoints_xy=rotation_results["keypoints_xy"],
572
+ classes_dictionnary=classes_dictionnary,
573
+ ):
574
+ return -1
575
+ else:
576
+ return 1
577
+
578
+
579
+ def get_angle_correction(
580
+ keypoints_xy: np.ndarray,
581
+ array_image: np.ndarray,
582
+ classes_dictionnary: dict[int, str],
583
+ ) -> float:
584
+ """
585
+ Get the angle correction in radians that aligns the fish (based on the
586
+ keypoints) horizontally.
587
+ """
588
+ v1 = get_direction_vector(
589
+ keypoints_xy=keypoints_xy, classes_dictionnary=classes_dictionnary
590
+ )
591
+ v_ref = get_reference_vector()
592
+ theta = get_angle(v1, v_ref)
593
+
594
+ angle_sign = get_angle_correction_sign(
595
+ angle_rad=theta,
596
+ array_image=array_image,
597
+ keypoints_xy=keypoints_xy,
598
+ classes_dictionnary=classes_dictionnary,
599
+ )
600
+ theta = angle_sign * theta
601
+ rotation_results = rotate_image_and_keypoints_xy(
602
+ angle_rad=theta, array_image=array_image, keypoints_xy=keypoints_xy
603
+ )
604
+
605
+ # Check whether the fish is upside down
606
+ if is_upside_down(
607
+ keypoints_xy=rotation_results["keypoints_xy"],
608
+ classes_dictionnary=classes_dictionnary,
609
+ ):
610
+ print("the fish is upside down...")
611
+ return theta + math.pi
612
+ else:
613
+ print("The fish is not upside down")
614
+ return theta # No need to rotate the fish more
615
+
616
+
617
+ def show_algorithm_steps(
618
+ image_filepath: Path,
619
+ keypoints_xy: np.ndarray,
620
+ rotation_results: dict,
621
+ theta: float,
622
+ classes_dictionnary: dict,
623
+ ) -> None:
624
+ """
625
+ Display a matplotlib figure that details step by step the result of the rotation.
626
+ Keypoints can be overlayed with the images.
627
+ """
628
+ array_image = np.array(Image.open(image_filepath))
629
+ array_image_final = np.array(
630
+ Image.open(image_filepath).rotate(math.degrees(theta), expand=True)
631
+ )
632
+
633
+ fig, axs = plt.subplots(1, 4, figsize=(20, 4))
634
+ fig.suptitle(f"{image_filepath.name}")
635
+ print(f"image_filepath: {image_filepath}")
636
+
637
+ # Hiding the x and y axis ticks
638
+ for ax in axs:
639
+ ax.xaxis.set_visible(False)
640
+ ax.yaxis.set_visible(False)
641
+
642
+ axs[0].set_title("original")
643
+ axs[0].imshow(array_image)
644
+ axs[1].set_title("predicted keypoints")
645
+ draw_keypoints_xy_on_ax(
646
+ ax=axs[1],
647
+ array_image=array_image,
648
+ keypoints_xy=keypoints_xy,
649
+ classes_dictionnary=classes_dictionnary,
650
+ )
651
+ axs[2].set_title(f"rotation of {math.degrees(theta):.1f} degrees")
652
+ draw_keypoints_xy_on_ax(
653
+ ax=axs[2],
654
+ array_image=rotation_results["array_image"],
655
+ keypoints_xy=rotation_results["keypoints_xy"],
656
+ classes_dictionnary=classes_dictionnary,
657
+ )
658
+ axs[3].set_title("final")
659
+ axs[3].imshow(array_image_final)
viz2d.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 2D visualization primitives based on Matplotlib.
3
+ 1) Plot images with `plot_images`.
4
+ 2) Call `plot_keypoints` or `plot_matches` any number of times.
5
+ 3) Optionally: save a .png or .pdf plot (nice in papers!) with `save_plot`.
6
+ """
7
+
8
+ import io
9
+ from typing import Callable
10
+
11
+ import matplotlib
12
+ import matplotlib.patheffects as path_effects
13
+ import matplotlib.pyplot as plt
14
+ import numpy as np
15
+ import torch
16
+ from PIL import Image
17
+
18
+
19
+ def pyplot_to_pil_image(plot_fn: Callable[..., None]) -> Image.Image:
20
+ """
21
+ Turn a plot_fn side effectful function that uses pyplot into a pil_image by
22
+ writing to an IO buffer.
23
+ """
24
+ plot_fn()
25
+ buf = io.BytesIO()
26
+ plt.savefig(buf, format="png")
27
+ buf.seek(0) # Move to the beginning of the BytesIO buffer
28
+ return Image.open(buf)
29
+
30
+
31
+ def cm_RdGn(x):
32
+ """Custom colormap: red (0) -> yellow (0.5) -> green (1)."""
33
+ x = np.clip(x, 0, 1)[..., None] * 2
34
+ c = x * np.array([[0, 1.0, 0]]) + (2 - x) * np.array([[1.0, 0, 0]])
35
+ return np.clip(c, 0, 1)
36
+
37
+
38
+ def cm_BlRdGn(x_):
39
+ """Custom colormap: blue (-1) -> red (0.0) -> green (1)."""
40
+ x = np.clip(x_, 0, 1)[..., None] * 2
41
+ c = x * np.array([[0, 1.0, 0, 1.0]]) + (2 - x) * np.array([[1.0, 0, 0, 1.0]])
42
+
43
+ xn = -np.clip(x_, -1, 0)[..., None] * 2
44
+ cn = xn * np.array([[0, 0.1, 1, 1.0]]) + (2 - xn) * np.array([[1.0, 0, 0, 1.0]])
45
+ out = np.clip(np.where(x_[..., None] < 0, cn, c), 0, 1)
46
+ return out
47
+
48
+
49
+ def cm_prune(x_):
50
+ """Custom colormap to visualize pruning"""
51
+ if isinstance(x_, torch.Tensor):
52
+ x_ = x_.cpu().numpy()
53
+ max_i = max(x_)
54
+ norm_x = np.where(x_ == max_i, -1, (x_ - 1) / 9)
55
+ return cm_BlRdGn(norm_x)
56
+
57
+
58
+ def matches_as_pil_image(
59
+ torch_images,
60
+ feats0,
61
+ feats1,
62
+ matches01,
63
+ mode: str = "column",
64
+ ) -> Image.Image:
65
+ """
66
+ Generate a PIL image outlining the keypoints from `feats0` and `feats1` and
67
+ how they match.
68
+ Overlay it on the torch_images.
69
+ """
70
+
71
+ def plot_fn():
72
+ kpts0, kpts1, matches = (
73
+ feats0["keypoints"],
74
+ feats1["keypoints"],
75
+ matches01["matches"],
76
+ )
77
+ m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]
78
+
79
+ axes = plot_images(imgs=torch_images, mode=mode)
80
+ plot_matches(m_kpts0, m_kpts1, color="lime", lw=0.2)
81
+
82
+ return pyplot_to_pil_image(plot_fn=plot_fn)
83
+
84
+
85
+ def keypoints_as_pil_image(
86
+ torch_image: torch.Tensor,
87
+ feats: dict[str, torch.Tensor],
88
+ color: str = "blue",
89
+ ps: int = 10,
90
+ ) -> Image.Image:
91
+ """
92
+ Generate a PIL image outlining the keypoints from `feats` and overlay it on
93
+ the torch_image.
94
+ """
95
+
96
+ def plot_fn():
97
+ kpts = feats["keypoints"]
98
+ plot_images([torch_image])
99
+ plot_keypoints(kpts=[kpts], colors=color, ps=ps)
100
+
101
+ return pyplot_to_pil_image(plot_fn=plot_fn)
102
+
103
+
104
+ def matching_keypoints_as_pil_image(
105
+ torch_images,
106
+ feats0,
107
+ feats1,
108
+ matches01,
109
+ mode: str = "column",
110
+ ) -> Image.Image:
111
+ """
112
+ Generate a PIL image outlining the keypoints from `feats0` and `feats1`.
113
+ Overlay it on the torch_images.
114
+ """
115
+
116
+ def plot_fn():
117
+ kpts0, kpts1 = (
118
+ feats0["keypoints"],
119
+ feats1["keypoints"],
120
+ )
121
+ kpc0, kpc1 = cm_prune(matches01["prune0"]), cm_prune(matches01["prune1"])
122
+ plot_images(torch_images, mode=mode)
123
+ plot_keypoints([kpts0, kpts1], colors=[kpc0, kpc1], ps=10)
124
+
125
+ return pyplot_to_pil_image(plot_fn=plot_fn)
126
+
127
+
128
+ def as_pil_image(
129
+ torch_images,
130
+ feats0,
131
+ feats1,
132
+ matches01,
133
+ mode: str = "column",
134
+ ) -> Image.Image:
135
+ kpts0, kpts1, matches = (
136
+ feats0["keypoints"],
137
+ feats1["keypoints"],
138
+ matches01["matches"],
139
+ )
140
+ m_kpts0, m_kpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]
141
+
142
+ axes = plot_images(imgs=torch_images, mode=mode)
143
+ plot_matches(m_kpts0, m_kpts1, color="lime", lw=0.2)
144
+
145
+ buf = io.BytesIO()
146
+ plt.savefig(buf, format="png")
147
+
148
+ buf.seek(0) # Move to the beginning of the BytesIO buffer
149
+
150
+ # Step 3: Open the image with PIL
151
+ image = Image.open(buf)
152
+ return image
153
+
154
+
155
+ def plot_images(
156
+ imgs,
157
+ titles=None,
158
+ cmaps="gray",
159
+ dpi=100,
160
+ pad=0.5,
161
+ adaptive=True,
162
+ mode: str = "column",
163
+ ):
164
+ """Plot a set of images horizontally.
165
+ Args:
166
+ imgs: list of NumPy RGB (H, W, 3) or PyTorch RGB (3, H, W) or mono (H, W).
167
+ titles: a list of strings, as titles for each image.
168
+ cmaps: colormaps for monochrome images.
169
+ adaptive: whether the figure size should fit the image aspect ratios.
170
+ mode (str): value in {column, row}
171
+ """
172
+ assert mode in [
173
+ "column",
174
+ "row",
175
+ ], f"mode is not valid, should be in ['column', 'row']."
176
+
177
+ # conversion to (H, W, 3) for torch.Tensor
178
+ imgs = [
179
+ (
180
+ img.permute(1, 2, 0).cpu().numpy()
181
+ if (isinstance(img, torch.Tensor) and img.dim() == 3)
182
+ else img
183
+ )
184
+ for img in imgs
185
+ ]
186
+
187
+ n = len(imgs)
188
+ if not isinstance(cmaps, (list, tuple)):
189
+ cmaps = [cmaps] * n
190
+
191
+ if adaptive:
192
+ ratios = [i.shape[1] / i.shape[0] for i in imgs] # W / H
193
+ elif mode == "row":
194
+ ratios = [4 / 3] * n
195
+ elif mode == "column":
196
+ ratios = [1 / 3] * n
197
+ else:
198
+ ratios = [4 / 3] * n
199
+
200
+ if mode == "column":
201
+ figsize = [10, 5]
202
+ fig, ax = plt.subplots(
203
+ n, 1, figsize=figsize, dpi=dpi, gridspec_kw={"height_ratios": ratios}
204
+ )
205
+ elif mode == "row":
206
+ figsize = [sum(ratios) * 4.5, 4.5]
207
+ fig, ax = plt.subplots(
208
+ 1, n, figsize=figsize, dpi=dpi, gridspec_kw={"width_ratios": ratios}
209
+ )
210
+
211
+ if n == 1:
212
+ ax = [ax]
213
+ for i in range(n):
214
+ ax[i].imshow(imgs[i], cmap=plt.get_cmap(cmaps[i]))
215
+ ax[i].get_yaxis().set_ticks([])
216
+ ax[i].get_xaxis().set_ticks([])
217
+ ax[i].set_axis_off()
218
+ for spine in ax[i].spines.values(): # remove frame
219
+ spine.set_visible(False)
220
+ if titles:
221
+ ax[i].set_title(titles[i])
222
+
223
+ fig.tight_layout(pad=pad)
224
+
225
+
226
+ def plot_keypoints(kpts, colors="lime", ps=4, axes=None, a=1.0):
227
+ """Plot keypoints for existing images.
228
+ Args:
229
+ kpts: list of ndarrays of size (N, 2).
230
+ colors: string, or list of list of tuples (one for each keypoints).
231
+ ps: size of the keypoints as float.
232
+ """
233
+ if not isinstance(colors, list):
234
+ colors = [colors] * len(kpts)
235
+ if not isinstance(a, list):
236
+ a = [a] * len(kpts)
237
+ if axes is None:
238
+ axes = plt.gcf().axes
239
+ for ax, k, c, alpha in zip(axes, kpts, colors, a):
240
+ if isinstance(k, torch.Tensor):
241
+ k = k.cpu().numpy()
242
+ ax.scatter(k[:, 0], k[:, 1], c=c, s=ps, linewidths=0, alpha=alpha)
243
+
244
+
245
+ def plot_matches(kpts0, kpts1, color=None, lw=1.5, ps=4, a=1.0, labels=None, axes=None):
246
+ """Plot matches for a pair of existing images.
247
+ Args:
248
+ kpts0, kpts1: corresponding keypoints of size (N, 2).
249
+ color: color of each match, string or RGB tuple. Random if not given.
250
+ lw: width of the lines.
251
+ ps: size of the end points (no endpoint if ps=0)
252
+ indices: indices of the images to draw the matches on.
253
+ a: alpha opacity of the match lines.
254
+ """
255
+ fig = plt.gcf()
256
+ if axes is None:
257
+ ax = fig.axes
258
+ ax0, ax1 = ax[0], ax[1]
259
+ else:
260
+ ax0, ax1 = axes
261
+ if isinstance(kpts0, torch.Tensor):
262
+ kpts0 = kpts0.cpu().numpy()
263
+ if isinstance(kpts1, torch.Tensor):
264
+ kpts1 = kpts1.cpu().numpy()
265
+ assert len(kpts0) == len(kpts1)
266
+ if color is None:
267
+ color = matplotlib.cm.hsv(np.random.rand(len(kpts0))).tolist()
268
+ elif len(color) > 0 and not isinstance(color[0], (tuple, list)):
269
+ color = [color] * len(kpts0)
270
+
271
+ if lw > 0:
272
+ for i in range(len(kpts0)):
273
+ line = matplotlib.patches.ConnectionPatch(
274
+ xyA=(kpts0[i, 0], kpts0[i, 1]),
275
+ xyB=(kpts1[i, 0], kpts1[i, 1]),
276
+ coordsA=ax0.transData,
277
+ coordsB=ax1.transData,
278
+ axesA=ax0,
279
+ axesB=ax1,
280
+ zorder=1,
281
+ color=color[i],
282
+ linewidth=lw,
283
+ clip_on=True,
284
+ alpha=a,
285
+ label=None if labels is None else labels[i],
286
+ picker=5.0,
287
+ )
288
+ line.set_annotation_clip(True)
289
+ fig.add_artist(line)
290
+
291
+ # freeze the axes to prevent the transform to change
292
+ ax0.autoscale(enable=False)
293
+ ax1.autoscale(enable=False)
294
+
295
+ if ps > 0:
296
+ ax0.scatter(kpts0[:, 0], kpts0[:, 1], c=color, s=ps)
297
+ ax1.scatter(kpts1[:, 0], kpts1[:, 1], c=color, s=ps)
298
+
299
+
300
+ def add_text(
301
+ idx,
302
+ text,
303
+ pos=(0.01, 0.99),
304
+ fs=15,
305
+ color="w",
306
+ lcolor="k",
307
+ lwidth=2,
308
+ ha="left",
309
+ va="top",
310
+ ):
311
+ ax = plt.gcf().axes[idx]
312
+ t = ax.text(
313
+ *pos, text, fontsize=fs, ha=ha, va=va, color=color, transform=ax.transAxes
314
+ )
315
+ if lcolor is not None:
316
+ t.set_path_effects(
317
+ [
318
+ path_effects.Stroke(linewidth=lwidth, foreground=lcolor),
319
+ path_effects.Normal(),
320
+ ]
321
+ )
322
+
323
+
324
+ def save_plot(path, **kw):
325
+ """Save the current figure without any white margin."""
326
+ plt.savefig(path, bbox_inches="tight", pad_inches=0, **kw)
yolo.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generic helper functions to interact with the ultralytics yolo
3
+ models.
4
+ """
5
+
6
+ from pathlib import Path
7
+
8
+ from ultralytics import YOLO
9
+
10
+
11
+ def load_pretrained_model(model_str: str) -> YOLO:
12
+ """Loads the pretrained `model`"""
13
+ return YOLO(model_str)
14
+
15
+
16
+ DEFAULT_TRAIN_PARAMS = {
17
+ "batch": 16,
18
+ "epochs": 100,
19
+ "patience": 100,
20
+ "imgsz": 640,
21
+ "lr0": 0.01,
22
+ "lrf": 0.01,
23
+ "optimizer": "auto",
24
+ # data augmentation
25
+ "mixup": 0.0,
26
+ "close_mosaic": 10,
27
+ "degrees": 0.0,
28
+ "translate": 0.1,
29
+ "flipud": 0.0,
30
+ "fliplr": 0.5,
31
+ }
32
+
33
+
34
+ def train(
35
+ model: YOLO,
36
+ data_yaml_path: Path,
37
+ params: dict,
38
+ project: Path = Path("data/04_models/yolo/"),
39
+ experiment_name: str = "train",
40
+ ):
41
+ """Main function for running a train run. It saves the results
42
+ under `project / experiment_name`.
43
+
44
+ Args:
45
+ model (YOLO): result of `load_pretrained_model`.
46
+ data_yaml_path (Path): filepath to the data.yaml file that specifies the split and classes to train on
47
+ params (dict): parameters to override when running the training. See https://docs.ultralytics.com/modes/train/#train-settings for a complete list of parameters.
48
+ project (Path): root path to store the run artifacts and results.
49
+ experiment_name (str): name of the experiment, that is added to the project root path to store the run.
50
+ """
51
+ assert data_yaml_path.exists(), f"data_yaml_path does not exist, {data_yaml_path}"
52
+ params = {**DEFAULT_TRAIN_PARAMS, **params}
53
+ model.train(
54
+ project=str(project),
55
+ name=experiment_name,
56
+ data=data_yaml_path.absolute(),
57
+ epochs=params["epochs"],
58
+ lr0=params["lr0"],
59
+ lrf=params["lrf"],
60
+ optimizer=params["optimizer"],
61
+ imgsz=params["imgsz"],
62
+ close_mosaic=params["close_mosaic"],
63
+ # Data Augmentation parameters
64
+ mixup=params["mixup"],
65
+ degrees=params["degrees"],
66
+ flipud=params["flipud"],
67
+ translate=params["translate"],
68
+ )