Update script.py
Browse files
script.py
CHANGED
@@ -55,10 +55,11 @@ class PytorchWorker:
|
|
55 |
return logits.tolist()
|
56 |
|
57 |
|
58 |
-
def make_submission(test_metadata,
|
59 |
"""Make submission with given """
|
60 |
-
|
61 |
-
|
|
|
62 |
|
63 |
predictions = []
|
64 |
|
@@ -66,8 +67,15 @@ def make_submission(test_metadata, model_path, model_name, output_csv_path="./su
|
|
66 |
image_path = os.path.join(images_root_path, row.filename)
|
67 |
|
68 |
test_image = Image.open(image_path).convert("RGB")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
-
logits =
|
71 |
|
72 |
predictions.append(np.argmax(logits))
|
73 |
|
@@ -85,7 +93,11 @@ if __name__ == "__main__":
|
|
85 |
zip_ref.extractall("/tmp/data")
|
86 |
|
87 |
# MODEL_PATH = "pytorch_model.bin"
|
88 |
-
MODEL_PATH = "
|
|
|
|
|
|
|
|
|
89 |
# MODEL_NAME = "tf_efficientnet_b1.ap_in1k"
|
90 |
MODEL_NAME = "swinv2_tiny_window16_256.ms_in1k"
|
91 |
|
@@ -95,8 +107,7 @@ if __name__ == "__main__":
|
|
95 |
|
96 |
make_submission(
|
97 |
test_metadata=test_metadata,
|
98 |
-
|
99 |
model_name=MODEL_NAME,
|
100 |
# images_root_path='/home/zeleznyt/mnt/data-ntis/projects/korpusy_cv/SnakeCLEF2024/val/SnakeCLEF2023-medium_size'
|
101 |
)
|
102 |
-
s
|
|
|
55 |
return logits.tolist()
|
56 |
|
57 |
|
58 |
+
def make_submission(test_metadata, model_paths, model_name, output_csv_path="./submission.csv", images_root_path="/tmp/data/private_testset"):
|
59 |
"""Make submission with given """
|
60 |
+
models = []
|
61 |
+
for m in model_paths:
|
62 |
+
models.append(PytorchWorker(m, model_name))
|
63 |
|
64 |
predictions = []
|
65 |
|
|
|
67 |
image_path = os.path.join(images_root_path, row.filename)
|
68 |
|
69 |
test_image = Image.open(image_path).convert("RGB")
|
70 |
+
flipped_image = test_image.transpose(Image.FLIP_LEFT_RIGHT)
|
71 |
+
|
72 |
+
result_logits = []
|
73 |
+
|
74 |
+
for model in models:
|
75 |
+
result_logits += model.predict_image(test_image)
|
76 |
+
# result_logits += model.predict_image(flipped_image)
|
77 |
|
78 |
+
logits = np.average(np.array(result_logits), 0)
|
79 |
|
80 |
predictions.append(np.argmax(logits))
|
81 |
|
|
|
93 |
zip_ref.extractall("/tmp/data")
|
94 |
|
95 |
# MODEL_PATH = "pytorch_model.bin"
|
96 |
+
MODEL_PATH = ["2405_PSCmetric2_0001_best_accuracy.pth",
|
97 |
+
"2405_PSCmetric_3x_best_accuracy.pth",
|
98 |
+
"2405_PSCmetric_real3_best_accuracy.pth",
|
99 |
+
"2405_cls_boost_best_accuracy.pth"
|
100 |
+
]
|
101 |
# MODEL_NAME = "tf_efficientnet_b1.ap_in1k"
|
102 |
MODEL_NAME = "swinv2_tiny_window16_256.ms_in1k"
|
103 |
|
|
|
107 |
|
108 |
make_submission(
|
109 |
test_metadata=test_metadata,
|
110 |
+
model_paths=MODEL_PATH,
|
111 |
model_name=MODEL_NAME,
|
112 |
# images_root_path='/home/zeleznyt/mnt/data-ntis/projects/korpusy_cv/SnakeCLEF2024/val/SnakeCLEF2023-medium_size'
|
113 |
)
|
|