Spaces:
Runtime error
Runtime error
Upload folder using huggingface_hub
Browse files- .dockerignore +3 -0
- .gitattributes +3 -33
- .gitignore +6 -0
- .idea/.gitignore +8 -0
- .idea/GranaMeasure_interface-main.iml +10 -0
- .idea/inspectionProfiles/Project_Default.xml +101 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +4 -0
- .idea/modules.xml +8 -0
- .idea/other.xml +6 -0
- .idea/vcs.xml +6 -0
- .idea/workspace.xml +157 -0
- Dockerfile +27 -0
- LICENSE +21 -0
- README.md +4 -8
- angle_calculation/angle_model.py +444 -0
- angle_calculation/classic.py +349 -0
- angle_calculation/envelope_correction.py +34 -0
- angle_calculation/granum_utils.py +80 -0
- angle_calculation/image_transforms.py +34 -0
- angle_calculation/sampling.py +142 -0
- app.py +602 -0
- grana_detection/mmwrapper.py +42 -0
- model.py +629 -0
- period_calculation/config.py +19 -0
- period_calculation/data_reader.py +861 -0
- period_calculation/image_transforms.py +79 -0
- period_calculation/models/abstract_model.py +61 -0
- period_calculation/models/gauss_model.py +237 -0
- period_calculation/period_measurer.py +54 -0
- requirements.txt +11 -0
- settings.py +1 -0
- styles.css +47 -0
- weights/AS_square_v16.ckpt +3 -0
- weights/model_weights_detector.pt +3 -0
- weights/period_measurer_weights-1.298_real_full-fa12970.ckpt +3 -0
- weights/yolo/20240604_yolov8_segm_ABRCR1_all_train4_best.pt +3 -0
- weights/yolo/current_yolo.pt +3 -0
.dockerignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
out/
|
2 |
+
grana/
|
3 |
+
.*
|
.gitattributes
CHANGED
@@ -1,35 +1,5 @@
|
|
1 |
-
*.
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
|
25 |
-
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
*.pt filter=lfs diff=lfs merge=lfs -text
|
3 |
*.pth filter=lfs diff=lfs merge=lfs -text
|
4 |
+
weights/AS_square_v16.ckpt filter=lfs diff=lfs merge=lfs -text
|
5 |
+
weights/period_measurer_weights-1.298_real_full-fa12970.ckpt filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.pyc
|
2 |
+
**/.ipynb_checkpoints/
|
3 |
+
venv/
|
4 |
+
results_*
|
5 |
+
out/
|
6 |
+
grana/
|
.idea/.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Default ignored files
|
2 |
+
/shelf/
|
3 |
+
/workspace.xml
|
4 |
+
# Datasource local storage ignored files
|
5 |
+
/dataSources/
|
6 |
+
/dataSources.local.xml
|
7 |
+
# Editor-based HTTP Client requests
|
8 |
+
/httpRequests/
|
.idea/GranaMeasure_interface-main.iml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<module type="PYTHON_MODULE" version="4">
|
3 |
+
<component name="NewModuleRootManager">
|
4 |
+
<content url="file://$MODULE_DIR$">
|
5 |
+
<excludeFolder url="file://$MODULE_DIR$/venv" />
|
6 |
+
</content>
|
7 |
+
<orderEntry type="jdk" jdkName="Python 3.10 (GranaMeasure_interface-main)" jdkType="Python SDK" />
|
8 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
9 |
+
</component>
|
10 |
+
</module>
|
.idea/inspectionProfiles/Project_Default.xml
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<profile version="1.0">
|
3 |
+
<option name="myName" value="Project Default" />
|
4 |
+
<inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
5 |
+
<Languages>
|
6 |
+
<language minSize="253" name="Python" />
|
7 |
+
</Languages>
|
8 |
+
</inspection_tool>
|
9 |
+
<inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
|
10 |
+
<inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
|
11 |
+
<option name="ignoredPackages">
|
12 |
+
<value>
|
13 |
+
<list size="73">
|
14 |
+
<item index="0" class="java.lang.String" itemvalue="pandas" />
|
15 |
+
<item index="1" class="java.lang.String" itemvalue="PyYAML" />
|
16 |
+
<item index="2" class="java.lang.String" itemvalue="django-polymorphic" />
|
17 |
+
<item index="3" class="java.lang.String" itemvalue="dacite" />
|
18 |
+
<item index="4" class="java.lang.String" itemvalue="django-recaptcha" />
|
19 |
+
<item index="5" class="java.lang.String" itemvalue="python-dateutil" />
|
20 |
+
<item index="6" class="java.lang.String" itemvalue="psycopg2-binary" />
|
21 |
+
<item index="7" class="java.lang.String" itemvalue="python-dotenv" />
|
22 |
+
<item index="8" class="java.lang.String" itemvalue="MarkupSafe" />
|
23 |
+
<item index="9" class="java.lang.String" itemvalue="astroid" />
|
24 |
+
<item index="10" class="java.lang.String" itemvalue="kombu" />
|
25 |
+
<item index="11" class="java.lang.String" itemvalue="django-extensions" />
|
26 |
+
<item index="12" class="java.lang.String" itemvalue="docopt" />
|
27 |
+
<item index="13" class="java.lang.String" itemvalue="sentry-sdk" />
|
28 |
+
<item index="14" class="java.lang.String" itemvalue="django-tinymce" />
|
29 |
+
<item index="15" class="java.lang.String" itemvalue="django-braces" />
|
30 |
+
<item index="16" class="java.lang.String" itemvalue="umpy" />
|
31 |
+
<item index="17" class="java.lang.String" itemvalue="gpxpy" />
|
32 |
+
<item index="18" class="java.lang.String" itemvalue="docutils" />
|
33 |
+
<item index="19" class="java.lang.String" itemvalue="lxml" />
|
34 |
+
<item index="20" class="java.lang.String" itemvalue="Markdown" />
|
35 |
+
<item index="21" class="java.lang.String" itemvalue="django-formtools" />
|
36 |
+
<item index="22" class="java.lang.String" itemvalue="django-celery-beat" />
|
37 |
+
<item index="23" class="java.lang.String" itemvalue="django-crispy-forms" />
|
38 |
+
<item index="24" class="java.lang.String" itemvalue="pylibmc" />
|
39 |
+
<item index="25" class="java.lang.String" itemvalue="tablib" />
|
40 |
+
<item index="26" class="java.lang.String" itemvalue="django-import-export" />
|
41 |
+
<item index="27" class="java.lang.String" itemvalue="gunicorn" />
|
42 |
+
<item index="28" class="java.lang.String" itemvalue="simplejson" />
|
43 |
+
<item index="29" class="java.lang.String" itemvalue="dataclasses-json" />
|
44 |
+
<item index="30" class="java.lang.String" itemvalue="django-hstore" />
|
45 |
+
<item index="31" class="java.lang.String" itemvalue="django-mptt" />
|
46 |
+
<item index="32" class="java.lang.String" itemvalue="boto3" />
|
47 |
+
<item index="33" class="java.lang.String" itemvalue="django-extra-views" />
|
48 |
+
<item index="34" class="java.lang.String" itemvalue="django-filter" />
|
49 |
+
<item index="35" class="java.lang.String" itemvalue="mock" />
|
50 |
+
<item index="36" class="java.lang.String" itemvalue="django-allauth" />
|
51 |
+
<item index="37" class="java.lang.String" itemvalue="django-taggit" />
|
52 |
+
<item index="38" class="java.lang.String" itemvalue="supervisor" />
|
53 |
+
<item index="39" class="java.lang.String" itemvalue="django-auth-ldap" />
|
54 |
+
<item index="40" class="java.lang.String" itemvalue="djangorestframework-gis" />
|
55 |
+
<item index="41" class="java.lang.String" itemvalue="frictionless" />
|
56 |
+
<item index="42" class="java.lang.String" itemvalue="django-bulk-update" />
|
57 |
+
<item index="43" class="java.lang.String" itemvalue="requests" />
|
58 |
+
<item index="44" class="java.lang.String" itemvalue="django-storages" />
|
59 |
+
<item index="45" class="java.lang.String" itemvalue="numpy" />
|
60 |
+
<item index="46" class="java.lang.String" itemvalue="rest-pandas" />
|
61 |
+
<item index="47" class="java.lang.String" itemvalue="Jinja2" />
|
62 |
+
<item index="48" class="java.lang.String" itemvalue="drf-yasg" />
|
63 |
+
<item index="49" class="java.lang.String" itemvalue="sqlparse" />
|
64 |
+
<item index="50" class="java.lang.String" itemvalue="docker" />
|
65 |
+
<item index="51" class="java.lang.String" itemvalue="celery" />
|
66 |
+
<item index="52" class="java.lang.String" itemvalue="ipdb" />
|
67 |
+
<item index="53" class="java.lang.String" itemvalue="akismet" />
|
68 |
+
<item index="54" class="java.lang.String" itemvalue="djangorestframework" />
|
69 |
+
<item index="55" class="java.lang.String" itemvalue="billiard" />
|
70 |
+
<item index="56" class="java.lang.String" itemvalue="funcy" />
|
71 |
+
<item index="57" class="java.lang.String" itemvalue="six" />
|
72 |
+
<item index="58" class="java.lang.String" itemvalue="django-celery-results" />
|
73 |
+
<item index="59" class="java.lang.String" itemvalue="amqp" />
|
74 |
+
<item index="60" class="java.lang.String" itemvalue="ipython" />
|
75 |
+
<item index="61" class="java.lang.String" itemvalue="ffmpeg-python" />
|
76 |
+
<item index="62" class="java.lang.String" itemvalue="django-debug-toolbar" />
|
77 |
+
<item index="63" class="java.lang.String" itemvalue="logilab-common" />
|
78 |
+
<item index="64" class="java.lang.String" itemvalue="pykwalify" />
|
79 |
+
<item index="65" class="java.lang.String" itemvalue="django-grappelli" />
|
80 |
+
<item index="66" class="java.lang.String" itemvalue="watchdog" />
|
81 |
+
<item index="67" class="java.lang.String" itemvalue="Sphinx" />
|
82 |
+
<item index="68" class="java.lang.String" itemvalue="azure-storage-blob" />
|
83 |
+
<item index="69" class="java.lang.String" itemvalue="Django" />
|
84 |
+
<item index="70" class="java.lang.String" itemvalue="django-timezone-field" />
|
85 |
+
<item index="71" class="java.lang.String" itemvalue="pytz" />
|
86 |
+
<item index="72" class="java.lang.String" itemvalue="Pillow" />
|
87 |
+
</list>
|
88 |
+
</value>
|
89 |
+
</option>
|
90 |
+
</inspection_tool>
|
91 |
+
<inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
|
92 |
+
<option name="ignoredErrors">
|
93 |
+
<list>
|
94 |
+
<option value="N803" />
|
95 |
+
<option value="N802" />
|
96 |
+
<option value="N806" />
|
97 |
+
</list>
|
98 |
+
</option>
|
99 |
+
</inspection_tool>
|
100 |
+
</profile>
|
101 |
+
</component>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<component name="InspectionProjectProfileManager">
|
2 |
+
<settings>
|
3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
4 |
+
<version value="1.0" />
|
5 |
+
</settings>
|
6 |
+
</component>
|
.idea/misc.xml
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (GranaMeasure_interface-main)" project-jdk-type="Python SDK" />
|
4 |
+
</project>
|
.idea/modules.xml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ProjectModuleManager">
|
4 |
+
<modules>
|
5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/GranaMeasure_interface-main.iml" filepath="$PROJECT_DIR$/.idea/GranaMeasure_interface-main.iml" />
|
6 |
+
</modules>
|
7 |
+
</component>
|
8 |
+
</project>
|
.idea/other.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="PySciProjectComponent">
|
4 |
+
<option name="PY_SCI_VIEW_SUGGESTED" value="true" />
|
5 |
+
</component>
|
6 |
+
</project>
|
.idea/vcs.xml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="VcsDirectoryMappings">
|
4 |
+
<mapping directory="$PROJECT_DIR$" vcs="Git" />
|
5 |
+
</component>
|
6 |
+
</project>
|
.idea/workspace.xml
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
2 |
+
<project version="4">
|
3 |
+
<component name="ChangeListManager">
|
4 |
+
<list default="true" id="f10c5f0a-4791-498a-9005-10ee84337c97" name="Changes" comment="" />
|
5 |
+
<option name="SHOW_DIALOG" value="false" />
|
6 |
+
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
7 |
+
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
8 |
+
<option name="LAST_RESOLUTION" value="IGNORE" />
|
9 |
+
</component>
|
10 |
+
<component name="FileTemplateManagerImpl">
|
11 |
+
<option name="RECENT_TEMPLATES">
|
12 |
+
<list>
|
13 |
+
<option value="CSS File" />
|
14 |
+
<option value="Python Script" />
|
15 |
+
</list>
|
16 |
+
</option>
|
17 |
+
</component>
|
18 |
+
<component name="Git.Settings">
|
19 |
+
<option name="RECENT_BRANCH_BY_REPOSITORY">
|
20 |
+
<map>
|
21 |
+
<entry key="$PROJECT_DIR$" value="development" />
|
22 |
+
</map>
|
23 |
+
</option>
|
24 |
+
<option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
|
25 |
+
</component>
|
26 |
+
<component name="GitSEFilterConfiguration">
|
27 |
+
<file-type-list>
|
28 |
+
<filtered-out-file-type name="LOCAL_BRANCH" />
|
29 |
+
<filtered-out-file-type name="REMOTE_BRANCH" />
|
30 |
+
<filtered-out-file-type name="TAG" />
|
31 |
+
<filtered-out-file-type name="COMMIT_BY_MESSAGE" />
|
32 |
+
</file-type-list>
|
33 |
+
</component>
|
34 |
+
<component name="ProjectId" id="2bog1ms5jMwIjHWHezXITkymVNk" />
|
35 |
+
<component name="ProjectLevelVcsManager" settingsEditedManually="true" />
|
36 |
+
<component name="ProjectViewState">
|
37 |
+
<option name="hideEmptyMiddlePackages" value="true" />
|
38 |
+
<option name="showLibraryContents" value="true" />
|
39 |
+
</component>
|
40 |
+
<component name="PropertiesComponent">
|
41 |
+
<property name="ASKED_ADD_EXTERNAL_FILES" value="true" />
|
42 |
+
<property name="RunOnceActivity.OpenProjectViewOnStart" value="true" />
|
43 |
+
<property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
|
44 |
+
<property name="SHARE_PROJECT_CONFIGURATION_FILES" value="true" />
|
45 |
+
<property name="WebServerToolWindowFactoryState" value="false" />
|
46 |
+
<property name="last_opened_file_path" value="$USER_HOME$/BirdNET-Analyzer" />
|
47 |
+
<property name="list.type.of.created.stylesheet" value="CSS" />
|
48 |
+
<property name="node.js.detected.package.eslint" value="true" />
|
49 |
+
<property name="node.js.selected.package.eslint" value="(autodetect)" />
|
50 |
+
<property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
|
51 |
+
</component>
|
52 |
+
<component name="RecentsManager">
|
53 |
+
<key name="CopyFile.RECENT_KEYS">
|
54 |
+
<recent name="$PROJECT_DIR$/examples" />
|
55 |
+
<recent name="$PROJECT_DIR$/mock_data" />
|
56 |
+
<recent name="$PROJECT_DIR$" />
|
57 |
+
</key>
|
58 |
+
<key name="MoveFile.RECENT_KEYS">
|
59 |
+
<recent name="$PROJECT_DIR$/images" />
|
60 |
+
</key>
|
61 |
+
</component>
|
62 |
+
<component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
|
63 |
+
<component name="TaskManager">
|
64 |
+
<task active="true" id="Default" summary="Default task">
|
65 |
+
<changelist id="f10c5f0a-4791-498a-9005-10ee84337c97" name="Changes" comment="" />
|
66 |
+
<created>1706886633279</created>
|
67 |
+
<option name="number" value="Default" />
|
68 |
+
<option name="presentableId" value="Default" />
|
69 |
+
<updated>1706886633279</updated>
|
70 |
+
<workItem from="1706886636008" duration="1452000" />
|
71 |
+
<workItem from="1707129545673" duration="18000" />
|
72 |
+
<workItem from="1707475281532" duration="689000" />
|
73 |
+
<workItem from="1707734574375" duration="136000" />
|
74 |
+
<workItem from="1707829113355" duration="7000" />
|
75 |
+
<workItem from="1715267971553" duration="18107000" />
|
76 |
+
<workItem from="1715791673881" duration="83000" />
|
77 |
+
<workItem from="1718701069639" duration="1074000" />
|
78 |
+
<workItem from="1718702835088" duration="2728000" />
|
79 |
+
<workItem from="1718705968353" duration="1911000" />
|
80 |
+
<workItem from="1718708463096" duration="1371000" />
|
81 |
+
<workItem from="1718710203255" duration="3374000" />
|
82 |
+
<workItem from="1718719127184" duration="3910000" />
|
83 |
+
<workItem from="1718729675691" duration="3859000" />
|
84 |
+
<workItem from="1718733940161" duration="2339000" />
|
85 |
+
<workItem from="1718736568891" duration="827000" />
|
86 |
+
<workItem from="1718737438910" duration="21000" />
|
87 |
+
<workItem from="1718740972224" duration="4830000" />
|
88 |
+
<workItem from="1718745952723" duration="1618000" />
|
89 |
+
<workItem from="1718793757299" duration="380000" />
|
90 |
+
<workItem from="1719923572507" duration="1390000" />
|
91 |
+
<workItem from="1720770804165" duration="6223000" />
|
92 |
+
<workItem from="1720780265020" duration="5909000" />
|
93 |
+
<workItem from="1720791565290" duration="1486000" />
|
94 |
+
<workItem from="1721039720633" duration="18000" />
|
95 |
+
<workItem from="1721127968492" duration="8000" />
|
96 |
+
<workItem from="1723193047208" duration="1397000" />
|
97 |
+
<workItem from="1723545275936" duration="11000" />
|
98 |
+
<workItem from="1727785730725" duration="1830000" />
|
99 |
+
<workItem from="1727787983064" duration="2377000" />
|
100 |
+
<workItem from="1727790929474" duration="2947000" />
|
101 |
+
<workItem from="1728301291811" duration="2719000" />
|
102 |
+
<workItem from="1728475950279" duration="10306000" />
|
103 |
+
<workItem from="1728545217803" duration="6586000" />
|
104 |
+
<workItem from="1728562720043" duration="7663000" />
|
105 |
+
<workItem from="1728657746085" duration="8583000" />
|
106 |
+
<workItem from="1729079338205" duration="3308000" />
|
107 |
+
<workItem from="1729252342732" duration="600000" />
|
108 |
+
<workItem from="1729255292344" duration="3000" />
|
109 |
+
<workItem from="1729255305399" duration="170000" />
|
110 |
+
<workItem from="1730227437944" duration="1638000" />
|
111 |
+
</task>
|
112 |
+
<task id="LOCAL-00001" summary="File format + new layout WIP">
|
113 |
+
<created>1718743330580</created>
|
114 |
+
<option name="number" value="00001" />
|
115 |
+
<option name="presentableId" value="LOCAL-00001" />
|
116 |
+
<option name="project" value="LOCAL" />
|
117 |
+
<updated>1718743330580</updated>
|
118 |
+
</task>
|
119 |
+
<task id="LOCAL-00002" summary="Color theme setup">
|
120 |
+
<created>1718745108453</created>
|
121 |
+
<option name="number" value="00002" />
|
122 |
+
<option name="presentableId" value="LOCAL-00002" />
|
123 |
+
<option name="project" value="LOCAL" />
|
124 |
+
<updated>1718745108453</updated>
|
125 |
+
</task>
|
126 |
+
<task id="LOCAL-00003" summary="Theming">
|
127 |
+
<created>1718747250552</created>
|
128 |
+
<option name="number" value="00003" />
|
129 |
+
<option name="presentableId" value="LOCAL-00003" />
|
130 |
+
<option name="project" value="LOCAL" />
|
131 |
+
<updated>1718747250552</updated>
|
132 |
+
</task>
|
133 |
+
<option name="localTasksCounter" value="4" />
|
134 |
+
<servers />
|
135 |
+
</component>
|
136 |
+
<component name="TypeScriptGeneratedFilesManager">
|
137 |
+
<option name="version" value="3" />
|
138 |
+
</component>
|
139 |
+
<component name="VcsManagerConfiguration">
|
140 |
+
<option name="ADD_EXTERNAL_FILES_SILENTLY" value="true" />
|
141 |
+
<MESSAGE value="File format + new layout WIP" />
|
142 |
+
<MESSAGE value="Color theme setup" />
|
143 |
+
<MESSAGE value="Theming" />
|
144 |
+
<option name="LAST_COMMIT_MESSAGE" value="Theming" />
|
145 |
+
</component>
|
146 |
+
<component name="XDebuggerManager">
|
147 |
+
<breakpoint-manager>
|
148 |
+
<breakpoints>
|
149 |
+
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
|
150 |
+
<url>file://$PROJECT_DIR$/app.py</url>
|
151 |
+
<line>172</line>
|
152 |
+
<option name="timeStamp" value="1" />
|
153 |
+
</line-breakpoint>
|
154 |
+
</breakpoints>
|
155 |
+
</breakpoint-manager>
|
156 |
+
</component>
|
157 |
+
</project>
|
Dockerfile
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.9 as build
|
2 |
+
|
3 |
+
RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
|
4 |
+
|
5 |
+
# Set up a new user named "user" with user ID 1000
|
6 |
+
RUN useradd -m -u 1000 user
|
7 |
+
|
8 |
+
# Switch to the "user" user
|
9 |
+
USER user
|
10 |
+
|
11 |
+
# Set home to the user's home directory
|
12 |
+
ENV HOME=/home/user \
|
13 |
+
PATH=/home/user/.local/bin:$PATH
|
14 |
+
|
15 |
+
# Set the working directory to the user's home directory
|
16 |
+
WORKDIR $HOME/app
|
17 |
+
|
18 |
+
COPY requirements.txt .
|
19 |
+
RUN pip install --no-cache-dir -r ./requirements.txt --extra-index-url https://download.pytorch.org/whl/cpu
|
20 |
+
# RUN mim install mmengine
|
21 |
+
# RUN mim install "mmcv==2.1.0" & mim install "mmdet==3.3.0"
|
22 |
+
|
23 |
+
FROM build as final
|
24 |
+
|
25 |
+
COPY --chown=user . .
|
26 |
+
|
27 |
+
CMD python app.py
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2023 center4ml
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -1,12 +1,8 @@
|
|
1 |
---
|
2 |
title: GranaMeasure
|
3 |
-
emoji: 📚
|
4 |
-
colorFrom: purple
|
5 |
-
colorTo: gray
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.4.0
|
8 |
app_file: app.py
|
9 |
-
|
|
|
10 |
---
|
11 |
-
|
12 |
-
|
|
|
1 |
---
|
2 |
title: GranaMeasure
|
|
|
|
|
|
|
|
|
|
|
3 |
app_file: app.py
|
4 |
+
sdk: gradio
|
5 |
+
sdk_version: 4.20.0
|
6 |
---
|
7 |
+
# GranaMeasure_interface
|
8 |
+
interface code for GranaMeasure
|
angle_calculation/angle_model.py
ADDED
@@ -0,0 +1,444 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import pytorch_lightning as pl
|
4 |
+
import timm
|
5 |
+
|
6 |
+
# from hydra.utils import instantiate
|
7 |
+
from scipy.stats import circmean, circstd
|
8 |
+
from scipy import ndimage
|
9 |
+
from skimage.transform import resize
|
10 |
+
|
11 |
+
from sampling import get_crop_batch
|
12 |
+
from granum_utils import get_circle_mask
|
13 |
+
import image_transforms
|
14 |
+
from envelope_correction import calculate_best_angle_from_mask
|
15 |
+
## loss
|
16 |
+
|
17 |
+
class ConfidenceScaler:
|
18 |
+
def __init__(self, data: np.ndarray):
|
19 |
+
self.data = data
|
20 |
+
self.data.sort()
|
21 |
+
def __call__(self, x):
|
22 |
+
return np.searchsorted(self.data,x) / len(self.data)
|
23 |
+
|
24 |
+
class PatchedPredictor:
|
25 |
+
def __init__(self,
|
26 |
+
model,
|
27 |
+
crop_size=96,
|
28 |
+
normalization=dict(mean=0,std=1),
|
29 |
+
n_samples=32,
|
30 |
+
mask=None,# 'circle', None
|
31 |
+
filter_outliers=True,
|
32 |
+
apply_radon=False, # apply Radon transform
|
33 |
+
radon_size=(128,128), # (int, int) reshape radon transformed image to this shape,
|
34 |
+
angle_confidence_threshold=0,
|
35 |
+
use_envelope_correction=True
|
36 |
+
):
|
37 |
+
self.model = model
|
38 |
+
self.crop_size = crop_size
|
39 |
+
self.normalization = normalization
|
40 |
+
self.n_samples = n_samples
|
41 |
+
if mask not in [None, 'circle']:
|
42 |
+
raise ValueError(f'unknown mask {mask}')
|
43 |
+
self.mask = mask
|
44 |
+
self.filter_outliers = filter_outliers
|
45 |
+
|
46 |
+
self.apply_radon = apply_radon
|
47 |
+
self.radon_size = radon_size
|
48 |
+
|
49 |
+
self.angle_confidence_threshold = angle_confidence_threshold
|
50 |
+
self.use_envelope_correction = use_envelope_correction
|
51 |
+
|
52 |
+
@torch.no_grad()
|
53 |
+
def __call__(self, img: np.ndarray, mask: np.ndarray):
|
54 |
+
pl.seed_everything(44)
|
55 |
+
# get crops with different scales and rotation
|
56 |
+
crops, angles_tta, scales_tta = get_crop_batch(
|
57 |
+
img, mask,
|
58 |
+
crop_size=self.crop_size,
|
59 |
+
samples_per_scale=self.n_samples,
|
60 |
+
use_variance_threshold=True
|
61 |
+
)
|
62 |
+
if len(crops) == 0:
|
63 |
+
return dict(
|
64 |
+
est_angle=np.nan,
|
65 |
+
est_angle_confidence=0.,
|
66 |
+
)
|
67 |
+
|
68 |
+
# preprocess batch (normalize, mask, transform)
|
69 |
+
batch = self._preprocess_batch(crops)
|
70 |
+
|
71 |
+
# predict for batch - we don't use period and lumen anymore
|
72 |
+
preds_direction, preds_period, preds_lumen_width = self.model(batch)
|
73 |
+
# # convert to numpy
|
74 |
+
# preds_direction = preds_direction.numpy()
|
75 |
+
# preds_period = preds_period.numpy()
|
76 |
+
# preds_lumen_width = preds_lumen_width.numpy()
|
77 |
+
|
78 |
+
# aggregate angles
|
79 |
+
est_angles = (preds_direction - angles_tta) % 180
|
80 |
+
est_angle = circmean(est_angles, low=-90, high=90) + 90
|
81 |
+
est_angle_std = circstd(est_angles, low=-90, high=90)
|
82 |
+
est_angle_confidence = self._std_to_confidence(est_angle_std, 10) # confidence 0.5 for std =10 degrees
|
83 |
+
|
84 |
+
if est_angle_confidence < self.angle_confidence_threshold:
|
85 |
+
est_angle = np.nan
|
86 |
+
est_angle_confidence = 0.
|
87 |
+
|
88 |
+
if self.use_envelope_correction and (not np.isnan(est_angle)):
|
89 |
+
angle_correction = -calculate_best_angle_from_mask(
|
90 |
+
ndimage.rotate(mask, -est_angle, reshape=True, order=0)
|
91 |
+
)
|
92 |
+
est_angle += angle_correction
|
93 |
+
|
94 |
+
return dict(
|
95 |
+
est_angle=est_angle,
|
96 |
+
est_angle_confidence=est_angle_confidence,
|
97 |
+
)
|
98 |
+
|
99 |
+
def _apply_radon(self, batch): # may reauire circle mask
|
100 |
+
crops_radon = image_transforms.batched_radon(batch.numpy())
|
101 |
+
crops_radon = np.transpose(resize(np.transpose(crops_radon, (1, 2, 0)), self.radon_size), (2, 0, 1))
|
102 |
+
return torch.tensor(crops_radon)
|
103 |
+
|
104 |
+
def _preprocess_batch(self, batch):
|
105 |
+
if self.mask == 'circle':
|
106 |
+
mask = get_circle_mask(batch.shape[1])
|
107 |
+
batch[:,mask] = 0
|
108 |
+
if self.apply_radon:
|
109 |
+
batch = self._apply_radon(batch)
|
110 |
+
batch = ((batch/255) - self.normalization['mean'])/self.normalization['std']
|
111 |
+
return batch.unsqueeze(1) # add channel dimension
|
112 |
+
|
113 |
+
def _filter_outliers(self, x, qmin=0.25, qmax=0.75):
|
114 |
+
x_min, x_max = np.quantile(x, [qmin, qmax])
|
115 |
+
return x[(x>=x_min) & (x<=x_max)]
|
116 |
+
|
117 |
+
def _std_to_confidence(self, x, x_thr, y_thr=0.5):
|
118 |
+
"""transform [0, inf] to [1,0], such that f(x_thr)=y_thr"""
|
119 |
+
return 1 / (1+x*(1-y_thr)/(x_thr*y_thr))
|
120 |
+
|
121 |
+
class CosineLoss(torch.nn.Module):
|
122 |
+
def __init__(self, p=1, degrees=False, scale=1):
|
123 |
+
super().__init__()
|
124 |
+
self.p = p
|
125 |
+
self.degrees = degrees
|
126 |
+
self.scale = scale
|
127 |
+
def forward(self, x, y):
|
128 |
+
if self.degrees:
|
129 |
+
x = torch.deg2rad(x)
|
130 |
+
y = torch.deg2rad(y)
|
131 |
+
return torch.mean((1-torch.cos(x-y))**self.p) * self.scale
|
132 |
+
|
133 |
+
## model
|
134 |
+
class AngleParser2d(torch.nn.Module):
|
135 |
+
def __init__(self, angle_range=180):
|
136 |
+
super().__init__()
|
137 |
+
self.angle_range = angle_range
|
138 |
+
def forward(self, batch):
|
139 |
+
# r = torch.linalg.norm(batch, dim=1)
|
140 |
+
preds_y_proj = torch.sigmoid(batch[:,0]) - 0.5
|
141 |
+
preds_x_proj = torch.sigmoid(batch[:,1]) - 0.5
|
142 |
+
preds_direction = self.angle_range/360.*torch.rad2deg(torch.arctan2(preds_y_proj, preds_x_proj))
|
143 |
+
return preds_direction
|
144 |
+
|
145 |
+
class AngleRegularizer(torch.nn.Module):
|
146 |
+
def __init__(self, strength=1.0, scale=1.0, p=2):
|
147 |
+
super().__init__()
|
148 |
+
self.strength = strength
|
149 |
+
self.scale = scale
|
150 |
+
self.p = p
|
151 |
+
def forward(self, batch):
|
152 |
+
r = torch.linalg.norm(batch, dim=1)
|
153 |
+
return self.strength * torch.norm(r - self.scale, p=self.p)
|
154 |
+
|
155 |
+
class AngleRegularizerLog(torch.nn.Module):
|
156 |
+
def __init__(self, strength=1.0, scale=1.0, p=2):
|
157 |
+
super().__init__()
|
158 |
+
self.strength = strength
|
159 |
+
self.scale = scale
|
160 |
+
self.p = p
|
161 |
+
def forward(self, batch):
|
162 |
+
r = torch.linalg.norm(batch, dim=1)
|
163 |
+
return self.strength * torch.norm(torch.log(r/self.scale), p=self.p)
|
164 |
+
|
165 |
+
class StripsModel(pl.LightningModule):
|
166 |
+
def __init__(self,
|
167 |
+
model_name = 'resnet18',
|
168 |
+
lr=0.001,
|
169 |
+
optimizer_hparams=dict(),
|
170 |
+
lr_hparams=dict(classname='MultiStepLR', kwargs=dict(milestones=[100, 150], gamma=0.1)),
|
171 |
+
loss_hparams=dict(rotation_weight=10., lumen_fraction_weight=50.),
|
172 |
+
angle_hparams=dict(angle_range=180.),
|
173 |
+
regularizer_hparams=None,
|
174 |
+
sigmoid_smoother=10.
|
175 |
+
):
|
176 |
+
super().__init__()
|
177 |
+
# Exports the hyperparameters to a YAML file, and create "self.hparams" namespace
|
178 |
+
self.save_hyperparameters()
|
179 |
+
# Create model - implemented in non-abstract classes
|
180 |
+
self.model = timm.create_model(model_name, in_chans=1, num_classes=4) #2 + self.hparams.angle_hparams['ndim'])
|
181 |
+
self.angle_parser = AngleParser2d(**self.hparams.angle_hparams)
|
182 |
+
self.regularizer = self._get_regularizer(self.hparams.regularizer_hparams)
|
183 |
+
self.losses = {
|
184 |
+
'direction': CosineLoss(2., True),
|
185 |
+
'period': torch.nn.functional.mse_loss,
|
186 |
+
'lumen_fraction': torch.nn.functional.mse_loss
|
187 |
+
}
|
188 |
+
self.losses_weights = {
|
189 |
+
'direction': self.hparams.loss_hparams['rotation_weight'],
|
190 |
+
'period': 1,
|
191 |
+
'lumen_fraction': self.hparams.loss_hparams['lumen_fraction_weight'],
|
192 |
+
'regularization': self.hparams.loss_hparams.get('regularization_weight', 0.)
|
193 |
+
}
|
194 |
+
|
195 |
+
def _get_regularizer(self, regularizer_params):
|
196 |
+
if regularizer_params is None:
|
197 |
+
return None
|
198 |
+
else:
|
199 |
+
return instantiate(regularizer_params)
|
200 |
+
|
201 |
+
|
202 |
+
def forward(self, x, return_raw=False):
|
203 |
+
"""get predictions from image batch"""
|
204 |
+
preds = self.model(x) # preds: logit angle_sin, logit angle_cos, period, logit lumen fraction or logit angle, period, logit lumen fraction
|
205 |
+
preds_direction = self.angle_parser(preds)
|
206 |
+
preds_period = preds[:,-2]
|
207 |
+
preds_lumen_fraction = torch.sigmoid(preds[:,-1]*self.hparams.sigmoid_smoother) #lumen fraction is between 0 and 1, so we take sigmoid fo this
|
208 |
+
|
209 |
+
outputs = [preds_direction, preds_period, preds_lumen_fraction]
|
210 |
+
if return_raw:
|
211 |
+
outputs.append(preds)
|
212 |
+
|
213 |
+
return tuple(outputs)
|
214 |
+
|
215 |
+
def configure_optimizers(self):
|
216 |
+
# AdamW is Adam with a correct implementation of weight decay (see here
|
217 |
+
# for details: https://arxiv.org/pdf/1711.05101.pdf)
|
218 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, **self.hparams.optimizer_hparams)
|
219 |
+
# scheduler = getattr(torch.optim.lr_scheduler, self.hparams.lr_hparams['classname'])(optimizer, **self.hparams.lr_hparams['kwargs'])
|
220 |
+
scheduler = instantiate({**self.hparams.lr_hparams, '_partial_': True})(optimizer)
|
221 |
+
return [optimizer], [scheduler]
|
222 |
+
|
223 |
+
def process_batch_supervised(self, batch):
|
224 |
+
"""get predictions, losses and mean errors (MAE)"""
|
225 |
+
|
226 |
+
# get predictions
|
227 |
+
preds = {}
|
228 |
+
preds['direction'], preds['period'], preds['lumen_fraction'], preds_raw = self.forward(batch['image'], return_raw=True) # preds: angle, period, lumen fraction, raw preds
|
229 |
+
|
230 |
+
# calculate losses
|
231 |
+
losses = {
|
232 |
+
'direction': self.losses['direction'](2*batch['direction'], 2*preds['direction']),
|
233 |
+
'period': self.losses['period'](batch['period'], preds['period']),
|
234 |
+
'lumen_fraction': self.losses['lumen_fraction'](batch['lumen_fraction'], preds['lumen_fraction']),
|
235 |
+
}
|
236 |
+
if self.regularizer is not None:
|
237 |
+
losses['regularization'] = self.regularizer(preds_raw[:,:2])
|
238 |
+
|
239 |
+
losses['final'] = \
|
240 |
+
losses['direction']*self.losses_weights['direction'] + \
|
241 |
+
losses['period']*self.losses_weights['period'] + \
|
242 |
+
losses['lumen_fraction']*self.losses_weights['lumen_fraction'] + \
|
243 |
+
losses.get('regularization', 0.)*self.losses_weights.get('regularization', 0.)
|
244 |
+
|
245 |
+
# calculate mean errors
|
246 |
+
period_difference = np.mean(abs(
|
247 |
+
batch['period'].detach().cpu().numpy() - \
|
248 |
+
preds['period'].detach().cpu().numpy()
|
249 |
+
))
|
250 |
+
|
251 |
+
a1 = batch['direction'].detach().cpu().numpy()
|
252 |
+
a2 = preds['direction'].detach().cpu().numpy()
|
253 |
+
angle_difference = np.mean(0.5*np.degrees(np.arccos(np.cos(2*np.radians(a2-a1)))))
|
254 |
+
|
255 |
+
lumen_fraction_difference = np.mean(abs(preds['lumen_fraction'].detach().cpu().numpy()-batch['lumen_fraction'].detach().cpu().numpy()))
|
256 |
+
|
257 |
+
mae = {
|
258 |
+
'period': period_difference,
|
259 |
+
'direction': angle_difference,
|
260 |
+
'lumen_fraction': lumen_fraction_difference
|
261 |
+
}
|
262 |
+
|
263 |
+
return preds, losses, mae
|
264 |
+
|
265 |
+
def log_all(self, losses, mae, prefix=''):
|
266 |
+
self.log(f"{prefix}angle_loss", losses['direction'].item())
|
267 |
+
self.log(f"{prefix}period_loss", losses['period'].item())
|
268 |
+
self.log(f"{prefix}lumen_fraction_loss", losses['lumen_fraction'].item())
|
269 |
+
self.log(f"{prefix}period_difference", mae['period'])
|
270 |
+
self.log(f"{prefix}angle_difference", mae['direction'])
|
271 |
+
self.log(f"{prefix}lumen_fraction_difference", mae['lumen_fraction'])
|
272 |
+
self.log(f"{prefix}loss", losses['final'])
|
273 |
+
if 'regularization' in losses:
|
274 |
+
self.log(f"{prefix}regularization_loss", losses['regularization'].item())
|
275 |
+
|
276 |
+
def training_step(self, batch, batch_idx):
|
277 |
+
# "batch" is the output of the training data loader.
|
278 |
+
preds, losses, mae = self.process_batch_supervised(batch)
|
279 |
+
self.log_all(losses, mae, prefix='train_')
|
280 |
+
|
281 |
+
return losses['final']
|
282 |
+
|
283 |
+
def validation_step(self, batch, batch_idx):
|
284 |
+
preds, losses, mae = self.process_batch_supervised(batch)
|
285 |
+
self.log_all(losses, mae, prefix='val_')
|
286 |
+
|
287 |
+
def test_step(self, batch, batch_idx):
|
288 |
+
preds, losses, mae = self.process_batch_supervised(batch)
|
289 |
+
self.log_all(losses, mae, prefix='test_')
|
290 |
+
|
291 |
+
|
292 |
+
class StripsModelLumenWidth(pl.LightningModule):
|
293 |
+
def __init__(self,
|
294 |
+
model_name = 'resnet18',
|
295 |
+
lr=0.001,
|
296 |
+
optimizer_hparams=dict(),
|
297 |
+
lr_hparams=dict(classname='MultiStepLR', kwargs=dict(milestones=[100, 150], gamma=0.1)),
|
298 |
+
loss_hparams=dict(rotation_weight=10., lumen_width_weight=50.),
|
299 |
+
angle_hparams=dict(angle_range=180.),
|
300 |
+
regularizer_hparams=None,
|
301 |
+
sigmoid_smoother=10.
|
302 |
+
):
|
303 |
+
super().__init__()
|
304 |
+
# Exports the hyperparameters to a YAML file, and create "self.hparams" namespace
|
305 |
+
self.save_hyperparameters()
|
306 |
+
# Create model - implemented in non-abstract classes
|
307 |
+
self.model = timm.create_model(model_name, in_chans=1, num_classes=4) #2 + self.hparams.angle_hparams['ndim'])
|
308 |
+
self.angle_parser = AngleParser2d(**self.hparams.angle_hparams)
|
309 |
+
self.regularizer = self._get_regularizer(self.hparams.regularizer_hparams)
|
310 |
+
self.losses = {
|
311 |
+
'direction': CosineLoss(2., True),
|
312 |
+
'period': torch.nn.functional.mse_loss,
|
313 |
+
'lumen_width': torch.nn.functional.mse_loss
|
314 |
+
}
|
315 |
+
self.losses_weights = {
|
316 |
+
'direction': self.hparams.loss_hparams['rotation_weight'],
|
317 |
+
'period': 1,
|
318 |
+
'lumen_width': self.hparams.loss_hparams['lumen_width_weight'],
|
319 |
+
'regularization': self.hparams.loss_hparams.get('regularization_weight', 0.)
|
320 |
+
}
|
321 |
+
|
322 |
+
def _get_regularizer(self, regularizer_params):
|
323 |
+
if regularizer_params is None:
|
324 |
+
return None
|
325 |
+
else:
|
326 |
+
return instantiate(regularizer_params)
|
327 |
+
|
328 |
+
def forward(self, x, return_raw=False):
|
329 |
+
"""get predictions from image batch"""
|
330 |
+
preds = self.model(x) # preds: logit angle_sin, logit angle_cos, period, logit lumen fraction or logit angle, period, logit lumen fraction
|
331 |
+
preds_direction = self.angle_parser(preds)
|
332 |
+
preds_period = preds[:,-2]
|
333 |
+
preds_lumen_width = preds[:,-1] #lumen fraction is between 0 and 1, so we take sigmoid fo this
|
334 |
+
|
335 |
+
outputs = [preds_direction, preds_period, preds_lumen_width]
|
336 |
+
if return_raw:
|
337 |
+
outputs.append(preds)
|
338 |
+
|
339 |
+
return tuple(outputs)
|
340 |
+
|
341 |
+
def configure_optimizers(self):
|
342 |
+
# AdamW is Adam with a correct implementation of weight decay (see here
|
343 |
+
# for details: https://arxiv.org/pdf/1711.05101.pdf)
|
344 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, **self.hparams.optimizer_hparams)
|
345 |
+
# scheduler = getattr(torch.optim.lr_scheduler, self.hparams.lr_hparams['classname'])(optimizer, **self.hparams.lr_hparams['kwargs'])
|
346 |
+
scheduler = instantiate({**self.hparams.lr_hparams, '_partial_': True})(optimizer)
|
347 |
+
return [optimizer], [scheduler]
|
348 |
+
|
349 |
+
def process_batch_supervised(self, batch):
|
350 |
+
"""get predictions, losses and mean errors (MAE)"""
|
351 |
+
|
352 |
+
# get predictions
|
353 |
+
preds = {}
|
354 |
+
preds['direction'], preds['period'], preds['lumen_width'], preds_raw = self.forward(batch['image'], return_raw=True) # preds: angle, period, lumen fraction, raw preds
|
355 |
+
|
356 |
+
# calculate losses
|
357 |
+
losses = {
|
358 |
+
'direction': self.losses['direction'](2*batch['direction'], 2*preds['direction']),
|
359 |
+
'period': self.losses['period'](batch['period'], preds['period']),
|
360 |
+
'lumen_width': self.losses['lumen_width'](batch['lumen_width'], preds['lumen_width']),
|
361 |
+
}
|
362 |
+
if self.regularizer is not None:
|
363 |
+
losses['regularization'] = self.regularizer(preds_raw[:,:2])
|
364 |
+
|
365 |
+
losses['final'] = \
|
366 |
+
losses['direction']*self.losses_weights['direction'] + \
|
367 |
+
losses['period']*self.losses_weights['period'] + \
|
368 |
+
losses['lumen_width']*self.losses_weights['lumen_width'] + \
|
369 |
+
losses.get('regularization', 0.)*self.losses_weights.get('regularization', 0.)
|
370 |
+
|
371 |
+
# calculate mean errors
|
372 |
+
period_difference = np.mean(abs(
|
373 |
+
batch['period'].detach().cpu().numpy() - \
|
374 |
+
preds['period'].detach().cpu().numpy()
|
375 |
+
))
|
376 |
+
|
377 |
+
a1 = batch['direction'].detach().cpu().numpy()
|
378 |
+
a2 = preds['direction'].detach().cpu().numpy()
|
379 |
+
angle_difference = np.mean(0.5*np.degrees(np.arccos(np.cos(2*np.radians(a2-a1)))))
|
380 |
+
|
381 |
+
lumen_width_difference = np.mean(abs(preds['lumen_width'].detach().cpu().numpy()-batch['lumen_width'].detach().cpu().numpy()))
|
382 |
+
|
383 |
+
lumen_fraction_pred = preds['lumen_width'].detach().cpu().numpy()/preds['period'].detach().cpu().numpy()
|
384 |
+
lumen_fraction_gt = batch['lumen_width'].detach().cpu().numpy()/batch['period'].detach().cpu().numpy()
|
385 |
+
lumen_fraction_difference = np.mean(abs(lumen_fraction_pred-lumen_fraction_gt))
|
386 |
+
|
387 |
+
mae = {
|
388 |
+
'period': period_difference,
|
389 |
+
'direction': angle_difference,
|
390 |
+
'lumen_width': lumen_width_difference,
|
391 |
+
'lumen_fraction': lumen_fraction_difference
|
392 |
+
}
|
393 |
+
|
394 |
+
return preds, losses, mae
|
395 |
+
|
396 |
+
def log_all(self, losses, mae, prefix=''):
|
397 |
+
for k, v in losses.items():
|
398 |
+
self.log(f'{prefix}{k}_loss', v.item() if isinstance(v, torch.Tensor) else v)
|
399 |
+
for k, v in mae.items():
|
400 |
+
self.log(f'{prefix}{k}_difference', v.item() if isinstance(v, torch.Tensor) else v)
|
401 |
+
|
402 |
+
def training_step(self, batch, batch_idx):
|
403 |
+
# "batch" is the output of the training data loader.
|
404 |
+
preds, losses, mae = self.process_batch_supervised(batch)
|
405 |
+
self.log_all(losses, mae, prefix='train_')
|
406 |
+
|
407 |
+
return losses['final']
|
408 |
+
|
409 |
+
def validation_step(self, batch, batch_idx):
|
410 |
+
preds, losses, mae = self.process_batch_supervised(batch)
|
411 |
+
self.log_all(losses, mae, prefix='val_')
|
412 |
+
|
413 |
+
def test_step(self, batch, batch_idx):
|
414 |
+
preds, losses, mae = self.process_batch_supervised(batch)
|
415 |
+
self.log_all(losses, mae, prefix='test_')
|
416 |
+
|
417 |
+
|
418 |
+
|
419 |
+
# class StripsModel(StripsModelGeneral):
|
420 |
+
# def __init__(self, model_name, *args, **kwargs):
|
421 |
+
# super().__init__( *args, **kwargs)
|
422 |
+
# self.model = timm.create_model(model_name, in_chans=1, num_classes=4)
|
423 |
+
# def forward(self, x):
|
424 |
+
# """get predictions from image batch"""
|
425 |
+
# preds = self.model(x) # preds: logit angle_sin, logit angle_cos, period, logit lumen fraction
|
426 |
+
# preds_sin = 1. - 2*torch.sigmoid(preds[:,0])
|
427 |
+
# preds_cos = 1. - 2*torch.sigmoid(preds[:,1])
|
428 |
+
# preds_direction = 0.5*torch.rad2deg(torch.arctan2(preds_sin, preds_cos))
|
429 |
+
# preds_period = preds[:,2]
|
430 |
+
# preds_lumen_fraction = torch.sigmoid(preds[:,3]) #lumen fraction is between 0 and 1, so we take sigmoid fo this
|
431 |
+
# return preds_direction, preds_period, preds_lumen_fraction
|
432 |
+
|
433 |
+
# class StripsModelAngle1(StripsModelGeneral):
|
434 |
+
# def __init__(self, model_name, *args, **kwargs):
|
435 |
+
# super().__init__( *args, **kwargs)
|
436 |
+
# self.model = timm.create_model(model_name, in_chans=1, num_classes=3)
|
437 |
+
# def forward(self, x):
|
438 |
+
# """get predictions from image batch"""
|
439 |
+
# preds = self.model(x) # preds: logit angle_sin, logit angle
|
440 |
+
# preds_direction = torch.pi * torch.sigmoid(preds[:,0])
|
441 |
+
# preds_period = preds[:,1]
|
442 |
+
# preds_lumen_fraction = torch.sigmoid(preds[:,2]) #lumen fraction is between 0 and 1, so we take sigmoid fo this
|
443 |
+
# return preds_direction, preds_period, preds_lumen_fraction
|
444 |
+
|
angle_calculation/classic.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from scipy import signal
|
3 |
+
from scipy import ndimage
|
4 |
+
from scipy.fftpack import next_fast_len
|
5 |
+
from skimage.transform import rotate
|
6 |
+
from skimage._shared.utils import convert_to_float
|
7 |
+
from skimage.transform import warp
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import cv2
|
10 |
+
from copy import deepcopy
|
11 |
+
|
12 |
+
def get_directional_std(image, theta=None,*, preserve_range=False):
|
13 |
+
|
14 |
+
if image.ndim != 2:
|
15 |
+
raise ValueError('The input image must be 2-D')
|
16 |
+
if theta is None:
|
17 |
+
theta = np.arange(180)
|
18 |
+
|
19 |
+
image = convert_to_float(image.copy(), preserve_range) #TODO: needed?
|
20 |
+
|
21 |
+
shape_min = min(image.shape)
|
22 |
+
img_shape = np.array(image.shape)
|
23 |
+
|
24 |
+
# Crop image to make it square
|
25 |
+
slices = tuple(slice(int(np.ceil(excess / 2)),
|
26 |
+
int(np.ceil(excess / 2) + shape_min))
|
27 |
+
if excess > 0 else slice(None)
|
28 |
+
for excess in (img_shape - shape_min))
|
29 |
+
image = image[slices]
|
30 |
+
shape_min = min(image.shape)
|
31 |
+
img_shape = np.array(image.shape)
|
32 |
+
|
33 |
+
radius = shape_min // 2
|
34 |
+
coords = np.array(np.ogrid[:image.shape[0], :image.shape[1]],
|
35 |
+
dtype=object)
|
36 |
+
dist = ((coords - img_shape // 2) ** 2).sum(0)
|
37 |
+
outside_reconstruction_circle = dist > radius ** 2
|
38 |
+
image[outside_reconstruction_circle] = 0
|
39 |
+
|
40 |
+
valid_square_slice = slice(int(np.ceil(radius*(1-1/np.sqrt(2)))), int(np.ceil(radius*(1+1/np.sqrt(2)))) )
|
41 |
+
|
42 |
+
# padded_image is always square
|
43 |
+
if image.shape[0] != image.shape[1]:
|
44 |
+
raise ValueError('padded_image must be a square')
|
45 |
+
center = image.shape[0] // 2
|
46 |
+
result = np.zeros(len(theta))
|
47 |
+
|
48 |
+
for i, angle in enumerate(np.deg2rad(theta)):
|
49 |
+
cos_a, sin_a = np.cos(angle), np.sin(angle)
|
50 |
+
R = np.array([[cos_a, sin_a, -center * (cos_a + sin_a - 1)],
|
51 |
+
[-sin_a, cos_a, -center * (cos_a - sin_a - 1)],
|
52 |
+
[0, 0, 1]])
|
53 |
+
rotated = warp(image, R, clip=False)
|
54 |
+
result[i] = rotated[valid_square_slice, valid_square_slice].std(axis=0).mean()
|
55 |
+
return result
|
56 |
+
|
57 |
+
def acf2d(x, nlags=None):
|
58 |
+
xo = x - x.mean(axis=0)
|
59 |
+
n = len(x)
|
60 |
+
if nlags is None:
|
61 |
+
nlags = n -1
|
62 |
+
lag_len = nlags
|
63 |
+
|
64 |
+
xi = np.arange(1, n + 1)
|
65 |
+
d = np.expand_dims(np.hstack((xi, xi[:-1][::-1])),1)
|
66 |
+
|
67 |
+
nobs = len(xo)
|
68 |
+
n = next_fast_len(2 * nobs + 1)
|
69 |
+
Frf = np.fft.fft(xo, n=n, axis=0)
|
70 |
+
|
71 |
+
acov = np.fft.ifft(Frf * np.conjugate(Frf), axis=0)[:nobs] / d[nobs - 1 :]
|
72 |
+
acov = acov.real
|
73 |
+
ac = acov[: nlags + 1] / acov[:1]
|
74 |
+
return ac
|
75 |
+
|
76 |
+
def get_period(acf_table, n_samples=50):
|
77 |
+
#TODO: use peak heights to select best candidates. use std to eliminate outliers
|
78 |
+
period_candidates = []
|
79 |
+
period_candidates_hights = []
|
80 |
+
for i in np.random.randint(0, acf_table.shape[1], min(acf_table.shape[1], n_samples)):
|
81 |
+
peaks = signal.find_peaks(acf_table[:,i])[0]
|
82 |
+
if len(peaks) == 0:
|
83 |
+
continue
|
84 |
+
peak_idx = peaks[0]
|
85 |
+
period_candidates.append(peak_idx)
|
86 |
+
period_candidates_hights.append(acf_table[peak_idx,i])
|
87 |
+
period_candidates = np.array(period_candidates)
|
88 |
+
period_candidates_hights = np.array(period_candidates_hights)
|
89 |
+
|
90 |
+
if len(period_candidates) == 0:
|
91 |
+
return np.nan, np.nan
|
92 |
+
elif len(period_candidates) == 1:
|
93 |
+
return period_candidates[0], np.nan
|
94 |
+
q1, q3 = np.quantile(period_candidates, [0.25, 0.75])
|
95 |
+
candidates_std = np.std(period_candidates[(period_candidates>=q1)&(period_candidates<=q3)])
|
96 |
+
# return period_candidates, period_candidates_hights
|
97 |
+
return np.median(period_candidates), candidates_std
|
98 |
+
|
99 |
+
def get_rotation_with_confidence(padded_image, blur_size=4, make_plots=True):
|
100 |
+
std_by_angle = get_directional_std(cv2.blur(padded_image, (blur_size,blur_size)))
|
101 |
+
rotation_angle = np.argmin(std_by_angle)
|
102 |
+
|
103 |
+
rotation_quality = 1 - np.min(std_by_angle)/np.median(std_by_angle)
|
104 |
+
if make_plots:
|
105 |
+
plt.plot(std_by_angle)
|
106 |
+
plt.axvline(rotation_angle, c='k')
|
107 |
+
plt.title(f'quality: {rotation_quality:0.2f}')
|
108 |
+
return rotation_angle, rotation_quality
|
109 |
+
|
110 |
+
def calculate_autocorrelation(oriented_img, blur_kernel=(7,1), make_plots=True):
|
111 |
+
autocorrelation = acf2d(cv2.blur(oriented_img.T, blur_kernel))
|
112 |
+
if make_plots:
|
113 |
+
fig, axs = plt.subplots(ncols=2, figsize=(12,6))
|
114 |
+
axs[0].imshow(autocorrelation)
|
115 |
+
axs[1].plot(autocorrelation.sum(axis=1))
|
116 |
+
return autocorrelation
|
117 |
+
|
118 |
+
def get_period_with_confidence(autocorrelation_tab, n_samples=30):
|
119 |
+
period, period_std = get_period(autocorrelation_tab, n_samples=n_samples)
|
120 |
+
if period_std == np.nan:
|
121 |
+
period_confidence = 0.001
|
122 |
+
else:
|
123 |
+
period_confidence = period/(period+2*period_std)
|
124 |
+
return period, period_confidence
|
125 |
+
|
126 |
+
def calculate_white_fraction(img, blur_size=4, make_plots=True): #TODO: add mask
|
127 |
+
blurred = cv2.blur(img, (blur_size, blur_size))
|
128 |
+
blurred_sum = blurred.sum(axis=0)
|
129 |
+
lower, upper = np.quantile(blurred_sum, [0.15, 0.85])
|
130 |
+
sign = blurred_sum > (lower+upper)/2
|
131 |
+
|
132 |
+
sign_change = sign[:-1] != sign[1:]
|
133 |
+
sign_change_indices = np.where(sign_change)[0]
|
134 |
+
|
135 |
+
if len(sign_change_indices) >= 2 + (sign[-1] == sign[0]):
|
136 |
+
cut_first = sign_change_indices[0]+1
|
137 |
+
|
138 |
+
if sign[-1] == sign[0]:
|
139 |
+
cut_last = sign_change_indices[-2]
|
140 |
+
else:
|
141 |
+
cut_last = sign_change_indices[-1]
|
142 |
+
|
143 |
+
white_fraction = np.mean(sign[cut_first:cut_last])
|
144 |
+
else:
|
145 |
+
white_fraction = np.nan
|
146 |
+
cut_first, cut_last = None, None
|
147 |
+
if make_plots:
|
148 |
+
fig, axs = plt.subplots(ncols=3, figsize=(16,6))
|
149 |
+
blurred_sum_normalized = blurred_sum - blurred_sum.min()
|
150 |
+
blurred_sum_normalized /= blurred_sum_normalized.max()
|
151 |
+
axs[0].plot(blurred_sum_normalized)
|
152 |
+
axs[0].plot(sign)
|
153 |
+
axs[1].plot(blurred_sum_normalized[cut_first:cut_last])
|
154 |
+
axs[1].plot(sign[cut_first:cut_last])
|
155 |
+
axs[2].imshow(img, cmap='gray')
|
156 |
+
for i, idx in enumerate(sign_change_indices):
|
157 |
+
plt.axvline(idx, c=['r', 'lime'][i%2])
|
158 |
+
fig.suptitle(f'fraction: {white_fraction:0.2f}')
|
159 |
+
|
160 |
+
return white_fraction
|
161 |
+
|
162 |
+
def process_img_crop(img, nm_per_px=1, make_plots=False, return_extra=False):
|
163 |
+
|
164 |
+
# image must be square
|
165 |
+
assert img.shape[0] == img.shape[1]
|
166 |
+
crop_size = img.shape[0]
|
167 |
+
|
168 |
+
# find orientation
|
169 |
+
rotation_angle, rotation_quality = get_rotation_with_confidence(img, blur_size=4, make_plots=make_plots)
|
170 |
+
|
171 |
+
# rotate and crop image
|
172 |
+
crop_margin = int((1 - 1/np.sqrt(2))*crop_size*0.5)
|
173 |
+
oriented_img = rotate(img, -rotation_angle)[2*crop_margin:-crop_margin, crop_margin:-crop_margin]
|
174 |
+
|
175 |
+
# calculate autocorrelation
|
176 |
+
autocorrelation = calculate_autocorrelation(oriented_img, blur_kernel=(7,1), make_plots=make_plots)
|
177 |
+
|
178 |
+
# find period
|
179 |
+
period, period_confidence = get_period_with_confidence(autocorrelation)
|
180 |
+
if make_plots:
|
181 |
+
print(f'period: {period}, confidence: {period_confidence}')
|
182 |
+
|
183 |
+
# find white fraction
|
184 |
+
white_fraction = calculate_white_fraction(oriented_img, make_plots=make_plots)
|
185 |
+
white_width = white_fraction*period
|
186 |
+
|
187 |
+
result = {
|
188 |
+
'direction': rotation_angle,
|
189 |
+
'direction confidence': rotation_quality,
|
190 |
+
'period': period*nm_per_px,
|
191 |
+
'period confidence': period_confidence,
|
192 |
+
'lumen width': white_width*nm_per_px
|
193 |
+
}
|
194 |
+
if return_extra:
|
195 |
+
result['extra'] = {
|
196 |
+
'autocorrelation': autocorrelation,
|
197 |
+
'oriented_img': oriented_img
|
198 |
+
}
|
199 |
+
|
200 |
+
return result
|
201 |
+
|
202 |
+
def get_top_k(a, k):
|
203 |
+
ind = np.argpartition(a, -k)[-k:]
|
204 |
+
return a[ind]
|
205 |
+
|
206 |
+
def get_crops(img, distance_map, crop_size, N_sample):
|
207 |
+
crop_r= np.sqrt(2)*crop_size / 2
|
208 |
+
possible_positions_y, possible_positions_x = np.where(distance_map >= crop_r)
|
209 |
+
no_edge_mask = (possible_positions_y>crop_r) & \
|
210 |
+
(possible_positions_x>crop_r) & \
|
211 |
+
(possible_positions_y<(distance_map.shape[0]-crop_r)) & \
|
212 |
+
(possible_positions_x<(distance_map.shape[1]-crop_r))
|
213 |
+
|
214 |
+
possible_positions_x = possible_positions_x[no_edge_mask]
|
215 |
+
possible_positions_y = possible_positions_y[no_edge_mask]
|
216 |
+
N_available = len(possible_positions_x)
|
217 |
+
positions_indices = np.random.choice(np.arange(N_available), min(N_sample, N_available), replace=False)
|
218 |
+
|
219 |
+
for idx in positions_indices:
|
220 |
+
yield img[possible_positions_y[idx]-crop_size//2:possible_positions_y[idx]+crop_size//2,possible_positions_x[idx]-crop_size//2:possible_positions_x[idx]+crop_size//2].copy()
|
221 |
+
|
222 |
+
def sliced_mean(x, slice_size):
|
223 |
+
cs_y = np.cumsum(x, axis=0)
|
224 |
+
cs_y = np.concatenate((np.zeros((1, cs_y.shape[1]), dtype=cs_y.dtype), cs_y), axis=0)
|
225 |
+
slices_y = (cs_y[slice_size:] - cs_y[:-slice_size])/slice_size
|
226 |
+
cs_xy = np.cumsum(slices_y, axis=1)
|
227 |
+
cs_xy = np.concatenate((np.zeros((cs_xy.shape[0], 1), dtype=cs_xy.dtype), cs_xy), axis=1)
|
228 |
+
slices_xy = (cs_xy[:,slice_size:] - cs_xy[:,:-slice_size])/slice_size
|
229 |
+
return slices_xy
|
230 |
+
|
231 |
+
def sliced_var(x, slice_size):
|
232 |
+
x = x.astype('float64')
|
233 |
+
return sliced_mean(x**2, slice_size) - sliced_mean(x, slice_size)**2
|
234 |
+
|
235 |
+
def select_samples(granum_image, granum_mask, crop_size=96, n_samples=64, granum_fraction_min=1.0, variance_p=2):
|
236 |
+
granum_occupancy = sliced_mean(granum_mask, crop_size)
|
237 |
+
possible_indices = np.stack(np.where(granum_occupancy >= granum_fraction_min), axis=1)
|
238 |
+
|
239 |
+
if variance_p == 0:
|
240 |
+
p = np.ones(len(possible_indices))
|
241 |
+
else:
|
242 |
+
variance_map = sliced_var(granum_image, crop_size)
|
243 |
+
p = variance_map[possible_indices[:,0], possible_indices[:,1]]**variance_p
|
244 |
+
p /= np.sum(p)
|
245 |
+
|
246 |
+
chosen_indices = np.random.choice(
|
247 |
+
np.arange(len(possible_indices)),
|
248 |
+
min(len(possible_indices), n_samples),
|
249 |
+
replace=False,
|
250 |
+
p = p
|
251 |
+
)
|
252 |
+
|
253 |
+
crops = []
|
254 |
+
for crop_idx, idx in enumerate(chosen_indices):
|
255 |
+
crops.append(
|
256 |
+
granum_image[
|
257 |
+
possible_indices[idx,0]:possible_indices[idx,0]+crop_size,
|
258 |
+
possible_indices[idx,1]:possible_indices[idx,1]+crop_size
|
259 |
+
]
|
260 |
+
)
|
261 |
+
return np.array(crops)
|
262 |
+
|
263 |
+
def calculate_distance_map(mask):
|
264 |
+
padded = np.pad(mask, pad_width=1, mode='constant', constant_values=False)
|
265 |
+
distance_map_padded = ndimage.distance_transform_edt(padded)
|
266 |
+
return distance_map_padded[1:-1,1:-1]
|
267 |
+
|
268 |
+
|
269 |
+
def measure_object(
|
270 |
+
img, mask,
|
271 |
+
nm_per_px=1, n_tries = 3,
|
272 |
+
direction_thr_min = 0.07, direction_thr_enough = 0.1,
|
273 |
+
crop_size = 200,
|
274 |
+
**kwargs):
|
275 |
+
|
276 |
+
distance_map = calculate_distance_map(mask)
|
277 |
+
crop_size = min(crop_size, int(min(get_top_k(distance_map.flatten(), n_tries)*0.5**0.5)))
|
278 |
+
|
279 |
+
direction_confidence = 0
|
280 |
+
best_stripes_data = {}
|
281 |
+
for i, img_crop in enumerate(get_crops(img, distance_map, crop_size, n_tries)):
|
282 |
+
stripes_data = process_img_crop(img_crop, nm_per_px=nm_per_px)
|
283 |
+
if stripes_data['direction confidence'] >= direction_confidence:
|
284 |
+
best_stripes_data = deepcopy(stripes_data)
|
285 |
+
direction_confidence = stripes_data['direction confidence']
|
286 |
+
if direction_confidence > direction_thr_enough:
|
287 |
+
break
|
288 |
+
|
289 |
+
result = best_stripes_data
|
290 |
+
|
291 |
+
if direction_confidence >= direction_thr_min:
|
292 |
+
|
293 |
+
mask_oriented = rotate(mask, 90-result['direction'], resize=True).astype('bool')
|
294 |
+
idx_begin_x, idx_end_x = np.where(np.any(mask_oriented, axis=0))[0][np.array([0, -1])]
|
295 |
+
idx_begin_y, idx_end_y = np.where(np.any(mask_oriented, axis=1))[0][np.array([0, -1])]
|
296 |
+
result['mask_oriented'] = mask_oriented[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x]
|
297 |
+
result['img_oriented'] = rotate(img, 90-result['direction'], resize=True)[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x]
|
298 |
+
|
299 |
+
# measurements = measure_granum_shape(result['mask_oriented'], nm_per_px=nm_per_px, oriented=True)
|
300 |
+
# else:
|
301 |
+
# measurements = measure_granum_shape(mask, nm_per_px=nm_per_px, oriented=False)
|
302 |
+
|
303 |
+
# result.update(**measurements)
|
304 |
+
# N_layers = result['height'] / result['period']
|
305 |
+
# if np.isfinite(N_layers):
|
306 |
+
# N_layers = round(N_layers)
|
307 |
+
|
308 |
+
return result
|
309 |
+
|
310 |
+
# def measure_object(
|
311 |
+
# img, mask,
|
312 |
+
# nm_per_px=1, n_tries = 3,
|
313 |
+
# direction_thr_min = 0.07, direction_thr_enough = 0.1,
|
314 |
+
# crop_size = 200,
|
315 |
+
# **kwargs):
|
316 |
+
|
317 |
+
# distance_map = calculate_distance_map(mask)
|
318 |
+
# crop_size = min(crop_size, int((min(get_top_k(distance_map.flatten(), n_tries)*0.5)**0.5)))
|
319 |
+
|
320 |
+
# direction_confidence = 0
|
321 |
+
# best_stripes_data = {}
|
322 |
+
# for i, img_crop in enumerate(select_samples(img, mask, crop_size=crop_size, n_samples=n_tries)):
|
323 |
+
# stripes_data = process_img_crop(img_crop, nm_per_px=nm_per_px)
|
324 |
+
# if stripes_data['direction_confidence'] >= direction_confidence:
|
325 |
+
# best_stripes_data = deepcopy(stripes_data)
|
326 |
+
# direction_confidence = stripes_data['direction_confidence']
|
327 |
+
# if direction_confidence > direction_thr_enough:
|
328 |
+
# break
|
329 |
+
|
330 |
+
# result = best_stripes_data
|
331 |
+
|
332 |
+
# if direction_confidence >= direction_thr_min:
|
333 |
+
|
334 |
+
# mask_oriented = rotate(mask, 90-result['direction'], resize=True).astype('bool')
|
335 |
+
# idx_begin_x, idx_end_x = np.where(np.any(mask_oriented, axis=0))[0][np.array([0, -1])]
|
336 |
+
# idx_begin_y, idx_end_y = np.where(np.any(mask_oriented, axis=1))[0][np.array([0, -1])]
|
337 |
+
# result['mask_oriented'] = mask_oriented[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x]
|
338 |
+
# result['img_oriented'] = rotate(img, 90-result['direction'], resize=True)[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x]
|
339 |
+
|
340 |
+
# # measurements = measure_granum_shape(result['mask_oriented'], nm_per_px=nm_per_px, oriented=True)
|
341 |
+
# # else:
|
342 |
+
# # measurements = measure_granum_shape(mask, nm_per_px=nm_per_px, oriented=False)
|
343 |
+
|
344 |
+
# # result.update(**measurements)
|
345 |
+
# # N_layers = result['height'] / result['period']
|
346 |
+
# # if np.isfinite(N_layers):
|
347 |
+
# # N_layers = round(N_layers)
|
348 |
+
|
349 |
+
# return result #{**measurements, **best_stripes_data, 'N layers': N_layers}
|
angle_calculation/envelope_correction.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import scipy
|
3 |
+
|
4 |
+
def detect_boundaries(mask, axis):
|
5 |
+
# calculate the boundaries of the mask
|
6 |
+
#axis = 0 results in x_from, x_to
|
7 |
+
#axis = 1 results in y_from, y_to
|
8 |
+
|
9 |
+
|
10 |
+
sum = mask.sum(axis=axis)
|
11 |
+
|
12 |
+
ind_from = min(sum.nonzero()[0])
|
13 |
+
ind_to = max(sum.nonzero()[0])
|
14 |
+
return ind_from, ind_to
|
15 |
+
|
16 |
+
def area(mask):
|
17 |
+
x1, y1 = detect_boundaries(mask, 0)
|
18 |
+
a = y1 - x1
|
19 |
+
x2, y2 = detect_boundaries(mask, 1)
|
20 |
+
b = y2 - x2
|
21 |
+
|
22 |
+
return (a * b, x1, y1, x2, y2)
|
23 |
+
|
24 |
+
def calculate_best_angle_from_mask(mask, angles=np.arange(-10,10,0.5)):
|
25 |
+
areas = []
|
26 |
+
for angle in angles:
|
27 |
+
rotated_mask = scipy.ndimage.rotate(mask, angle, reshape=True, order = 0) # order = 0 is the nearest neighbor interpolation, so the mask is not interpolated
|
28 |
+
this_area = area(rotated_mask)
|
29 |
+
areas.append(this_area[0])
|
30 |
+
|
31 |
+
best_angle = angles[np.argmin(areas)]
|
32 |
+
return best_angle
|
33 |
+
|
34 |
+
|
angle_calculation/granum_utils.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image, ImageDraw, ImageFont
|
2 |
+
import numpy as np
|
3 |
+
from scipy import ndimage
|
4 |
+
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from typing import Any, List
|
7 |
+
from zipfile import ZipFile
|
8 |
+
|
9 |
+
def add_text(image: Image.Image, text: str, location=(0.5, 0.5), color='red', size=40) -> Image.Image:
|
10 |
+
draw = ImageDraw.Draw(image)
|
11 |
+
font = ImageFont.load_default(size=size)
|
12 |
+
draw.text((int(image.size[0]*location[0]), int(image.size[1]*location[1])), text, font=font, fill=color)
|
13 |
+
return image
|
14 |
+
|
15 |
+
|
16 |
+
def select_unique_mask(mask):
|
17 |
+
"""if mask consists of multiple parts, select the largest"""
|
18 |
+
blobs = ndimage.label(mask)[0]
|
19 |
+
blob_labels, blob_sizes = np.unique(blobs, return_counts=True)
|
20 |
+
best_blob_label = blob_labels[1:][np.argmax(blob_sizes[1:])]
|
21 |
+
return blobs == best_blob_label
|
22 |
+
|
23 |
+
def object_slice(mask, margin=128):
|
24 |
+
rows = np.any(mask, axis=1)
|
25 |
+
cols = np.any(mask, axis=0)
|
26 |
+
row_min, row_max = np.where(rows)[0][[0, -1]]
|
27 |
+
col_min, col_max = np.where(cols)[0][[0, -1]]
|
28 |
+
|
29 |
+
# Create a slice object for the bounding box
|
30 |
+
bounding_box_slice = (
|
31 |
+
slice(max(0,row_min-margin), min(row_max + 1+margin, len(rows)+1)),
|
32 |
+
slice(max(0,col_min-margin), min(col_max + 1+margin, len(cols)+1))
|
33 |
+
)
|
34 |
+
|
35 |
+
return bounding_box_slice
|
36 |
+
|
37 |
+
def resize_to(image: Image.Image, s=4032) -> Image.Image:
|
38 |
+
w, h = image.size
|
39 |
+
longest_size = max(h, w)
|
40 |
+
|
41 |
+
resize_factor = longest_size / s
|
42 |
+
|
43 |
+
resized_image = image.resize((int(w/resize_factor), int(h/resize_factor)))
|
44 |
+
return resized_image
|
45 |
+
|
46 |
+
def rolling_mean(x, window):
|
47 |
+
cs = np.r_[0, np.cumsum(x)]
|
48 |
+
rolling_sum = cs[window:] - cs[:-window]
|
49 |
+
return rolling_sum/window
|
50 |
+
|
51 |
+
@dataclass
|
52 |
+
class Granum:
|
53 |
+
image: Any = None#Optional[np.ndarray] = None
|
54 |
+
mask: Any = None #Optional[np.ndarray] = None
|
55 |
+
scaler: Any = None
|
56 |
+
nm_per_px: float = float('nan')
|
57 |
+
detection_confidence: float = float('nan')
|
58 |
+
|
59 |
+
def zip_files(files: List[str], output_name: str) -> None:
|
60 |
+
with ZipFile(output_name, "w") as zipObj:
|
61 |
+
for file in files:
|
62 |
+
zipObj.write(file)
|
63 |
+
|
64 |
+
def filter_boundary_detections(masks, scaler=None):
|
65 |
+
last_index_right = -1 if scaler is None else masks.shape[1]-1-scaler.pad_right
|
66 |
+
last_index_bottom = -1 if scaler is None else masks.shape[2]-1-scaler.pad_bottom
|
67 |
+
doesnt_touch_boundary_mask = ~(np.any(masks[:,0,:] != 0, axis=1) | np.any(masks[:,last_index_right:,:] != 0, axis=(1,2)) | np.any(masks[:,:,0] != 0, axis=1) | np.any(masks[:,:,last_index_bottom:] != 0, axis=(1,2)))
|
68 |
+
return doesnt_touch_boundary_mask
|
69 |
+
|
70 |
+
def get_circle_mask(shape, r=None):
|
71 |
+
if isinstance(shape, int):
|
72 |
+
shape = (shape, shape)
|
73 |
+
if r is None:
|
74 |
+
r = min(shape)/2
|
75 |
+
X, Y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
|
76 |
+
center_x = shape[1] / 2 - 0.5
|
77 |
+
center_y = shape[0] / 2 - 0.5
|
78 |
+
|
79 |
+
mask = ((X-center_x)**2 + (Y-center_y)**2) >= r**2
|
80 |
+
return mask
|
angle_calculation/image_transforms.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
|
4 |
+
def batched_radon(image_batch):
|
5 |
+
batch_size, img_size = image_batch.shape[:2]
|
6 |
+
if batch_size > 512: # limit batch size to 512 because cv2.warpAffine fails for batch> 512
|
7 |
+
return np.concatenate([batched_radon(image_batch[i:i+512]) for i in range(0,batch_size,512)], axis=0)
|
8 |
+
theta = np.arange(180)
|
9 |
+
radon_image = np.zeros((image_batch.shape[0], img_size, len(theta)),
|
10 |
+
dtype='float32')
|
11 |
+
|
12 |
+
for i, angle in enumerate(theta):
|
13 |
+
M = cv2.getRotationMatrix2D(((img_size-1)/2.0,(img_size-1)/2.0),angle,1)
|
14 |
+
rotated = cv2.warpAffine(np.transpose(image_batch, (1, 2, 0)),M,(img_size,img_size))
|
15 |
+
if batch_size == 1: # cv2.warpAffine cancels batch dimension if equal to 1
|
16 |
+
rotated = rotated[:,:, np.newaxis]
|
17 |
+
rotated = np.transpose(rotated, (2, 0, 1))
|
18 |
+
rotated = rotated / np.array(255, dtype='float32')
|
19 |
+
radon_image[:, :, i] = rotated.sum(axis=1)
|
20 |
+
return radon_image
|
21 |
+
|
22 |
+
def get_center_crop_coords(height: int, width: int, crop_height: int, crop_width: int):
|
23 |
+
"""from https://github.com/albumentations-team/albumentations/blob/master/albumentations/augmentations/crops/functional.py"""
|
24 |
+
y1 = (height - crop_height) // 2
|
25 |
+
y2 = y1 + crop_height
|
26 |
+
x1 = (width - crop_width) // 2
|
27 |
+
x2 = x1 + crop_width
|
28 |
+
return x1, y1, x2, y2
|
29 |
+
|
30 |
+
def center_crop(img: np.ndarray, crop_height: int, crop_width: int):
|
31 |
+
height, width = img.shape[:2]
|
32 |
+
x1, y1, x2, y2 = get_center_crop_coords(height, width, crop_height, crop_width)
|
33 |
+
img = img[y1:y2, x1:x2]
|
34 |
+
return img
|
angle_calculation/sampling.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from PIL import Image
|
2 |
+
from typing import Tuple
|
3 |
+
import numpy as np
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from scipy import ndimage
|
6 |
+
import torch
|
7 |
+
from torchvision.transforms import functional as tvf
|
8 |
+
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
def sliced_mean(x, slice_size):
|
12 |
+
cs_y = np.cumsum(x, axis=0)
|
13 |
+
cs_y = np.concatenate((np.zeros((1, cs_y.shape[1]), dtype=cs_y.dtype), cs_y), axis=0)
|
14 |
+
slices_y = (cs_y[slice_size:] - cs_y[:-slice_size])/slice_size
|
15 |
+
cs_xy = np.cumsum(slices_y, axis=1)
|
16 |
+
cs_xy = np.concatenate((np.zeros((cs_xy.shape[0], 1), dtype=cs_xy.dtype), cs_xy), axis=1)
|
17 |
+
slices_xy = (cs_xy[:,slice_size:] - cs_xy[:,:-slice_size])/slice_size
|
18 |
+
return slices_xy
|
19 |
+
|
20 |
+
def sliced_var(x, slice_size):
|
21 |
+
x = x.astype('float64')
|
22 |
+
return sliced_mean(x**2, slice_size) - sliced_mean(x, slice_size)**2
|
23 |
+
|
24 |
+
def calculate_local_variance(img, var_window):
|
25 |
+
"""return local variance map with the same size as input image"""
|
26 |
+
var = sliced_var(img, var_window)
|
27 |
+
|
28 |
+
left_pad = var_window // 2 -1
|
29 |
+
right_pad = var_window -1 - left_pad
|
30 |
+
var_padded = np.pad(
|
31 |
+
var,
|
32 |
+
pad_width=(
|
33 |
+
(left_pad,right_pad),
|
34 |
+
(left_pad,right_pad)
|
35 |
+
))
|
36 |
+
return var_padded
|
37 |
+
|
38 |
+
def get_crop_batch(img: np.ndarray, mask: np.ndarray, crop_size=96, crop_scales=np.geomspace(0.5, 2, 7), samples_per_scale=32, use_variance_threshold=False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
39 |
+
"""
|
40 |
+
Generate a batch of cropped images from an input image and corresponding mask, at various scales and rotations.
|
41 |
+
|
42 |
+
Parameters
|
43 |
+
----------
|
44 |
+
img : np.ndarray
|
45 |
+
The input image from which crops are generated.
|
46 |
+
mask : np.ndarray
|
47 |
+
The binary mask indicating the region of interest in the image.
|
48 |
+
crop_size : int, optional
|
49 |
+
The size of the square crop.
|
50 |
+
crop_scales : np.ndarray, optional
|
51 |
+
An array of scale factors to apply to the crop size.
|
52 |
+
samples_per_scale : int, optional
|
53 |
+
Number of samples to generate per scale factor.
|
54 |
+
use_variance_threshold : bool, optional
|
55 |
+
Flag to use variance thresholding for selecting crop locations.
|
56 |
+
|
57 |
+
Returns
|
58 |
+
-------
|
59 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
60 |
+
A tuple containing the tensor of crops, their rotation angles, and scale factors.
|
61 |
+
"""
|
62 |
+
|
63 |
+
# pad
|
64 |
+
pad_size = int(np.ceil(0.5*crop_size*max(crop_scales)*(np.sqrt(2)-1)))
|
65 |
+
img_padded = np.pad(img, pad_size)
|
66 |
+
mask_padded = np.pad(mask, pad_size)
|
67 |
+
|
68 |
+
# distance map
|
69 |
+
distance_map_padded = ndimage.distance_transform_edt(mask_padded)
|
70 |
+
# TODO: adjust scales and samples_per_scale
|
71 |
+
|
72 |
+
if use_variance_threshold:
|
73 |
+
variance_window = min(crop_size//2, min(img.shape))
|
74 |
+
variance_map_padded = np.pad(calculate_local_variance(img, variance_window), pad_size)
|
75 |
+
variance_median = np.ma.median(np.ma.masked_where(distance_map_padded<0.5*variance_window, variance_map_padded))
|
76 |
+
variance_mask = variance_map_padded >= variance_median
|
77 |
+
else:
|
78 |
+
variance_mask = np.ones_like(mask_padded)
|
79 |
+
|
80 |
+
# initilize output
|
81 |
+
crops_granum = []
|
82 |
+
angles_granum = []
|
83 |
+
scales_granum = []
|
84 |
+
# loop over scales
|
85 |
+
for scale in crop_scales:
|
86 |
+
half_crop_size_scaled = int(np.floor(scale*0.5*crop_size)) # half of crop size after scaling
|
87 |
+
crop_pad = int(np.ceil((np.sqrt(2) - 1)*half_crop_size_scaled)) # pad added in order to allow rotation
|
88 |
+
half_crop_size_external = half_crop_size_scaled + crop_pad # size of "external crop" which will be rotated
|
89 |
+
|
90 |
+
possible_indices = np.stack(np.where(variance_mask & (distance_map_padded >= 2*half_crop_size_scaled)), axis=1)
|
91 |
+
if len(possible_indices) == 0:
|
92 |
+
continue
|
93 |
+
chosen_indices = np.random.choice(np.arange(len(possible_indices)), min(len(possible_indices), samples_per_scale), replace=False)
|
94 |
+
|
95 |
+
crops = [
|
96 |
+
img_padded[y-half_crop_size_external:y+half_crop_size_external, x-half_crop_size_external:x+half_crop_size_external] for y, x in possible_indices[chosen_indices]
|
97 |
+
]
|
98 |
+
|
99 |
+
# rotate
|
100 |
+
rotation_angles = np.random.rand(len(crops))*180 - 90
|
101 |
+
crops = [
|
102 |
+
ndimage.rotate(crop, angle, reshape=False)[crop_pad:-crop_pad,crop_pad:-crop_pad] for crop, angle in zip(crops, rotation_angles)
|
103 |
+
]
|
104 |
+
# add to output
|
105 |
+
crops_granum.append(tvf.resize(torch.tensor(np.array(crops)), (crop_size,crop_size),antialias=True)) # resize crops to crop_size
|
106 |
+
angles_granum.extend(rotation_angles.tolist())
|
107 |
+
scales_granum.extend([scale]*len(crops))
|
108 |
+
|
109 |
+
if len(angles_granum) == 0:
|
110 |
+
return [], [], []
|
111 |
+
|
112 |
+
crops_granum = torch.concat(crops_granum)
|
113 |
+
angles_granum = torch.tensor(angles_granum, dtype=torch.float)
|
114 |
+
scales_granum = torch.tensor(scales_granum, dtype=torch.float)
|
115 |
+
|
116 |
+
return crops_granum, angles_granum, scales_granum
|
117 |
+
|
118 |
+
def get_crop_batch_from_path(img_path, mask_path=None, use_variance_threshold=False):
|
119 |
+
"""
|
120 |
+
Load an image and its mask from file paths and generate a batch of cropped images.
|
121 |
+
|
122 |
+
Parameters
|
123 |
+
----------
|
124 |
+
img_path : str
|
125 |
+
Path to the input image.
|
126 |
+
mask_path : str, optional
|
127 |
+
Path to the binary mask image. If None, assumes mask path by replacing image extension with '.npy'.
|
128 |
+
use_variance_threshold : bool, optional
|
129 |
+
Flag to use variance thresholding for selecting crop locations.
|
130 |
+
|
131 |
+
Returns
|
132 |
+
-------
|
133 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
|
134 |
+
A tuple containing the tensor of crops, their rotation angles, and scale factors, obtained from the specified image path.
|
135 |
+
"""
|
136 |
+
if mask_path is None:
|
137 |
+
mask_path = str(Path(img_path).with_suffix('.npy'))
|
138 |
+
mask = np.load(mask_path)
|
139 |
+
img = np.array(Image.open(img_path))[:,:,0]
|
140 |
+
|
141 |
+
return get_crop_batch(img, mask, use_variance_threshold=use_variance_threshold)
|
142 |
+
|
app.py
ADDED
@@ -0,0 +1,602 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import shutil
|
3 |
+
import time
|
4 |
+
import uuid
|
5 |
+
from datetime import datetime
|
6 |
+
from decimal import Decimal
|
7 |
+
|
8 |
+
import gradio as gr
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
|
11 |
+
from settings import DEMO
|
12 |
+
|
13 |
+
plt.switch_backend("agg") # fix for "RuntimeError: main thread is not in main loop"
|
14 |
+
import numpy as np
|
15 |
+
import pandas as pd
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
from model import GranaAnalyser
|
19 |
+
|
20 |
+
ga = GranaAnalyser(
|
21 |
+
"weights/yolo/20240604_yolov8_segm_ABRCR1_all_train4_best.pt",
|
22 |
+
"weights/AS_square_v16.ckpt",
|
23 |
+
"weights/period_measurer_weights-1.298_real_full-fa12970.ckpt",
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
def calc_ratio(pixels, nano):
|
28 |
+
"""
|
29 |
+
Calculates ratio of pixels to nanometers and returns as str to populate ratio_input
|
30 |
+
:param pixels:
|
31 |
+
:param nano:
|
32 |
+
:return:
|
33 |
+
"""
|
34 |
+
if not (pixels and nano):
|
35 |
+
pass
|
36 |
+
else:
|
37 |
+
res = pixels / nano
|
38 |
+
return res
|
39 |
+
|
40 |
+
|
41 |
+
# https://jakevdp.github.io/PythonDataScienceHandbook/05.13-kernel-density-estimation.html
|
42 |
+
def KDE(dataset, h):
|
43 |
+
# the Kernel function
|
44 |
+
def K(x):
|
45 |
+
return np.exp(-(x ** 2) / 2) / np.sqrt(2 * np.pi)
|
46 |
+
|
47 |
+
n_samples = dataset.size
|
48 |
+
|
49 |
+
x_range = dataset # x-value range for plotting KDEs
|
50 |
+
|
51 |
+
total_sum = 0
|
52 |
+
# iterate over datapoints
|
53 |
+
for i, xi in enumerate(dataset):
|
54 |
+
total_sum += K((x_range - xi) / h)
|
55 |
+
|
56 |
+
y_range = total_sum / (h * n_samples)
|
57 |
+
|
58 |
+
return y_range
|
59 |
+
|
60 |
+
|
61 |
+
def prepare_files_for_download(
|
62 |
+
dir_name,
|
63 |
+
grana_data,
|
64 |
+
aggregated_data,
|
65 |
+
detection_visualizations_dict,
|
66 |
+
images_grana_dict,
|
67 |
+
):
|
68 |
+
"""
|
69 |
+
Save and zip files for download
|
70 |
+
:param dir_name:
|
71 |
+
:param grana_data: DataFrame containing all grana measurements
|
72 |
+
:param aggregated_data: dict containing aggregated measurements
|
73 |
+
:return:
|
74 |
+
"""
|
75 |
+
dir_to_zip = f"{dir_name}/to_zip"
|
76 |
+
|
77 |
+
# raw data
|
78 |
+
grana_data_csv_path = f"{dir_to_zip}/grana_raw_data.csv"
|
79 |
+
grana_data.to_csv(grana_data_csv_path, index=False)
|
80 |
+
|
81 |
+
# aggregated measurements
|
82 |
+
aggregated_csv_path = f"{dir_to_zip}/grana_aggregated_data.csv"
|
83 |
+
aggregated_data.to_csv(aggregated_csv_path)
|
84 |
+
|
85 |
+
# annotated pictures
|
86 |
+
masked_images_dir = f"{dir_to_zip}/annotated_images"
|
87 |
+
os.makedirs(masked_images_dir)
|
88 |
+
for img_name, img in detection_visualizations_dict.items():
|
89 |
+
filename_split = img_name.split(".")
|
90 |
+
extension = filename_split[-1]
|
91 |
+
filename = ".".join(filename_split[:-1])
|
92 |
+
filename = f"{filename}_annotated.{extension}"
|
93 |
+
img.save(f"{masked_images_dir}/{filename}")
|
94 |
+
|
95 |
+
# single_grana images
|
96 |
+
grana_images_dir = f"{dir_to_zip}/single_grana_images"
|
97 |
+
os.makedirs(grana_images_dir)
|
98 |
+
org_images_dict = pd.Series(
|
99 |
+
grana_data["source image"].values, index=grana_data["granum ID"]
|
100 |
+
).to_dict()
|
101 |
+
for img_name, img in images_grana_dict.items():
|
102 |
+
org_filename = org_images_dict[img_name]
|
103 |
+
org_filename_split = org_filename.split(".")
|
104 |
+
org_filename_no_ext = ".".join(org_filename_split[:-1])
|
105 |
+
img_name_ext = f"{org_filename_no_ext}_granum_{str(img_name)}.png"
|
106 |
+
img.save(f"{grana_images_dir}/{img_name_ext}")
|
107 |
+
|
108 |
+
# zip all files
|
109 |
+
date_str = datetime.today().strftime("%Y-%m-%d")
|
110 |
+
zip_name = f"GRANA_results_{date_str}"
|
111 |
+
zip_path = f"{dir_name}/{zip_name}"
|
112 |
+
shutil.make_archive(zip_path, "zip", dir_to_zip)
|
113 |
+
|
114 |
+
# delete to_zip dir
|
115 |
+
zip_dir_path = os.path.join(os.getcwd(), dir_to_zip)
|
116 |
+
shutil.rmtree(zip_dir_path)
|
117 |
+
|
118 |
+
download_file_path = f"{zip_path}.zip"
|
119 |
+
return download_file_path
|
120 |
+
|
121 |
+
|
122 |
+
def show_info_on_submit(s):
|
123 |
+
return (
|
124 |
+
gr.Button(interactive=False),
|
125 |
+
gr.Button(interactive=False),
|
126 |
+
gr.Row(visible=True),
|
127 |
+
gr.Row(visible=False),
|
128 |
+
)
|
129 |
+
|
130 |
+
|
131 |
+
def load_css():
|
132 |
+
with open("styles.css", "r") as f:
|
133 |
+
css_content = f.read()
|
134 |
+
return css_content
|
135 |
+
|
136 |
+
|
137 |
+
primary_hue = gr.themes.Color(
|
138 |
+
c50="#e1f8ee",
|
139 |
+
c100="#b7efd5",
|
140 |
+
c200="#8de6bd",
|
141 |
+
c300="#63dda5",
|
142 |
+
c400="#39d48d",
|
143 |
+
c500="#27b373",
|
144 |
+
c600="#1e8958",
|
145 |
+
c700="#155f3d",
|
146 |
+
c800="#0c3522",
|
147 |
+
c900="#030b07",
|
148 |
+
c950="#000",
|
149 |
+
)
|
150 |
+
|
151 |
+
|
152 |
+
theme = gr.themes.Default(
|
153 |
+
primary_hue=primary_hue,
|
154 |
+
font=[gr.themes.GoogleFont("Ubuntu"), "ui-sans-serif", "system-ui", "sans-serif"],
|
155 |
+
)
|
156 |
+
|
157 |
+
|
158 |
+
def draw_violin_plot(y, ylabel, title):
|
159 |
+
# only generate plot for 3 or more values
|
160 |
+
if y.count() < 3:
|
161 |
+
return None
|
162 |
+
|
163 |
+
# Colors
|
164 |
+
RED_DARK = "#850e00"
|
165 |
+
DARK_GREEN = "#0c3522"
|
166 |
+
BRIGHT_GREEN = "#8de6bd"
|
167 |
+
|
168 |
+
# Create jittered version of "x" (which is only 1)
|
169 |
+
x_jittered = []
|
170 |
+
kde = KDE(y, (y.max() - y.min()) / y.size / 2)
|
171 |
+
kde = kde / kde.max() * 0.2
|
172 |
+
for y_val in kde:
|
173 |
+
x_jittered.append(1 + np.random.uniform(-y_val, y_val, 1))
|
174 |
+
|
175 |
+
fig = plt.figure()
|
176 |
+
ax = fig.add_subplot(1, 1, 1)
|
177 |
+
ax.scatter(x=x_jittered, y=y, s=20, alpha=0.4, c=DARK_GREEN)
|
178 |
+
|
179 |
+
violins = ax.violinplot(
|
180 |
+
y,
|
181 |
+
widths=0.45,
|
182 |
+
bw_method="silverman",
|
183 |
+
showmeans=False,
|
184 |
+
showmedians=False,
|
185 |
+
showextrema=False,
|
186 |
+
)
|
187 |
+
|
188 |
+
# change violin color
|
189 |
+
for pc in violins["bodies"]:
|
190 |
+
pc.set_facecolor(BRIGHT_GREEN)
|
191 |
+
|
192 |
+
# add a boxplot to ax
|
193 |
+
# but make the whiskers length equal to 1 SD, i.e. in the proportion of the IQ range, but this length should start from the mean but be visible from the box boundary
|
194 |
+
lower = np.mean(y) - 1 * np.std(y)
|
195 |
+
upper = np.mean(y) + 1 * np.std(y)
|
196 |
+
|
197 |
+
medianprops = dict(linewidth=1, color="black", solid_capstyle="butt")
|
198 |
+
boxplot_stats = [
|
199 |
+
{
|
200 |
+
"med": np.median(y),
|
201 |
+
"q1": np.percentile(y, 25),
|
202 |
+
"q3": np.percentile(y, 75),
|
203 |
+
"whislo": lower,
|
204 |
+
"whishi": upper,
|
205 |
+
}
|
206 |
+
]
|
207 |
+
|
208 |
+
ax.bxp(
|
209 |
+
boxplot_stats, # data for the boxplot
|
210 |
+
showfliers=False, # do not show the outliers beyond the caps.
|
211 |
+
showcaps=True, # show the caps
|
212 |
+
medianprops=medianprops,
|
213 |
+
)
|
214 |
+
|
215 |
+
# Add mean value point
|
216 |
+
ax.scatter(1, y.mean(), s=30, color=RED_DARK, zorder=3)
|
217 |
+
|
218 |
+
ax.set_xticks([])
|
219 |
+
ax.set_ylabel(ylabel)
|
220 |
+
ax.set_title(title)
|
221 |
+
fig.tight_layout()
|
222 |
+
|
223 |
+
return fig
|
224 |
+
|
225 |
+
|
226 |
+
def transform_aggregated_results_table(results_dict):
|
227 |
+
MEASUREMENT_HEADER = "measurement [unit]"
|
228 |
+
VALUE_HEADER = "value +-SD"
|
229 |
+
|
230 |
+
def get_value_str(value, std):
|
231 |
+
if np.isnan(value) or np.isnan(std):
|
232 |
+
return "-"
|
233 |
+
value_str = str(Decimal(str(value)).quantize(Decimal("0.01")))
|
234 |
+
std_str = str(Decimal(str(std)).quantize(Decimal("0.01")))
|
235 |
+
return f"{value_str} +-{std_str}"
|
236 |
+
|
237 |
+
def append_to_dict(new_key, old_val_key, old_sd_key):
|
238 |
+
aggregated_dict[MEASUREMENT_HEADER].append(new_key)
|
239 |
+
value_str = get_value_str(results_dict[old_val_key], results_dict[old_sd_key])
|
240 |
+
aggregated_dict[VALUE_HEADER].append(value_str)
|
241 |
+
|
242 |
+
aggregated_dict = {MEASUREMENT_HEADER: [], VALUE_HEADER: []}
|
243 |
+
|
244 |
+
# area
|
245 |
+
append_to_dict("area [nm^2]", "area nm^2", "area nm^2 std")
|
246 |
+
|
247 |
+
# perimeter
|
248 |
+
append_to_dict("perimeter [nm]", "perimeter nm", "perimeter nm std")
|
249 |
+
|
250 |
+
# diameter
|
251 |
+
append_to_dict("diameter [nm]", "diameter nm", "diameter nm std")
|
252 |
+
|
253 |
+
# height
|
254 |
+
append_to_dict("height [nm]", "height nm", "height nm std")
|
255 |
+
|
256 |
+
# number of layers
|
257 |
+
append_to_dict("number of thylakoids", "Number of layers", "Number of layers std")
|
258 |
+
|
259 |
+
# SRD
|
260 |
+
append_to_dict("SRD [nm]", "period nm", "period nm std")
|
261 |
+
|
262 |
+
# GSI
|
263 |
+
append_to_dict("GSI", "GSI", "GSI std")
|
264 |
+
|
265 |
+
# N grana
|
266 |
+
aggregated_dict[MEASUREMENT_HEADER].append("number of grana")
|
267 |
+
aggregated_dict[VALUE_HEADER].append(str(int(results_dict["N grana"])))
|
268 |
+
|
269 |
+
return aggregated_dict
|
270 |
+
|
271 |
+
|
272 |
+
def rename_columns_in_results_table(results_table):
|
273 |
+
column_names = {
|
274 |
+
"Granum ID": "granum ID",
|
275 |
+
"File name": "source image",
|
276 |
+
"area nm^2": "area [nm^2]",
|
277 |
+
"perimeter nm": "perimeter [nm]",
|
278 |
+
"diameter nm": "diameter [nm]",
|
279 |
+
"height nm": "height [nm]",
|
280 |
+
"Number of layers": "number of thylakoids",
|
281 |
+
"period nm": "SRD [nm]",
|
282 |
+
"period SD nm": "SRD SD [nm]",
|
283 |
+
}
|
284 |
+
results_table = results_table.rename(columns=column_names)
|
285 |
+
return results_table
|
286 |
+
|
287 |
+
|
288 |
+
with gr.Blocks(css=load_css(), theme=theme) as demo:
|
289 |
+
|
290 |
+
svg = """
|
291 |
+
<svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 30.73 33.38">
|
292 |
+
<defs>
|
293 |
+
<style>
|
294 |
+
.cls-1 {
|
295 |
+
fill: #27b373;
|
296 |
+
stroke-width: 0px;
|
297 |
+
}
|
298 |
+
</style>
|
299 |
+
</defs>
|
300 |
+
<path class="cls-1" d="M19.69,11.73h-3.22c-2.74,0-4.96,2.22-4.96,4.96h0c0,2.74,2.22,4.96,4.96,4.96h3.43c.56,0,1,.51.89,1.09-.08.43-.49.72-.92.72h-8.62c-.74,0-1.34-.6-1.34-1.34v-10.87c0-.74.6-1.34,1.34-1.34h13.44c2.73,0,4.95-2.22,4.95-4.95h0c0-2.75-2.22-4.97-4.96-4.97h-13.85C4.85,0,0,4.85,0,10.83v11.71c0,5.98,4.85,10.83,10.83,10.83h9.07c5.76,0,10.49-4.52,10.81-10.21.35-6.29-4.72-11.44-11.02-11.44ZM19.9,31.4h-9.07c-4.89,0-8.85-3.96-8.85-8.85v-11.71C1.98,5.95,5.95,1.98,10.83,1.98h13.81c1.64,0,2.97,1.33,2.97,2.97h0c0,1.65-1.33,2.97-2.96,2.97h-13.4c-1.83,0-3.32,1.49-3.32,3.32v10.87c0,1.83,1.49,3.32,3.32,3.32h8.56c1.51,0,2.83-1.12,2.97-2.62.16-1.72-1.2-3.16-2.88-3.16h-3.52c-1.64,0-2.97-1.33-2.97-2.97h0c0-1.64,1.33-2.97,2.97-2.97h3.34c4.83,0,8.9,3.81,9.01,8.64s-3.9,9.04-8.84,9.04Z"/>
|
301 |
+
<path class="cls-1" d="M19.9,29.41h-9.07c-3.79,0-6.87-3.07-6.87-6.87v-11.71c0-3.79,3.07-6.87,6.87-6.87h13.81c.55,0,.99.44.99.99h0c0,.55-.44.99-.99.99h-13.81c-2.7,0-4.88,2.19-4.88,4.88v11.71c0,2.7,2.19,4.88,4.88,4.88h8.94c2.64,0,4.91-2.05,5-4.7s-2.12-5.05-4.87-5.05h-3.52c-.55,0-.99-.44-.99-.99h0c0-.55.44-.99.99-.99h3.36c3.74,0,6.9,2.92,7.01,6.66.11,3.87-3.01,7.06-6.85,7.06Z"/>
|
302 |
+
</svg>
|
303 |
+
"""
|
304 |
+
|
305 |
+
gr.HTML(
|
306 |
+
f'<div class="header"><div id="header-logo">{svg}</div><div id="header-text">GRANA<div></div>'
|
307 |
+
)
|
308 |
+
|
309 |
+
with gr.Row(elem_classes="input-row"): # input
|
310 |
+
with gr.Column():
|
311 |
+
gr.HTML(
|
312 |
+
"<h1>1. Choose images to upload. All the images need to be of the same scale and experimental variant.</h1>"
|
313 |
+
)
|
314 |
+
img_input = gr.File(file_count="multiple")
|
315 |
+
|
316 |
+
gr.HTML("<h1>2. Set the scale of the images for the measurements.</h1>")
|
317 |
+
with gr.Row():
|
318 |
+
with gr.Column():
|
319 |
+
gr.HTML("Either provide pixel per nanometer ratio...")
|
320 |
+
ratio_input = gr.Number(
|
321 |
+
label="pixel per nm", precision=3, step=0.001
|
322 |
+
)
|
323 |
+
|
324 |
+
with gr.Column():
|
325 |
+
gr.HTML("...or length of the scale bar in pixels and nanometers.")
|
326 |
+
pixels_input = gr.Number(label="Length in pixels")
|
327 |
+
nano_input = gr.Number(label="Length in nanometers")
|
328 |
+
|
329 |
+
pixels_input.change(
|
330 |
+
calc_ratio,
|
331 |
+
inputs=[pixels_input, nano_input],
|
332 |
+
outputs=ratio_input,
|
333 |
+
)
|
334 |
+
nano_input.change(
|
335 |
+
calc_ratio,
|
336 |
+
inputs=[pixels_input, nano_input],
|
337 |
+
outputs=ratio_input,
|
338 |
+
)
|
339 |
+
|
340 |
+
with gr.Row():
|
341 |
+
clear_btn = gr.ClearButton(img_input, "Clear")
|
342 |
+
submit_btn = gr.Button("Submit", variant="primary")
|
343 |
+
|
344 |
+
with gr.Row(visible=False) as loading_row:
|
345 |
+
with gr.Column():
|
346 |
+
gr.HTML(
|
347 |
+
"<div class='processed-info'>Images are being processed. This may take a while...</div>"
|
348 |
+
)
|
349 |
+
|
350 |
+
with gr.Row(visible=False) as output_row:
|
351 |
+
with gr.Column():
|
352 |
+
gr.HTML(
|
353 |
+
'<div class="results-header">Results</div>'
|
354 |
+
"<p>Full results are a zip file containing:<p>"
|
355 |
+
"<ul>- grana_raw_data.csv: a table with full grana measurements,</ul>"
|
356 |
+
"<ul>- grana_aggregated_data.csv: a table with aggregated measurements,</ul>"
|
357 |
+
'<ul>- directory "annotated_images" with all submitted images with masks on detected grana,</ul>'
|
358 |
+
'<ul>- directory "single_grana_images" with images of all detected grana.</ul>'
|
359 |
+
"<p>Note that GRANA only stores the result files for 1 hour.</p>",
|
360 |
+
elem_classes="input-row",
|
361 |
+
)
|
362 |
+
with gr.Row(elem_classes="input-row"):
|
363 |
+
download_file_out = gr.DownloadButton(
|
364 |
+
label="Download results",
|
365 |
+
variant="primary",
|
366 |
+
elem_classes="margin-bottom",
|
367 |
+
)
|
368 |
+
with gr.Row():
|
369 |
+
gr.HTML(
|
370 |
+
'<h2 class="title">Annotated images</h2>'
|
371 |
+
"Gallery of uploaded images with masks of recognized grana structures. "
|
372 |
+
"Each granum mask is "
|
373 |
+
"labeled with its number. Note that only fully visible grana in the image are masked."
|
374 |
+
)
|
375 |
+
with gr.Row(elem_classes="margin-bottom"):
|
376 |
+
gallery_out = gr.Gallery(
|
377 |
+
columns=4,
|
378 |
+
rows=2,
|
379 |
+
object_fit="contain",
|
380 |
+
label="Detection visualizations",
|
381 |
+
show_download_button=False,
|
382 |
+
)
|
383 |
+
|
384 |
+
with gr.Row(elem_classes="input-row"):
|
385 |
+
gr.HTML(
|
386 |
+
'<h2 class="title">Aggregated results for all uploaded images</h2>'
|
387 |
+
)
|
388 |
+
with gr.Row(elem_classes=["input-row", "margin-bottom"]):
|
389 |
+
table_out = gr.Dataframe(label="Aggregated data")
|
390 |
+
|
391 |
+
with gr.Row():
|
392 |
+
gr.HTML(
|
393 |
+
'<h2 class="title">Violin graphs</h2>'
|
394 |
+
"These graphs present aggregated results for selected structural parameters. "
|
395 |
+
"The graph for each parameter is only generated if three or more values are available. "
|
396 |
+
"Each graph "
|
397 |
+
"displays individual data points, a box plot indicating the first and third quartiles, whiskers "
|
398 |
+
"marking the standard deviation (SD), the median value (horizontal line on the box plot), "
|
399 |
+
"the mean value (red dot), and a density plot where the width represents the frequency."
|
400 |
+
)
|
401 |
+
|
402 |
+
with gr.Row():
|
403 |
+
area_plot_out = gr.Plot(label="Area")
|
404 |
+
perimeter_plot_out = gr.Plot(label="Perimeter")
|
405 |
+
gsi_plot_out = gr.Plot(label="GSI")
|
406 |
+
|
407 |
+
with gr.Row(elem_classes="margin-bottom"):
|
408 |
+
diameter_plot_out = gr.Plot(label="Diameter")
|
409 |
+
height_plot_out = gr.Plot(label="Height")
|
410 |
+
srd_plot_out = gr.Plot(label="SRD")
|
411 |
+
|
412 |
+
with gr.Row():
|
413 |
+
gr.HTML(
|
414 |
+
'<h2 class="title">Recognized and rotated grana structures</h2>'
|
415 |
+
)
|
416 |
+
|
417 |
+
with gr.Row(elem_classes="margin-bottom"):
|
418 |
+
gallery_single_grana_out = gr.Gallery(
|
419 |
+
columns=4,
|
420 |
+
rows=2,
|
421 |
+
object_fit="contain",
|
422 |
+
label="Single grana images",
|
423 |
+
show_download_button=False,
|
424 |
+
)
|
425 |
+
|
426 |
+
with gr.Row():
|
427 |
+
gr.HTML(
|
428 |
+
'<h2 class="title">Full results</h2>'
|
429 |
+
"Note that structural parameters other than area and perimeter are only calculated for the grana "
|
430 |
+
"whose direction and/or SRD could be estimated."
|
431 |
+
)
|
432 |
+
with gr.Row():
|
433 |
+
table_full_out = gr.Dataframe(label="Full measurements data")
|
434 |
+
|
435 |
+
submit_btn.click(
|
436 |
+
show_info_on_submit,
|
437 |
+
inputs=[submit_btn],
|
438 |
+
outputs=[submit_btn, clear_btn, loading_row, output_row],
|
439 |
+
)
|
440 |
+
|
441 |
+
def enable_submit():
|
442 |
+
return (
|
443 |
+
gr.Button(interactive=True),
|
444 |
+
gr.Button(interactive=True),
|
445 |
+
gr.Row(visible=False),
|
446 |
+
)
|
447 |
+
|
448 |
+
def gradio_analize_image(images, scale):
|
449 |
+
"""
|
450 |
+
Model accepts following parameters:
|
451 |
+
:param images: list of images to be processed, in either tiff or png format
|
452 |
+
:param scale: float, nm to pixel ratio
|
453 |
+
|
454 |
+
Model returns the following objects:
|
455 |
+
- detection_visualizations: list of images with masks to be displayed as gallery and served to download
|
456 |
+
as zip of images
|
457 |
+
- grana_data: dataframe with measurements for each image to be served to download as a csv file
|
458 |
+
- images_grana: list of images with single grana to be served to download as zip of images
|
459 |
+
- aggregated_data: dataframe with aggregated measurements for all images to be displayed as table and served
|
460 |
+
to download as csv
|
461 |
+
"""
|
462 |
+
|
463 |
+
# validate that at least one image has been uploaded
|
464 |
+
if images is None or len(images) == 0:
|
465 |
+
raise gr.Error("Please upload at least one image")
|
466 |
+
|
467 |
+
# on demo instance, we limit the number of images to 5
|
468 |
+
if DEMO:
|
469 |
+
if len(images) > 5:
|
470 |
+
raise gr.Error("In demo version it is possible to analyze up to 5 images.")
|
471 |
+
|
472 |
+
# validate that scale has been provided correctly
|
473 |
+
if scale is None or scale == 0:
|
474 |
+
raise gr.Error("Please provide scale. Use dot as decimal separator")
|
475 |
+
|
476 |
+
# validate that all images are png or tiff
|
477 |
+
for image in images:
|
478 |
+
if not image.name.lower().endswith((".png", ".tif", ".jpg", ".jpeg")):
|
479 |
+
raise gr.Error("Only png, tiff, jpg ang jpeg images are supported")
|
480 |
+
|
481 |
+
# clean up previous results
|
482 |
+
# find all directories in current working directory that start with "results_"
|
483 |
+
# that were created more than 1 hour ago and delete them with all contents
|
484 |
+
for directory_name in os.listdir():
|
485 |
+
if directory_name.startswith("results_"):
|
486 |
+
dir_path = os.path.join(os.getcwd(), directory_name)
|
487 |
+
if os.path.isdir(dir_path):
|
488 |
+
if time.time() - os.path.getctime(dir_path) > 60 * 60:
|
489 |
+
shutil.rmtree(dir_path)
|
490 |
+
|
491 |
+
# create a directory for results
|
492 |
+
results_dir_name = "results_{uuid}".format(uuid=uuid.uuid4().hex)
|
493 |
+
os.makedirs(results_dir_name)
|
494 |
+
zip_dir_name = f"{results_dir_name}/to_zip"
|
495 |
+
os.makedirs(zip_dir_name)
|
496 |
+
|
497 |
+
# model takes a dict of images, so we need to convert input to list of PIL.PngImagePlugin.PngImageFile or
|
498 |
+
# PIL.TiffImagePlugin.TiffImageFile objects
|
499 |
+
images_dict = {
|
500 |
+
image.name.split("/")[-1]: Image.open(image.name)
|
501 |
+
for i, image in enumerate(images)
|
502 |
+
}
|
503 |
+
|
504 |
+
# model works here
|
505 |
+
(
|
506 |
+
detection_visualizations_dict,
|
507 |
+
grana_data,
|
508 |
+
images_grana_dict,
|
509 |
+
aggregated_data,
|
510 |
+
) = ga.predict(images_dict, scale)
|
511 |
+
detection_visualizations = list(detection_visualizations_dict.values())
|
512 |
+
images_grana = list(images_grana_dict.values())
|
513 |
+
|
514 |
+
# rearrange aggregated data to be displayed as table
|
515 |
+
aggregated_dict = transform_aggregated_results_table(aggregated_data)
|
516 |
+
aggregated_df_transposed = pd.DataFrame.from_dict(aggregated_dict)
|
517 |
+
|
518 |
+
# rename columns in full results
|
519 |
+
grana_data = rename_columns_in_results_table(grana_data)
|
520 |
+
|
521 |
+
# save files returned by model to disk so they can be retrieved for downloading
|
522 |
+
download_file_path = prepare_files_for_download(
|
523 |
+
results_dir_name,
|
524 |
+
grana_data,
|
525 |
+
aggregated_df_transposed,
|
526 |
+
detection_visualizations_dict,
|
527 |
+
images_grana_dict,
|
528 |
+
)
|
529 |
+
|
530 |
+
# generate plot
|
531 |
+
area_fig = draw_violin_plot(
|
532 |
+
grana_data["area [nm^2]"].dropna(),
|
533 |
+
"Granum area [nm^2]",
|
534 |
+
"Grana areas from all uploaded images",
|
535 |
+
)
|
536 |
+
perimeter_fig = draw_violin_plot(
|
537 |
+
grana_data["perimeter [nm]"].dropna(),
|
538 |
+
"Granum perimeter [nm]",
|
539 |
+
"Grana perimeters from all uploaded images",
|
540 |
+
)
|
541 |
+
gsi_fig = draw_violin_plot(
|
542 |
+
grana_data["GSI"].dropna(),
|
543 |
+
"GSI",
|
544 |
+
"GSI from all uploaded images",
|
545 |
+
)
|
546 |
+
diameter_fig = draw_violin_plot(
|
547 |
+
grana_data["diameter [nm]"].dropna(),
|
548 |
+
"Granum diameter [nm]",
|
549 |
+
"Grana diameters from all uploaded images",
|
550 |
+
)
|
551 |
+
height_fig = draw_violin_plot(
|
552 |
+
grana_data["height [nm]"].dropna(),
|
553 |
+
"Granum height [nm]",
|
554 |
+
"Grana heights from all uploaded images",
|
555 |
+
)
|
556 |
+
srd_fig = draw_violin_plot(
|
557 |
+
grana_data["SRD [nm]"].dropna(), "SRD [nm]", "SRD from all uploaded images"
|
558 |
+
)
|
559 |
+
|
560 |
+
return [
|
561 |
+
gr.Row(visible=True),
|
562 |
+
gr.Row(visible=True),
|
563 |
+
download_file_path,
|
564 |
+
detection_visualizations,
|
565 |
+
aggregated_df_transposed,
|
566 |
+
area_fig,
|
567 |
+
perimeter_fig,
|
568 |
+
gsi_fig,
|
569 |
+
diameter_fig,
|
570 |
+
height_fig,
|
571 |
+
srd_fig,
|
572 |
+
images_grana,
|
573 |
+
grana_data,
|
574 |
+
]
|
575 |
+
|
576 |
+
submit_btn.click(
|
577 |
+
fn=gradio_analize_image,
|
578 |
+
inputs=[
|
579 |
+
img_input,
|
580 |
+
ratio_input,
|
581 |
+
],
|
582 |
+
outputs=[
|
583 |
+
loading_row,
|
584 |
+
output_row,
|
585 |
+
# file_download_checkboxes,
|
586 |
+
download_file_out,
|
587 |
+
gallery_out,
|
588 |
+
table_out,
|
589 |
+
area_plot_out,
|
590 |
+
perimeter_plot_out,
|
591 |
+
gsi_plot_out,
|
592 |
+
diameter_plot_out,
|
593 |
+
height_plot_out,
|
594 |
+
srd_plot_out,
|
595 |
+
gallery_single_grana_out,
|
596 |
+
table_full_out,
|
597 |
+
],
|
598 |
+
).then(fn=enable_submit, inputs=[], outputs=[submit_btn, clear_btn, loading_row])
|
599 |
+
|
600 |
+
demo.launch(
|
601 |
+
share=False, debug=True, server_name="0.0.0.0", allowed_paths=["images/logo.svg"]
|
602 |
+
)
|
grana_detection/mmwrapper.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from typing import Union, Optional
|
4 |
+
from PIL import Image
|
5 |
+
from mmdet.apis import DetInferencer
|
6 |
+
from ultralytics.engine.results import Results
|
7 |
+
import warnings
|
8 |
+
|
9 |
+
class MMDetector(DetInferencer):
|
10 |
+
def __call__(
|
11 |
+
self,
|
12 |
+
inputs,
|
13 |
+
) -> Results:
|
14 |
+
"""Call the inferencer as in DetInferencer but for single image.
|
15 |
+
|
16 |
+
Args:
|
17 |
+
inputs (np.ndarray | str): Inputs for the inferencer.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
Result: yolo-like result
|
21 |
+
"""
|
22 |
+
|
23 |
+
ori_inputs = self._inputs_to_list(inputs)
|
24 |
+
|
25 |
+
data = list(self.preprocess(
|
26 |
+
ori_inputs, batch_size=1))[0][1]
|
27 |
+
|
28 |
+
preds = self.forward(data)[0]
|
29 |
+
|
30 |
+
yolo_result = Results(
|
31 |
+
orig_img=ori_inputs[0], path="", names=[""],
|
32 |
+
boxes=torch.cat((preds.pred_instances.bboxes, preds.pred_instances.scores.unsqueeze(-1), preds.pred_instances.labels.unsqueeze(-1)), dim=1),
|
33 |
+
masks=preds.pred_instances.masks
|
34 |
+
)
|
35 |
+
|
36 |
+
return yolo_result
|
37 |
+
|
38 |
+
def predict(self, source: Image.Image, conf=None):
|
39 |
+
"""yolo interface"""
|
40 |
+
if conf is not None:
|
41 |
+
warnings.warn(f"confidence value {conf} ignored")
|
42 |
+
return [self.__call__(np.array(source.convert("RGB")))]
|
model.py
ADDED
@@ -0,0 +1,629 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import itertools
|
2 |
+
import warnings
|
3 |
+
from io import BytesIO
|
4 |
+
from copy import deepcopy
|
5 |
+
from pathlib import Path
|
6 |
+
from typing import List, Tuple, Dict, Optional, Any, Union
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
import numpy as np
|
11 |
+
import pandas as pd
|
12 |
+
import matplotlib.pyplot as plt
|
13 |
+
from scipy import ndimage
|
14 |
+
from skimage import measure
|
15 |
+
import torchvision.transforms.functional as tvf
|
16 |
+
from ultralytics import YOLO
|
17 |
+
import torch
|
18 |
+
import cv2
|
19 |
+
import gradio
|
20 |
+
|
21 |
+
import sys, os
|
22 |
+
sys.path.append(os.path.abspath('angle_calculation'))
|
23 |
+
# from classic import measure_object
|
24 |
+
from sampling import get_crop_batch
|
25 |
+
from angle_model import PatchedPredictor, StripsModelLumenWidth
|
26 |
+
|
27 |
+
from period_calculation.period_measurer import PeriodMeasurer
|
28 |
+
# from grana_detection.mmwrapper import MMDetector # mmdet installation in docker is problematic for now
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class Granum:
|
32 |
+
id: Optional[int] = None
|
33 |
+
image: Any = None
|
34 |
+
mask: Any = None
|
35 |
+
scaler: Any = None
|
36 |
+
nm_per_px: float = float('nan')
|
37 |
+
detection_confidence: float = float('nan')
|
38 |
+
img_oriented: Optional[np.ndarray] = None # oriented fragment of the image
|
39 |
+
mask_oriented: Optional[np.ndarray] = None # oriented fragment of the mask
|
40 |
+
measurements: dict = field(default_factory=dict) # dict with grana measurements
|
41 |
+
|
42 |
+
class ScalerPadder:
|
43 |
+
"""resize and pad image to specific range.
|
44 |
+
minimal_pad: obligatory padding, e.g. required for detector
|
45 |
+
"""
|
46 |
+
def __init__(self, target_size=1024, target_short_edge_min=640, minimal_pad=16, pad_to_multiply=32):
|
47 |
+
self.minimal_pad = minimal_pad
|
48 |
+
self.target_size = target_size - 2*self.minimal_pad # detection pad is necessary padding size
|
49 |
+
self.target_short_edge_min = target_short_edge_min - 2*self.minimal_pad
|
50 |
+
|
51 |
+
self.max_size_nm = 6000 # training images covers ~3100 nm
|
52 |
+
self.min_size_nm = 2400 # training images covers ~3100 nm
|
53 |
+
self.pad_to_multiply = pad_to_multiply
|
54 |
+
|
55 |
+
def transform(self, image: Image.Image, px_per_nm: float=1.298) -> Image.Image:
|
56 |
+
self.original_size = image.size
|
57 |
+
self.original_px_per_nm = px_per_nm
|
58 |
+
w, h = self.original_size
|
59 |
+
longest_size = max(h, w)
|
60 |
+
img_size_nm = longest_size / px_per_nm
|
61 |
+
if img_size_nm > self.max_size_nm:
|
62 |
+
error_message = f'too large image, image size: {img_size_nm:0.1f}nm, max allowed: {self.max_size_nm}nm'
|
63 |
+
# raise ValueError(error_message)
|
64 |
+
# warnings.warn(warning_message)
|
65 |
+
gradio.Warning(error_message)
|
66 |
+
# add_text(image, warning_message, location=(0.1, 0.1), color='blue', size=int(40*longest_size/self.target_size))
|
67 |
+
|
68 |
+
self.resize_factor = self.target_size / (max(self.min_size_nm, img_size_nm) * px_per_nm)
|
69 |
+
self.px_per_nm_transformed = px_per_nm * self.resize_factor
|
70 |
+
|
71 |
+
resized_image = resize_with_cv2(image, (int(h*self.resize_factor), int(w*self.resize_factor)))
|
72 |
+
|
73 |
+
if w >= h:
|
74 |
+
pad_w = self.target_size-resized_image.size[0]
|
75 |
+
pad_h = max(0, self.target_short_edge_min-resized_image.size[1])
|
76 |
+
else:
|
77 |
+
pad_w = max(0, self.target_short_edge_min-resized_image.size[0])
|
78 |
+
pad_h = self.target_size-resized_image.size[1]
|
79 |
+
|
80 |
+
# apply minimal padding
|
81 |
+
pad_w += 2*self.minimal_pad
|
82 |
+
pad_h += 2*self.minimal_pad
|
83 |
+
|
84 |
+
# round to multiplication
|
85 |
+
pad_w += (self.pad_to_multiply - resized_image.size[0]%self.pad_to_multiply)%self.pad_to_multiply
|
86 |
+
pad_h += (self.pad_to_multiply - resized_image.size[1]%self.pad_to_multiply)%self.pad_to_multiply
|
87 |
+
|
88 |
+
self.pad_right = pad_w // 2
|
89 |
+
self.pad_left = pad_w - self.pad_right
|
90 |
+
|
91 |
+
self.pad_up = pad_h // 2
|
92 |
+
self.pad_bottom = pad_h - self.pad_up
|
93 |
+
|
94 |
+
padded_image = tvf.pad(resized_image, [self.pad_left,self.pad_up, self.pad_right, self.pad_bottom], padding_mode='reflect') # fill 114 as in YOLO
|
95 |
+
return padded_image
|
96 |
+
|
97 |
+
@property
|
98 |
+
def unpad_slice(self) -> Tuple[slice]:
|
99 |
+
return slice(self.pad_up,-self.pad_bottom if self.pad_bottom>0 else None), slice(self.pad_left,-self.pad_right if self.pad_right>0 else None)
|
100 |
+
|
101 |
+
def inverse_transform(self, image: Union[np.ndarray, Image.Image], output_size: Optional[Tuple[int]]=None, output_nm_per_px: Optional[float]=None, return_pil: bool=True) -> Image.Image:
|
102 |
+
if isinstance(image, Image.Image):
|
103 |
+
image = np.array(image)
|
104 |
+
# h, w = image.shape[:2]
|
105 |
+
# unpadded_image = image[self.pad_up:h-self.pad_bottom,self.pad_left:w-self.pad_right]
|
106 |
+
unapdded_image = image[self.unpad_slice]
|
107 |
+
|
108 |
+
if output_size is not None and output_nm_per_px is not None:
|
109 |
+
raise ValueError("one of output_size or output_nm_per_px must not be None")
|
110 |
+
elif output_nm_per_px is not None:
|
111 |
+
resize_factor = self.original_nm_per_px/output_nm_per_px
|
112 |
+
output_size = (int(self.original_size[0]*resize_factor), int(self.original_size[1]*resize_factor))
|
113 |
+
elif output_size is None:
|
114 |
+
output_size = self.original_size
|
115 |
+
resized_image = resize_with_cv2(unapdded_image, (output_size[1],output_size[0]), return_pil=return_pil) #Image.fromarray(unpadded_image).resize(self.original_size)
|
116 |
+
|
117 |
+
return resized_image
|
118 |
+
|
119 |
+
def close_contour(contour):
|
120 |
+
if not np.array_equal(contour[0], contour[-1]):
|
121 |
+
contour = np.vstack((contour, contour[0]))
|
122 |
+
return contour
|
123 |
+
|
124 |
+
def binary_mask_to_polygon(binary_mask, tol=0.01):
|
125 |
+
padded_binary_mask = np.pad(binary_mask, pad_width=1, mode='constant', constant_values=0)
|
126 |
+
contours = measure.find_contours(padded_binary_mask, 0.5)
|
127 |
+
# assert len(contours) == 1 #raise error if there are more than 1 contour
|
128 |
+
contour = contours[0]
|
129 |
+
contour -= 1 # correct for padding
|
130 |
+
contour = close_contour(contour)
|
131 |
+
|
132 |
+
polygon = measure.approximate_polygon(contour, tol)
|
133 |
+
|
134 |
+
polygon = np.flip(polygon, axis=1)
|
135 |
+
# after padding and subtracting 1 we may get -0.5 points in our polygon. Replace it with 0
|
136 |
+
polygon = np.where(polygon>=0, polygon, 0)
|
137 |
+
# segmentation = polygon.ravel().tolist()
|
138 |
+
|
139 |
+
return polygon
|
140 |
+
|
141 |
+
def measure_shape(binary_mask):
|
142 |
+
contour = binary_mask_to_polygon(binary_mask)
|
143 |
+
perimeter = np.sum(np.linalg.norm(contour[:-1] - contour[1:], axis=1))
|
144 |
+
area = np.sum(binary_mask)
|
145 |
+
|
146 |
+
return perimeter, area
|
147 |
+
|
148 |
+
def calculate_gsi(perimeter, height, area):
|
149 |
+
a = 0.5*(perimeter - 2*height)
|
150 |
+
return 1 - area/(a*height)
|
151 |
+
|
152 |
+
def object_slice(mask):
|
153 |
+
rows = np.any(mask, axis=1)
|
154 |
+
cols = np.any(mask, axis=0)
|
155 |
+
row_min, row_max = np.where(rows)[0][[0, -1]]
|
156 |
+
col_min, col_max = np.where(cols)[0][[0, -1]]
|
157 |
+
|
158 |
+
# Create a slice object for the bounding box
|
159 |
+
bounding_box_slice = (slice(row_min, row_max + 1), slice(col_min, col_max + 1))
|
160 |
+
|
161 |
+
return bounding_box_slice
|
162 |
+
|
163 |
+
|
164 |
+
def figure_to_pil(fig):
|
165 |
+
buf = BytesIO()
|
166 |
+
fig.savefig(buf, format='png')
|
167 |
+
buf.seek(0)
|
168 |
+
|
169 |
+
# Load the image from the buffer as a PIL Image
|
170 |
+
image = deepcopy(Image.open(buf))
|
171 |
+
|
172 |
+
# Close the buffer
|
173 |
+
buf.close()
|
174 |
+
return image
|
175 |
+
|
176 |
+
def resize_to(image: Image.Image, s: int=4032, return_factor: bool =False) -> Image.Image:
|
177 |
+
w, h = image.size
|
178 |
+
longest_size = max(h, w)
|
179 |
+
|
180 |
+
resize_factor = longest_size / s
|
181 |
+
|
182 |
+
resized_image = image.resize((int(w/resize_factor), int(h/resize_factor)))
|
183 |
+
if return_factor:
|
184 |
+
return resized_image, resize_factor
|
185 |
+
return resized_image
|
186 |
+
|
187 |
+
def resize_with_cv2(image, shape, return_pil=True):
|
188 |
+
"""resize using cv2 with cv2.INTER_LINEAR - consistent with YOLO"""
|
189 |
+
h, w = shape
|
190 |
+
if isinstance(image, Image.Image):
|
191 |
+
image = np.array(image)
|
192 |
+
|
193 |
+
resized = cv2.resize(image, (w, h), interpolation=cv2.INTER_LINEAR)
|
194 |
+
if return_pil:
|
195 |
+
return Image.fromarray(resized)
|
196 |
+
else:
|
197 |
+
return resized
|
198 |
+
|
199 |
+
def select_unique_mask(mask):
|
200 |
+
"""if mask consists of multiple parts, select the largest"""
|
201 |
+
if not np.any(mask): # if mask is empty, return without change
|
202 |
+
return mask
|
203 |
+
blobs = ndimage.label(mask)[0]
|
204 |
+
blob_labels, blob_sizes = np.unique(blobs, return_counts=True)
|
205 |
+
best_blob_label = blob_labels[1:][np.argmax(blob_sizes[1:])]
|
206 |
+
return blobs == best_blob_label
|
207 |
+
|
208 |
+
def sliced_mean(x, slice_size):
|
209 |
+
cs_y = np.cumsum(x, axis=0)
|
210 |
+
cs_y = np.concatenate((np.zeros((1, cs_y.shape[1]), dtype=cs_y.dtype), cs_y), axis=0)
|
211 |
+
slices_y = (cs_y[slice_size:] - cs_y[:-slice_size])/slice_size
|
212 |
+
cs_xy = np.cumsum(slices_y, axis=1)
|
213 |
+
cs_xy = np.concatenate((np.zeros((cs_xy.shape[0], 1), dtype=cs_xy.dtype), cs_xy), axis=1)
|
214 |
+
slices_xy = (cs_xy[:,slice_size:] - cs_xy[:,:-slice_size])/slice_size
|
215 |
+
return slices_xy
|
216 |
+
|
217 |
+
def sliced_var(x, slice_size):
|
218 |
+
x = x.astype('float64')
|
219 |
+
return sliced_mean(x**2, slice_size) - sliced_mean(x, slice_size)**2
|
220 |
+
|
221 |
+
def calculate_distance_map(mask):
|
222 |
+
padded = np.pad(mask, pad_width=1, mode='constant', constant_values=False)
|
223 |
+
distance_map_padded = ndimage.distance_transform_edt(padded)
|
224 |
+
return distance_map_padded[1:-1,1:-1]
|
225 |
+
|
226 |
+
def select_samples(granum_image, granum_mask, crop_size=96, n_samples=64, granum_fraction_min=0.75, variance_p=0.):
|
227 |
+
granum_occupancy = sliced_mean(granum_mask, crop_size)
|
228 |
+
possible_indices = np.stack(np.where(granum_occupancy >= granum_fraction_min), axis=1)
|
229 |
+
|
230 |
+
if variance_p == 0:
|
231 |
+
p = np.ones(len(possible_indices))
|
232 |
+
else:
|
233 |
+
variance_map = sliced_var(granum_image, crop_size)
|
234 |
+
p = variance_map[possible_indices[:,0], possible_indices[:,1]]**variance_p
|
235 |
+
p /= np.sum(p)
|
236 |
+
|
237 |
+
chosen_indices = np.random.choice(
|
238 |
+
np.arange(len(possible_indices)),
|
239 |
+
min(len(possible_indices), n_samples),
|
240 |
+
replace=False,
|
241 |
+
p = p
|
242 |
+
)
|
243 |
+
|
244 |
+
crops = []
|
245 |
+
for crop_idx, idx in enumerate(chosen_indices):
|
246 |
+
crops.append(
|
247 |
+
granum_image[
|
248 |
+
possible_indices[idx,0]:possible_indices[idx,0]+crop_size,
|
249 |
+
possible_indices[idx,1]:possible_indices[idx,1]+crop_size
|
250 |
+
]
|
251 |
+
)
|
252 |
+
return np.array(crops)
|
253 |
+
|
254 |
+
def calculate_height(mask_oriented): #HACK
|
255 |
+
span = mask_oriented.shape[0] - np.argmax(mask_oriented[::-1], axis=0) - np.argmax(mask_oriented, axis=0)
|
256 |
+
return np.quantile(span, 0.8)
|
257 |
+
|
258 |
+
def calculate_diameter(mask_oriented):
|
259 |
+
"""returns mean diameter"""
|
260 |
+
# calculate 0.25 and 0.75 lines
|
261 |
+
vertical_mask = np.any(mask_oriented, axis=1)
|
262 |
+
upper_granum_bound = np.argmax(vertical_mask)
|
263 |
+
lower_granum_bound = mask_oriented.shape[0] - np.argmax(vertical_mask[::-1])
|
264 |
+
upper = round(0.75*upper_granum_bound + 0.25*lower_granum_bound)
|
265 |
+
lower = max(upper+1, round(0.25*upper_granum_bound + 0.75*lower_granum_bound))
|
266 |
+
valid_rows_slice = slice(upper, lower)
|
267 |
+
|
268 |
+
# calculate diameters
|
269 |
+
span = mask_oriented.shape[1] - np.argmax(mask_oriented[valid_rows_slice,::-1], axis=1) - np.argmax(mask_oriented[valid_rows_slice], axis=1)
|
270 |
+
return np.mean(span)
|
271 |
+
|
272 |
+
def robust_mean(x, q=0.1):
|
273 |
+
x_med = np.median(x)
|
274 |
+
deviations = abs(x- x_med)
|
275 |
+
if max(deviations) == 0:
|
276 |
+
mask = np.ones(len(x), dtype='bool')
|
277 |
+
else:
|
278 |
+
threshold = np.quantile(deviations, 1-q)
|
279 |
+
mask = x[deviations<= threshold]
|
280 |
+
|
281 |
+
return np.mean(x[mask])
|
282 |
+
|
283 |
+
def rotate_image_and_mask(image, mask, direction):
|
284 |
+
mask_oriented = ndimage.rotate(mask.astype('int'), -direction, reshape=True).astype('bool')
|
285 |
+
idx_begin_x, idx_end_x = np.where(np.any(mask_oriented, axis=0))[0][np.array([0, -1])]
|
286 |
+
idx_begin_y, idx_end_y = np.where(np.any(mask_oriented, axis=1))[0][np.array([0, -1])]
|
287 |
+
img_oriented = ndimage.rotate(image, -direction, reshape=True) #[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x]
|
288 |
+
return img_oriented, mask_oriented
|
289 |
+
|
290 |
+
class GranaAnalyser:
|
291 |
+
def __init__(self, weights_detector: str, weights_orientation: str, weights_period: str, period_sd_threshold_nm: float=2.5) -> None:
|
292 |
+
"""
|
293 |
+
Initializes the GranaAnalyser with specified weights for detection and measuring.
|
294 |
+
|
295 |
+
This method loads the weights for the grana detection and measuring algorithms
|
296 |
+
from the specified file paths. It also loads mock data for visualization and
|
297 |
+
analysis purposes.
|
298 |
+
|
299 |
+
Parameters:
|
300 |
+
weights_detector (str): The file path to the weights file for the grana detection algorithm.
|
301 |
+
weights_orientation (str): The file path to the weights file for the grana orientation algorithm.
|
302 |
+
weights_period (str): The file path to the weights file for the grana period algorithm.
|
303 |
+
"""
|
304 |
+
self.detector = YOLO(weights_detector)
|
305 |
+
|
306 |
+
self.orienter = PatchedPredictor(
|
307 |
+
StripsModelLumenWidth.load_from_checkpoint(weights_orientation, map_location='cpu').eval(),
|
308 |
+
normalization = dict(mean=0.250, std=0.135),
|
309 |
+
n_samples=32,
|
310 |
+
mask=None,
|
311 |
+
crop_size=64,
|
312 |
+
angle_confidence_threshold=0.2
|
313 |
+
)
|
314 |
+
|
315 |
+
self.measurement_px_per_nm = 1/0.768 # image scale required for measurement
|
316 |
+
|
317 |
+
self.period_measurer = PeriodMeasurer(
|
318 |
+
weights_period,
|
319 |
+
px_per_nm=self.measurement_px_per_nm,
|
320 |
+
sd_threshold_nm=period_sd_threshold_nm,
|
321 |
+
period_threshold_nm_min=14, period_threshold_nm_max=30
|
322 |
+
)
|
323 |
+
|
324 |
+
|
325 |
+
def get_grana_data(self, image, detections, scaler, border_margin=1, min_count=1) -> List[Granum]:
|
326 |
+
"""filter detections and create grana data"""
|
327 |
+
image_numpy = np.array(image)
|
328 |
+
if image_numpy.ndim == 3:
|
329 |
+
image_numpy = image_numpy[:,:,0]
|
330 |
+
|
331 |
+
mask_all = None
|
332 |
+
grana = []
|
333 |
+
for mask, confidence in zip(
|
334 |
+
detections.masks.data.cpu().numpy().astype('bool'),
|
335 |
+
detections.boxes.conf.cpu().numpy()
|
336 |
+
):
|
337 |
+
granum_mask = select_unique_mask(mask[scaler.unpad_slice])
|
338 |
+
# check if mask is empty after padding
|
339 |
+
if not np.any(granum_mask):
|
340 |
+
continue
|
341 |
+
granum_mask = ndimage.binary_fill_holes(granum_mask)
|
342 |
+
|
343 |
+
# check if touches boundary:
|
344 |
+
if (np.sum(granum_mask[:border_margin])>min_count) or \
|
345 |
+
(np.sum(granum_mask[-border_margin:])>min_count) or \
|
346 |
+
(np.sum(granum_mask[:,:border_margin])>min_count) or \
|
347 |
+
(np.sum(granum_mask[:,-border_margin:])>min_count):
|
348 |
+
|
349 |
+
continue
|
350 |
+
|
351 |
+
# check grana overlap
|
352 |
+
if mask_all is None:
|
353 |
+
mask_all = granum_mask
|
354 |
+
else:
|
355 |
+
intersection = mask_all & granum_mask
|
356 |
+
|
357 |
+
if intersection.sum() >= (granum_mask.sum() * 0.2):
|
358 |
+
continue
|
359 |
+
mask_all = mask_all | granum_mask
|
360 |
+
|
361 |
+
granum = Granum(
|
362 |
+
image = image,
|
363 |
+
mask = granum_mask,
|
364 |
+
scaler=scaler,
|
365 |
+
detection_confidence=float(confidence)
|
366 |
+
)
|
367 |
+
|
368 |
+
granum.image_numpy = image_numpy
|
369 |
+
grana.append(granum)
|
370 |
+
return grana
|
371 |
+
|
372 |
+
def measure_grana(self, grana: List[Granum], measurement_image: np.ndarray) -> List[Granum]:
|
373 |
+
"""measure grana: includes orientation detection, period detection and geometric measurements"""
|
374 |
+
for granum in grana:
|
375 |
+
measurement_mask = resize_with_cv2(granum.mask.astype(np.uint8), measurement_image.shape[:2], return_pil=False).astype('bool')
|
376 |
+
|
377 |
+
granum.bounding_box_slice = object_slice(measurement_mask)
|
378 |
+
granum.image_crop = measurement_image[granum.bounding_box_slice][:,:]
|
379 |
+
granum.mask_crop = measurement_mask[granum.bounding_box_slice]
|
380 |
+
|
381 |
+
# initialize measurements
|
382 |
+
granum.measurements = {}
|
383 |
+
|
384 |
+
# measure shape
|
385 |
+
granum.measurements['perimeter px'], granum.measurements['area px'] = measure_shape(granum.mask_crop)
|
386 |
+
|
387 |
+
# measrure orientation
|
388 |
+
orienter_predictions = self.orienter(granum.image_crop, granum.mask_crop)
|
389 |
+
granum.measurements['direction'] = orienter_predictions["est_angle"]
|
390 |
+
granum.measurements['direction confidence'] = orienter_predictions["est_angle_confidence"]
|
391 |
+
|
392 |
+
if not np.isnan(granum.measurements["direction"]):
|
393 |
+
img_oriented, mask_oriented = rotate_image_and_mask(granum.image_crop, granum.mask_crop, granum.measurements["direction"])
|
394 |
+
oriented_granum_slice = object_slice(mask_oriented)
|
395 |
+
granum.img_oriented = img_oriented[oriented_granum_slice]
|
396 |
+
granum.mask_oriented = mask_oriented[oriented_granum_slice]
|
397 |
+
granum.measurements['height px'] = calculate_height(granum.mask_oriented)
|
398 |
+
granum.measurements['GSI'] = calculate_gsi(
|
399 |
+
granum.measurements['perimeter px'],
|
400 |
+
granum.measurements['height px'],
|
401 |
+
granum.measurements['area px']
|
402 |
+
)
|
403 |
+
granum.measurements['diameter px'] = calculate_diameter(granum.mask_oriented)
|
404 |
+
|
405 |
+
oriented_granum_slice = object_slice(granum.mask_oriented)
|
406 |
+
granum.measurements["period nm"], granum.measurements["period SD nm"] = self.period_measurer(granum.img_oriented, granum.mask_oriented)
|
407 |
+
|
408 |
+
if not pd.isna(granum.measurements['period nm']):
|
409 |
+
granum.measurements['Number of layers'] = round(granum.measurements['height px']/ self.measurement_px_per_nm / granum.measurements['period nm'])
|
410 |
+
|
411 |
+
return grana
|
412 |
+
|
413 |
+
def extract_grana_data(self, grana: List[Granum]) -> pd.DataFrame:
|
414 |
+
"""collect and scale grana data"""
|
415 |
+
grana_data = []
|
416 |
+
for granum in grana:
|
417 |
+
granum_entry = {
|
418 |
+
'Granum ID': granum.id,
|
419 |
+
'detection confidence': granum.detection_confidence
|
420 |
+
}
|
421 |
+
# fill with None if absent:
|
422 |
+
for key in ['direction', 'Number of layers', 'GSI', 'period nm', 'period SD nm']:
|
423 |
+
granum_entry[key] = granum.measurements.get(key, None)
|
424 |
+
# scale linearly:
|
425 |
+
for key in ['height px', 'diameter px', 'perimeter px', 'perimeter px']:
|
426 |
+
granum_entry[f"{key[:-3]} nm"] = granum.measurements.get(key, np.nan) / self.measurement_px_per_nm
|
427 |
+
# scale quadratically
|
428 |
+
granum_entry['area nm^2'] = granum.measurements['area px'] / self.measurement_px_per_nm**2
|
429 |
+
|
430 |
+
grana_data.append(granum_entry)
|
431 |
+
|
432 |
+
return pd.DataFrame(grana_data)
|
433 |
+
|
434 |
+
def visualize_detections(self, grana: List[Granum], image: Image.Image) -> Image.Image:
|
435 |
+
visualization_longer_edge = 1024
|
436 |
+
scale = visualization_longer_edge/max(image.size)
|
437 |
+
visualization_size = (round(scale*image.size[0]), round(scale*image.size[1]))
|
438 |
+
visualization_image = np.array(image.resize(visualization_size).convert('RGB'))
|
439 |
+
|
440 |
+
if len(grana) > 0:
|
441 |
+
grana_mask = resize_with_cv2(
|
442 |
+
np.any(np.array([granum.mask for granum in grana]),axis=0).astype(np.uint8),
|
443 |
+
visualization_size[::-1],
|
444 |
+
return_pil=False
|
445 |
+
).astype('bool')
|
446 |
+
visualization_image[grana_mask]= (0.7*visualization_image[grana_mask] + 0.3*np.array([[[39, 179, 115]]])).astype(np.uint8)
|
447 |
+
|
448 |
+
|
449 |
+
for granum in grana:
|
450 |
+
scale = visualization_longer_edge/max(granum.mask.shape)
|
451 |
+
y, x = ndimage.center_of_mass(granum.mask)
|
452 |
+
cv2.putText(visualization_image, f'{granum.id}', org=(int(x*scale)-10, int(y*scale)+10), fontFace=cv2.FONT_HERSHEY_SIMPLEX , fontScale=1, color=(39, 179, 115),thickness = 2)
|
453 |
+
|
454 |
+
return Image.fromarray(visualization_image)
|
455 |
+
|
456 |
+
|
457 |
+
def generate_grana_images(self, grana: List[Granum], image_name: str ="") -> List[Image.Image]:
|
458 |
+
grana_images = {}
|
459 |
+
for granum in grana:
|
460 |
+
fig, ax = plt.subplots()
|
461 |
+
if granum.img_oriented is None:
|
462 |
+
image_to_plot = granum.image_crop
|
463 |
+
mask_to_plot = granum.mask_crop
|
464 |
+
extra_caption = " orientation and period unknown"
|
465 |
+
else:
|
466 |
+
image_to_plot = granum.img_oriented
|
467 |
+
mask_to_plot = granum.mask_oriented
|
468 |
+
extra_caption = ""
|
469 |
+
|
470 |
+
ax.imshow(0.5*255*(~mask_to_plot) +image_to_plot*(1-0.5*(~mask_to_plot)), cmap='gray', vmin=0, vmax=255)
|
471 |
+
ax.axis('off')
|
472 |
+
ax.set_title(f'[{granum.id}]{image_name}\n{extra_caption}')
|
473 |
+
granum_image = figure_to_pil(fig)
|
474 |
+
grana_images[granum.id] = granum_image
|
475 |
+
plt.close('all')
|
476 |
+
|
477 |
+
return grana_images
|
478 |
+
|
479 |
+
def format_data(self, grana_data: pd.DataFrame) -> pd.DataFrame:
|
480 |
+
rounding_roles = {'area nm^2': 0, 'perimeter nm': 1, 'diameter nm': 1, 'height nm': 1, 'period nm': 1, 'period SD nm': 2, 'GSI':2, 'direction': 1}
|
481 |
+
rounded_data = grana_data.round(rounding_roles)
|
482 |
+
columns_order = ['Granum ID', 'File name', 'area nm^2', 'perimeter nm', 'GSI','diameter nm', 'height nm', 'Number of layers','period nm', 'period SD nm', 'direction']
|
483 |
+
return rounded_data[columns_order]
|
484 |
+
|
485 |
+
|
486 |
+
def aggregate_data(self, grana_data: pd.DataFrame, confidence: Optional[float]=None) -> Dict:
|
487 |
+
if confidence is None:
|
488 |
+
filtered = grana_data
|
489 |
+
else:
|
490 |
+
filtered = grana_data.loc[grana_data['aggregated confidence'] >= confidence]
|
491 |
+
aggregation = filtered[['area nm^2', 'perimeter nm', 'diameter nm', 'height nm', 'Number of layers', 'period nm', 'GSI']].mean().to_dict()
|
492 |
+
aggregation_std = filtered[['area nm^2', 'perimeter nm', 'diameter nm', 'height nm', 'Number of layers', 'period nm', 'GSI']].std().to_dict()
|
493 |
+
aggregation_std = {f"{k} std": v for k, v in aggregation_std.items()}
|
494 |
+
aggregation_result = {**aggregation, **aggregation_std, 'N grana': len(filtered)}
|
495 |
+
return aggregation_result
|
496 |
+
|
497 |
+
def predict_on_single(self, image: Image.Image, scale: float, detection_confidence: float=0.25, granum_id_start=1, image_name: str = "") -> Tuple[List[Image.Image], pd.DataFrame, List[Image.Image]]:
|
498 |
+
"""
|
499 |
+
Predicts and aggregates data related to grana using a dictionary of images.
|
500 |
+
|
501 |
+
Parameters:
|
502 |
+
image (Image.Image): PIL Image object to be analyzed
|
503 |
+
scale (float): scale of the image: px per nm.
|
504 |
+
detection_confidence (float): The detection confidence threshold shape measurement
|
505 |
+
|
506 |
+
Returns:
|
507 |
+
Tuple[Image.Image, pandas.DataFrame, List[Image.Image]]:
|
508 |
+
A tuple containing:
|
509 |
+
- detection_visualization (Image.Image): PIL image representing
|
510 |
+
the detection visualizations.
|
511 |
+
- grana_data (pandas.DataFrame): A DataFrame containing the simulated granum data.
|
512 |
+
- grana_images (List[Image.Image]): A list of PIL images of the grana.
|
513 |
+
"""
|
514 |
+
# convert to grayscale
|
515 |
+
image = image.convert("L")
|
516 |
+
|
517 |
+
# detect
|
518 |
+
scaler = ScalerPadder(target_size=1024, target_short_edge_min=640)
|
519 |
+
scaled_image = scaler.transform(image, px_per_nm=scale)
|
520 |
+
detections = self.detector.predict(source=scaled_image, conf=detection_confidence)[0]
|
521 |
+
|
522 |
+
# get grana data
|
523 |
+
grana = self.get_grana_data(image, detections, scaler)
|
524 |
+
for granum_id, granum in enumerate(grana, start=granum_id_start):
|
525 |
+
granum.id = granum_id
|
526 |
+
|
527 |
+
|
528 |
+
# visualize detections
|
529 |
+
detection_visualization = self.visualize_detections(grana, image)
|
530 |
+
|
531 |
+
# measure grana
|
532 |
+
measurement_image_resize_factor = self.measurement_px_per_nm / scale
|
533 |
+
measurement_image_shape = (
|
534 |
+
int(image.size[1]*measurement_image_resize_factor),
|
535 |
+
int(image.size[0]*measurement_image_resize_factor)
|
536 |
+
)
|
537 |
+
measurement_image = resize_with_cv2( # numpy image in scale valid for measurement
|
538 |
+
image, measurement_image_shape, return_pil=False
|
539 |
+
)
|
540 |
+
grana = self.measure_grana(grana, measurement_image)
|
541 |
+
|
542 |
+
# pandas DataFrame
|
543 |
+
grana_data = self.extract_grana_data(grana)
|
544 |
+
|
545 |
+
# list of PIL images
|
546 |
+
grana_images = self.generate_grana_images(grana, image_name=image_name)
|
547 |
+
|
548 |
+
return detection_visualization, grana_data, grana_images
|
549 |
+
|
550 |
+
def predict(self, images: Dict[str, Image.Image], scale: float, detection_confidence: float=0.25, parameter_confidence: Optional[float]=None) -> Tuple[List[Image.Image], pd.DataFrame, List[Image.Image], Dict]:
|
551 |
+
"""
|
552 |
+
Predicts and aggregates data related to grana using a dictionary of images.
|
553 |
+
|
554 |
+
Parameters:
|
555 |
+
images (Dict[str, Image.Image]): A dictionary of PIL Image objects to be analyzed,
|
556 |
+
keyed by their names.
|
557 |
+
scale (float): scale of the image: px per nm
|
558 |
+
detection_confidence (float): The detection confidence threshold shape measurement
|
559 |
+
parameter_confidence (float): The confidence threshold used for data aggregation. Only
|
560 |
+
data with aggregated confidence above this threshold will
|
561 |
+
be considered.
|
562 |
+
|
563 |
+
Returns:
|
564 |
+
Tuple[List[Image.Image], pandas.DataFrame, List[Image.Image], Dict]:
|
565 |
+
A tuple containing:
|
566 |
+
- detection_visualizations (List[Image.Image]): A list of PIL images representing
|
567 |
+
the detection visualizations.
|
568 |
+
- grana_data (pandas.DataFrame): A DataFrame containing the simulated granum data.
|
569 |
+
- grana_images (List[Image.Image]): A list of PIL images of the grana.
|
570 |
+
- aggregated_data (Dict): A dictionary containing the aggregated data results.
|
571 |
+
"""
|
572 |
+
detection_visualizations_all = {}
|
573 |
+
grana_data_all = None
|
574 |
+
grana_images_all = {}
|
575 |
+
|
576 |
+
granum_id_start = 1
|
577 |
+
for image_name, image in images.items():
|
578 |
+
detection_visualization, grana_data, grana_images = self.predict_on_single(image, scale=scale, detection_confidence=detection_confidence, granum_id_start=granum_id_start, image_name=image_name)
|
579 |
+
granum_id_start += len(grana_data)
|
580 |
+
detection_visualizations_all[image_name] = detection_visualization
|
581 |
+
grana_images_all.update(grana_images)
|
582 |
+
|
583 |
+
grana_data['File name'] = image_name
|
584 |
+
if grana_data_all is None:
|
585 |
+
grana_data_all = grana_data
|
586 |
+
else:
|
587 |
+
# grana_data['Granum ID'] += len(grana_data_all)
|
588 |
+
grana_data_all = pd.concat([grana_data_all, grana_data])
|
589 |
+
|
590 |
+
# dict
|
591 |
+
# grana_data_all.to_csv('grana_data_all.csv', index=False)
|
592 |
+
aggregated_data = self.aggregate_data(grana_data_all, parameter_confidence)
|
593 |
+
|
594 |
+
formatted_grana_data = self.format_data(grana_data_all)
|
595 |
+
|
596 |
+
return detection_visualizations_all, formatted_grana_data, grana_images_all, aggregated_data
|
597 |
+
|
598 |
+
|
599 |
+
class GranaDetector(GranaAnalyser):
|
600 |
+
"""supplementary class for grana detection only
|
601 |
+
"""
|
602 |
+
def __init__(self, weights_detector: str, detector_config: Optional[str] = None, model_type="yolo") -> None:
|
603 |
+
|
604 |
+
if model_type == "yolo":
|
605 |
+
self.detector = YOLO(weights_detector)
|
606 |
+
elif model_type == "mmdetection":
|
607 |
+
self.detector = MMDetector(model=detector_config, weights=weights_detector)
|
608 |
+
else:
|
609 |
+
raise NotImplementedError()
|
610 |
+
|
611 |
+
def predict_on_single(self, image: Image.Image, scale: float, detection_confidence: float=0.25, granum_id_start=1, use_scaling=True, granum_border_margin=1, granum_border_min_count=1, scaler_sizes=(1024, 640)) -> List[Granum]:
|
612 |
+
# convert to grayscale
|
613 |
+
image = image.convert("L")
|
614 |
+
|
615 |
+
# detect
|
616 |
+
if use_scaling:
|
617 |
+
scaler = ScalerPadder(target_size=scaler_sizes[0], target_short_edge_min=scaler_sizes[1])
|
618 |
+
else:
|
619 |
+
#dummy scaler
|
620 |
+
scaler = ScalerPadder(target_size=max(image.size), target_short_edge_min=min(image.size), minimal_pad=0, pad_to_multiply=1)
|
621 |
+
scaled_image = scaler.transform(image, scale=scale)
|
622 |
+
detections = self.detector.predict(source=scaled_image, conf=detection_confidence)[0]
|
623 |
+
|
624 |
+
# get grana data
|
625 |
+
grana = self.get_grana_data(image, detections, scaler, border_margin=granum_border_margin, min_count=granum_border_min_count)
|
626 |
+
for i_granum, granum in enumerate(grana, start=1):
|
627 |
+
granum.id = i_granum
|
628 |
+
|
629 |
+
return grana
|
period_calculation/config.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
import albumentations as A
|
3 |
+
from albumentations.pytorch import ToTensorV2
|
4 |
+
|
5 |
+
|
6 |
+
transforms = [
|
7 |
+
A.Normalize(**{'mean': 0.2845, 'std': 0.1447}, max_pixel_value=1.0),
|
8 |
+
# Applies the formula (img - mean * max_pixel_value) / (std * max_pixel_value)
|
9 |
+
ToTensorV2()
|
10 |
+
]
|
11 |
+
|
12 |
+
model_config = {
|
13 |
+
'receptive_field_height': 220,
|
14 |
+
'receptive_field_width': 38,
|
15 |
+
'stride_height': 64,
|
16 |
+
'stride_width': 2,
|
17 |
+
'image_height': 476,
|
18 |
+
'image_width': 476}
|
19 |
+
|
period_calculation/data_reader.py
ADDED
@@ -0,0 +1,861 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import skimage.io
|
4 |
+
from pathlib import Path
|
5 |
+
import torch
|
6 |
+
import scipy
|
7 |
+
from PIL import Image, ImageFilter, ImageChops
|
8 |
+
# from config import model_config
|
9 |
+
from period_calculation.config import model_config
|
10 |
+
|
11 |
+
# Function to add Gaussian noise
|
12 |
+
|
13 |
+
def add_microscope_noise(base_image_as_numpy, noise_intensity):
|
14 |
+
###### The code below is for adding noise to the image
|
15 |
+
# noise intensity is a number between 0 and 1
|
16 |
+
# --- priginal implementation was provided by Michał Bykowski
|
17 |
+
# --- and adapted
|
18 |
+
# This routine works with PIL images and numpy internally (changing formats as it goes)
|
19 |
+
# but the input and output are numpy arrays
|
20 |
+
|
21 |
+
def add_noise(image, mean=0, std_dev=50): # std_dev impacts the amount of noise
|
22 |
+
# Generating noise
|
23 |
+
noise = np.random.normal(mean, std_dev, (image.height, image.width))
|
24 |
+
# Adding noise to the image
|
25 |
+
noisy_image = np.array(image) + noise
|
26 |
+
# Ensuring the values remain within valid grayscale range
|
27 |
+
noisy_image = np.clip(noisy_image, 0, 255)
|
28 |
+
return Image.fromarray(noisy_image.astype('uint8'))
|
29 |
+
|
30 |
+
|
31 |
+
base_image = Image.fromarray(base_image_as_numpy)
|
32 |
+
gray_value = 128
|
33 |
+
gray = Image.new('L', base_image.size, color=gray_value)
|
34 |
+
|
35 |
+
|
36 |
+
gray = add_noise(gray, std_dev=noise_intensity * 76)
|
37 |
+
gray = gray.filter(ImageFilter.GaussianBlur(radius=3))
|
38 |
+
gray = add_noise(gray, std_dev=noise_intensity * 23)
|
39 |
+
gray = gray.filter(ImageFilter.GaussianBlur(radius=2))
|
40 |
+
gray = add_noise(gray, std_dev=noise_intensity * 15)
|
41 |
+
|
42 |
+
# soft light works as in Photoshop
|
43 |
+
# Superimposes two images on top of each other using the Soft Light algorithm
|
44 |
+
result = ImageChops.soft_light(base_image, gray)
|
45 |
+
|
46 |
+
return np.array(result)
|
47 |
+
|
48 |
+
def detect_boundaries(mask, axis):
|
49 |
+
# calculate the boundaries of the mask
|
50 |
+
#axis = 0 results in x_from, x_to
|
51 |
+
#axis = 1 results in y_from, y_to
|
52 |
+
|
53 |
+
|
54 |
+
sum = mask.sum(axis=axis)
|
55 |
+
|
56 |
+
ind_from = min(sum.nonzero()[0])
|
57 |
+
ind_to = max(sum.nonzero()[0])
|
58 |
+
return ind_from, ind_to
|
59 |
+
|
60 |
+
def add_symmetric_filling_beyond_mask(img, mask):
|
61 |
+
for x in range(img.shape[1]):
|
62 |
+
if sum(mask[:, x]) != 0: #if there is at least one nonzero index
|
63 |
+
nonzero_indices = mask[:, x].nonzero()[0]
|
64 |
+
|
65 |
+
y_min = min(nonzero_indices)
|
66 |
+
y_max = max(nonzero_indices)
|
67 |
+
|
68 |
+
if y_max == y_min: #there is only one point
|
69 |
+
img[:, x] = img[y_min, x]
|
70 |
+
else:
|
71 |
+
next = y_min + 1
|
72 |
+
step = +1 # we start by going upwards
|
73 |
+
for y in reversed(range(y_min)):
|
74 |
+
img[y, x] = img[next, x]
|
75 |
+
if next == y_max or next == y_min: #we hit the boundaries - we reverse
|
76 |
+
step *= -1 #reverse direction
|
77 |
+
next += step
|
78 |
+
|
79 |
+
next = y_max - 1
|
80 |
+
step = -1 # we start by going downwards
|
81 |
+
for y in range(y_max + 1, img.shape[0]): #we hit the boundaries - we reverse
|
82 |
+
img[y, x] = img[next, x]
|
83 |
+
if next == y_max or next == y_min:
|
84 |
+
step *= -1 # reverse direction
|
85 |
+
next += step
|
86 |
+
return img
|
87 |
+
class AbstractDataset(torch.utils.data.Dataset):
|
88 |
+
|
89 |
+
def __init__(self,
|
90 |
+
model = None,
|
91 |
+
transforms=[],
|
92 |
+
#### distortions during training ####
|
93 |
+
hv_symmetry=True, # True or False
|
94 |
+
|
95 |
+
min_horizontal_subsampling = 50, # None to turn off; or minimal percentage of horizontal size of the image
|
96 |
+
min_vertical_subsampling = 70, # None to turn off; or minimal percentage of vertical size of the image
|
97 |
+
max_random_tilt = 3, # None to turn off; or maximum tilt in degrees
|
98 |
+
max_add_colors_to_histogram = 10, # 0 to turn off; or points of the histogram to be added
|
99 |
+
max_remove_colors_from_histogram = 30, # 0 to turn off; or points of the histogram to be removed
|
100 |
+
max_noise_intensity = 3.0, # 0.0 to turn off; or max intensity of the noise
|
101 |
+
|
102 |
+
gaussian_phase_transforms_epoch=None, # None to turn off; or number of the epoch when the gaussian phase starts
|
103 |
+
min_horizontal_subsampling_gaussian_phase = 30, # None to turn off; or minimal percentage of horizontal size of the image
|
104 |
+
min_vertical_subsampling_gaussian_phase = 70, # None to turn off; or minimal percentage of vertical size of the image
|
105 |
+
max_random_tilt_gaussian_phase = 2, # None to turn off; or maximum tilt in degrees
|
106 |
+
max_add_colors_to_histogram_gaussian_phase = 10, # 0 to turn off; or points of the histogram to be added
|
107 |
+
max_remove_colors_from_histogram_gaussian_phase = 60, # 0 to turn off; or points of the histogram to be removed
|
108 |
+
max_noise_intensity_gaussian_phase = 3.5, # 0.0 to turn off; or max intensity of the noise
|
109 |
+
|
110 |
+
#### controling variables ####
|
111 |
+
transform_level=2, # 0 - no transforms, 1 - only the basic transform, 2 - all transforms, -1 - subsampling for high images
|
112 |
+
retain_raw_images=False,
|
113 |
+
retain_masks=False):
|
114 |
+
|
115 |
+
|
116 |
+
self.model = model # we need that to check epoch number during training
|
117 |
+
|
118 |
+
self.hv_symmetry = hv_symmetry
|
119 |
+
|
120 |
+
self.min_horizontal_subsampling = min_horizontal_subsampling
|
121 |
+
self.min_vertical_subsampling = min_vertical_subsampling
|
122 |
+
self.max_random_tilt = max_random_tilt
|
123 |
+
self.max_add_colors_to_histogram = max_add_colors_to_histogram
|
124 |
+
self.max_remove_colors_from_histogram = max_remove_colors_from_histogram
|
125 |
+
self.max_noise_intensity = max_noise_intensity
|
126 |
+
|
127 |
+
self.gaussian_phase_transforms_epoch = gaussian_phase_transforms_epoch
|
128 |
+
self.min_horizontal_subsampling_gaussian_phase = min_horizontal_subsampling_gaussian_phase
|
129 |
+
self.min_vertical_subsampling_gaussian_phase = min_vertical_subsampling_gaussian_phase
|
130 |
+
self.max_random_tilt_gaussian_phase = max_random_tilt_gaussian_phase
|
131 |
+
self.max_add_colors_to_histogram_gaussian_phase = max_add_colors_to_histogram_gaussian_phase
|
132 |
+
self.max_remove_colors_from_histogram_gaussian_phase = max_remove_colors_from_histogram_gaussian_phase
|
133 |
+
self.max_noise_intensity_gaussian_phase = max_noise_intensity_gaussian_phase
|
134 |
+
|
135 |
+
self.image_height = model_config['image_height']
|
136 |
+
self.image_width = model_config['image_width']
|
137 |
+
|
138 |
+
self.transform_level = transform_level
|
139 |
+
self.retain_raw_images = retain_raw_images
|
140 |
+
self.retain_masks = retain_masks
|
141 |
+
self.transforms = transforms
|
142 |
+
|
143 |
+
|
144 |
+
def get_image_and_mask(self, row):
|
145 |
+
raise NotImplementedError("Subclass needs to implement this method")
|
146 |
+
|
147 |
+
def load_and_transform_image_and_mask(self, row):
|
148 |
+
img, mask = self.get_image_and_mask(row)
|
149 |
+
|
150 |
+
angle = row['angle']
|
151 |
+
#check if gaussian phase is on
|
152 |
+
if self.gaussian_phase_transforms_epoch is not None and self.model.current_epoch >= self.gaussian_phase_transforms_epoch:
|
153 |
+
max_random_tilt = self.max_random_tilt_gaussian_phase
|
154 |
+
max_noise_intensity = self.max_noise_intensity_gaussian_phase
|
155 |
+
min_horizontal_subsampling = self.min_horizontal_subsampling_gaussian_phase
|
156 |
+
min_vertical_subsampling = self.min_vertical_subsampling_gaussian_phase
|
157 |
+
max_add_colors_to_histogram = self.max_add_colors_to_histogram_gaussian_phase
|
158 |
+
max_remove_colors_from_histogram = self.max_remove_colors_from_histogram_gaussian_phase
|
159 |
+
else:
|
160 |
+
max_random_tilt = self.max_random_tilt
|
161 |
+
max_noise_intensity = self.max_noise_intensity
|
162 |
+
min_horizontal_subsampling = self.min_horizontal_subsampling
|
163 |
+
min_vertical_subsampling = self.min_vertical_subsampling
|
164 |
+
max_add_colors_to_histogram = self.max_add_colors_to_histogram
|
165 |
+
max_remove_colors_from_histogram = self.max_remove_colors_from_histogram
|
166 |
+
|
167 |
+
|
168 |
+
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
if self.transform_level >= 2 and max_random_tilt is not None:
|
173 |
+
####### RANDOM TILT
|
174 |
+
angle += np.random.uniform(-max_random_tilt, max_random_tilt)
|
175 |
+
|
176 |
+
img = scipy.ndimage.rotate(img, 90 - angle, reshape=True, order=3) # HORIZONTAL POSITION
|
177 |
+
###the part of the image that is added after rotation is all black (0s)
|
178 |
+
mask = scipy.ndimage.rotate(mask, 90 - angle, reshape=True, order = 0) # HORIZONTAL POSITION
|
179 |
+
#order = 0 is the nearest neighbor interpolation, so the mask is not interpolated
|
180 |
+
|
181 |
+
############# CROP
|
182 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
183 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
184 |
+
|
185 |
+
#crop the image to the verical and horizontal limits.
|
186 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
187 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
188 |
+
|
189 |
+
|
190 |
+
img_raw = img.copy()
|
191 |
+
|
192 |
+
|
193 |
+
if self.transform_level >= 2:
|
194 |
+
########## ADDING NOISE
|
195 |
+
|
196 |
+
if max_noise_intensity > 0.0:
|
197 |
+
noise_intensity = np.random.random() * max_noise_intensity
|
198 |
+
noisy_img = add_microscope_noise(img, noise_intensity=noise_intensity)
|
199 |
+
img[mask] = noisy_img[mask]
|
200 |
+
|
201 |
+
if self.transform_level == -1:
|
202 |
+
#special case where we take at most 300 middle pixels from the image
|
203 |
+
# (vertical subsampling)
|
204 |
+
# to handle very latge images correctly
|
205 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
206 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
207 |
+
|
208 |
+
y_size = y_to - y_from + 1
|
209 |
+
|
210 |
+
random_size = 300 #not so random, ay?
|
211 |
+
|
212 |
+
if y_size > random_size:
|
213 |
+
random_start = y_size // 2 - random_size // 2
|
214 |
+
|
215 |
+
y_from = random_start
|
216 |
+
y_to = random_start + random_size - 1
|
217 |
+
|
218 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
219 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
220 |
+
|
221 |
+
# recrop the image if necessary
|
222 |
+
# -- even after only horizontal subsampling it may be necessary to recrop the image
|
223 |
+
|
224 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
225 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
226 |
+
|
227 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
228 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
229 |
+
|
230 |
+
if self.transform_level >= 1:
|
231 |
+
############## HORIZONTAL SUBSAMPLING
|
232 |
+
if min_horizontal_subsampling is not None:
|
233 |
+
x_size = x_to - x_from + 1
|
234 |
+
|
235 |
+
# add some random horizontal shift
|
236 |
+
random_size = np.random.randint(x_size * min_horizontal_subsampling / 100.0, x_size + 1)
|
237 |
+
random_start = np.random.randint(0, x_size - random_size + 1) + x_from
|
238 |
+
|
239 |
+
img = img[:, random_start:(random_start + random_size)]
|
240 |
+
mask = mask[:, random_start:(random_start + random_size)]
|
241 |
+
|
242 |
+
############ VERTICAL SUBSAMPLING
|
243 |
+
if min_vertical_subsampling is not None:
|
244 |
+
|
245 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
246 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
247 |
+
|
248 |
+
y_size = y_to - y_from + 1
|
249 |
+
|
250 |
+
random_size = np.random.randint(y_size * min_vertical_subsampling / 100.0, y_size + 1)
|
251 |
+
random_start = np.random.randint(0, y_size - random_size + 1) + y_from
|
252 |
+
|
253 |
+
y_from = random_start
|
254 |
+
y_to = random_start + random_size - 1
|
255 |
+
|
256 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
257 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
258 |
+
|
259 |
+
if min_horizontal_subsampling is not None or min_vertical_subsampling is not None:
|
260 |
+
#recrop the image if necessary
|
261 |
+
# -- even after only horizontal subsampling it may be necessary to recrop the image
|
262 |
+
|
263 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
264 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
265 |
+
|
266 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
267 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
268 |
+
|
269 |
+
|
270 |
+
######### ADD SYMMETRIC FILLING OF THE IMAGE BEYOND THE MASK
|
271 |
+
#img = add_symmetric_filling_beyond_mask(img, mask)
|
272 |
+
#This leaves holes in the image, so we will not use it
|
273 |
+
|
274 |
+
#plt.imshow(img)
|
275 |
+
#plt.show()
|
276 |
+
######### HORIZONTAL AND VERTICAL SYMMETRY.
|
277 |
+
# When superimposed, the result is 180 degree rotation
|
278 |
+
if self.transform_level >= 1 and self.hv_symmetry:
|
279 |
+
for axis in range(2):
|
280 |
+
if np.random.randint(0, 2) % 2 == 0:
|
281 |
+
img = np.flip(img, axis = axis)
|
282 |
+
mask = np.flip(mask, axis = axis)
|
283 |
+
#plt.imshow(img)
|
284 |
+
#plt.show()
|
285 |
+
|
286 |
+
if self.transform_level >= 2 and (max_add_colors_to_histogram > 0 or max_remove_colors_from_histogram > 0):
|
287 |
+
lower_bound = np.random.randint(-max_add_colors_to_histogram, max_remove_colors_from_histogram + 1)
|
288 |
+
upper_bound = np.random.randint(255 - max_remove_colors_from_histogram, 255 + max_add_colors_to_histogram + 1)
|
289 |
+
# first clip the values outstanding from the range (lower_bound -- upper_bound)
|
290 |
+
img[mask] = np.clip(img[mask], lower_bound, upper_bound)
|
291 |
+
# the range (lower_bound -- upper_bound) gets mapped to the range (0--255)
|
292 |
+
# but only in a portion of the image where mask = True
|
293 |
+
img[mask] = np.interp(img[mask], (lower_bound, upper_bound), (0, 255)).astype(np.uint8)
|
294 |
+
|
295 |
+
#### since preserve_range in skimage.transform.resize is set to False, the image
|
296 |
+
#### will be converted to float. Consult:
|
297 |
+
# https://scikit-image.org/docs/stable/api/skimage.transform.html#skimage.transform.resize
|
298 |
+
# https://scikit-image.org/docs/dev/user_guide/data_types.html
|
299 |
+
|
300 |
+
# In our case the image gets conveted to floats ranging 0-1
|
301 |
+
old_height = img.shape[0]
|
302 |
+
img = skimage.transform.resize(img, (self.image_height, self.image_width), order=3)
|
303 |
+
new_height = img.shape[0]
|
304 |
+
mask = skimage.transform.resize(mask, (self.image_height, self.image_width), order=0, preserve_range=True)
|
305 |
+
# order = 0 is the nearest neighbor interpolation, so the mask is not interpolated
|
306 |
+
|
307 |
+
scale_factor = new_height / old_height
|
308 |
+
|
309 |
+
|
310 |
+
#plt.imshow(img)
|
311 |
+
#plt.show()
|
312 |
+
#plt.imshow(mask)
|
313 |
+
#plt.show()
|
314 |
+
return img, mask, scale_factor, img_raw
|
315 |
+
|
316 |
+
def get_annotations_row(self, idx):
|
317 |
+
raise NotImplementedError("Subclass needs to implement this method")
|
318 |
+
|
319 |
+
def __getitem__(self, idx):
|
320 |
+
row = self.get_annotations_row(idx)
|
321 |
+
|
322 |
+
image, mask, scale_factor, image_raw = self.load_and_transform_image_and_mask(row)
|
323 |
+
|
324 |
+
image_data = {
|
325 |
+
'image': image,
|
326 |
+
}
|
327 |
+
|
328 |
+
for transform in self.transforms:
|
329 |
+
image_data = transform(**image_data)
|
330 |
+
# transform operates on image field ONLY of image_data, and returns a dictionary with the same keys
|
331 |
+
|
332 |
+
ret_dict = {
|
333 |
+
'image': image_data['image'],
|
334 |
+
'period_px': torch.tensor(row['period_nm'] * scale_factor * row['px_per_nm'], dtype=torch.float32),
|
335 |
+
'filename': row['granum_image'],
|
336 |
+
'px_per_nm': row['px_per_nm'],
|
337 |
+
'scale': scale_factor, # the scale factor is used to calculate the true period error
|
338 |
+
# (before scale) in losses and metrics
|
339 |
+
'neutral': -self.transforms[0].mean/self.transforms[0].std #value of 0 after the scale transform
|
340 |
+
}
|
341 |
+
|
342 |
+
if self.retain_raw_images:
|
343 |
+
ret_dict['image_raw'] = image_raw
|
344 |
+
|
345 |
+
if self.retain_masks:
|
346 |
+
ret_dict['mask'] = mask
|
347 |
+
|
348 |
+
return ret_dict
|
349 |
+
|
350 |
+
def __len__(self):
|
351 |
+
raise NotImplementedError("Subclass needs to implement this method")
|
352 |
+
|
353 |
+
class ImageDataset(AbstractDataset):
|
354 |
+
def __init__(self, annotations, data_dir: Path, *args, **kwargs):
|
355 |
+
super().__init__(*args, **kwargs)
|
356 |
+
self.data_dir = Path(data_dir)
|
357 |
+
|
358 |
+
self.id = 1
|
359 |
+
|
360 |
+
if isinstance(annotations, str):
|
361 |
+
annotations = data_dir / annotations #make it a Path object relative to data_dir
|
362 |
+
|
363 |
+
if isinstance(annotations, Path):
|
364 |
+
self.annotations = pd.read_csv(data_dir / annotations)
|
365 |
+
no_period = ['27_k7 [1]_4.png']
|
366 |
+
del_img = ['38_k42[1]_19.png', 'n6363_araLL_60kx_6 [1]_0.png', '27_hs8 [1]_5.png', '27_k7 [1]_20.png',
|
367 |
+
'F1_1_60kx_01 [1]_2.png']
|
368 |
+
self.annotations = self.annotations[~self.annotations['granum_image'].isin(no_period)]
|
369 |
+
self.annotations = self.annotations[~self.annotations['granum_image'].isin(del_img)]
|
370 |
+
else:
|
371 |
+
self.annotations = annotations
|
372 |
+
|
373 |
+
def get_image_and_mask(self, row):
|
374 |
+
filename = row['granum_image']
|
375 |
+
img_path = self.data_dir / filename
|
376 |
+
img_raw = skimage.io.imread(img_path)
|
377 |
+
|
378 |
+
img = img_raw[:, :, 0] # all three channels are equal, with the exception
|
379 |
+
# of the last channel which is the full blue (0,0,255) for outside the mask (so blue channel is 255, red and green are 0)
|
380 |
+
mask = (img_raw != (0, 0, 255)).any(axis=2)
|
381 |
+
return img, mask
|
382 |
+
|
383 |
+
def get_annotations_row(self, idx):
|
384 |
+
row = self.annotations.iloc[idx].to_dict()
|
385 |
+
row['idx'] = idx
|
386 |
+
return row
|
387 |
+
|
388 |
+
def __len__(self):
|
389 |
+
return len(self.annotations)
|
390 |
+
|
391 |
+
class ArtificialDataset(AbstractDataset):
|
392 |
+
def __init__(self,
|
393 |
+
min_period = 20,
|
394 |
+
max_period = 140,
|
395 |
+
white_fraction_min = 0.15,
|
396 |
+
white_fraction_max=0.45,
|
397 |
+
|
398 |
+
noise_min_sd = 0.0,
|
399 |
+
noise_max_sd = 100.0,
|
400 |
+
noise_max_sd_everywhere = 20.0, # 20.0
|
401 |
+
leftovers_max = 5,
|
402 |
+
|
403 |
+
get_real_masks_dataset = None, #None or instance of ImageDataset
|
404 |
+
*args, **kwargs):
|
405 |
+
super().__init__(*args, **kwargs)
|
406 |
+
self.id = 0
|
407 |
+
self.min_period = min_period
|
408 |
+
self.max_period = max_period
|
409 |
+
self.white_fraction_min = white_fraction_min
|
410 |
+
self.white_fraction_max = white_fraction_max
|
411 |
+
|
412 |
+
self.receptive_field_height = model_config['receptive_field_height']
|
413 |
+
self.stride_height = model_config['stride_height']
|
414 |
+
self.receptive_field_width = model_config['receptive_field_width']
|
415 |
+
self.stride_width = model_config['stride_width']
|
416 |
+
|
417 |
+
self.noise_min_sd = noise_min_sd
|
418 |
+
self.noise_max_sd = noise_max_sd
|
419 |
+
self.noise_max_sd_everywhere = noise_max_sd_everywhere
|
420 |
+
|
421 |
+
self.leftovers_max = leftovers_max
|
422 |
+
|
423 |
+
self.get_real_masks_dataset = get_real_masks_dataset
|
424 |
+
|
425 |
+
|
426 |
+
def get_image_and_mask(self, row):
|
427 |
+
# generate a rectangular image of black and white horizontal stripes
|
428 |
+
# with black stripes varying with white stripes
|
429 |
+
|
430 |
+
period_px = row['period_nm'] * row['px_per_nm']
|
431 |
+
# white occupying 5-20 % of a total period (white+black)
|
432 |
+
white_px = np.random.randint(period_px * self.white_fraction_min, period_px * self.white_fraction_max + 1)
|
433 |
+
|
434 |
+
|
435 |
+
# mask is rectangle of True values
|
436 |
+
img = np.zeros((self.image_height, self.image_width), dtype=np.uint8)
|
437 |
+
mask = np.ones((self.image_height, self.image_width), dtype=bool)
|
438 |
+
black_px = period_px - white_px
|
439 |
+
random_start = np.random.randint(0, period_px+1)
|
440 |
+
for i in range(self.image_height):
|
441 |
+
if (random_start+i) % (black_px + white_px) < black_px:
|
442 |
+
# sample width with random numbers from 0 to 101
|
443 |
+
img[i, :] = np.random.randint(0, 101, self.image_width)
|
444 |
+
else:
|
445 |
+
#sample width with random numbers from 156 to 255
|
446 |
+
img[i, :] = np.random.randint(156, 256, self.image_width)
|
447 |
+
|
448 |
+
if self.noise_max_sd_everywhere > self.noise_min_sd:
|
449 |
+
sd = np.random.uniform(self.noise_min_sd, self.noise_max_sd_everywhere)
|
450 |
+
noise = np.random.normal(0, sd, (self.image_height, self.image_width))
|
451 |
+
img = np.clip(img+noise.astype(img.dtype), 0, 255)
|
452 |
+
|
453 |
+
if self.noise_max_sd > self.noise_min_sd:
|
454 |
+
# there is also a metagrid in the image
|
455 |
+
# consisting of overlapping receptive fields of size 190x42
|
456 |
+
# with stride 64x4
|
457 |
+
# the metagrid is 5x102
|
458 |
+
overlapping_fields_count_height = (self.image_height - self.receptive_field_height) // self.stride_height + 1
|
459 |
+
overlapping_fields_count_width = (self.image_width - self.receptive_field_width) // self.stride_width + 1
|
460 |
+
|
461 |
+
|
462 |
+
sd = np.random.uniform(self.noise_min_sd, self.noise_max_sd)
|
463 |
+
noise = np.random.normal(0, sd, (self.image_height, self.image_width))
|
464 |
+
|
465 |
+
#there will be some left-over metagrid rectangles
|
466 |
+
leftovers_count = np.random.randint(1, self.leftovers_max)
|
467 |
+
for i in range(leftovers_count):
|
468 |
+
metagrid_row = np.random.randint(0, overlapping_fields_count_height)
|
469 |
+
metagrid_col = np.random.randint(0, overlapping_fields_count_width)
|
470 |
+
#zero-out the noise inside the selected metagrid
|
471 |
+
noise[metagrid_row * self.stride_height:metagrid_row * self.stride_height + self.receptive_field_height + 1, \
|
472 |
+
metagrid_col * self.stride_width :metagrid_col * self.stride_width + self.receptive_field_width + 1] = 0
|
473 |
+
|
474 |
+
#add noise to the image
|
475 |
+
img = np.clip(img+noise.astype(img.dtype), 0, 255)
|
476 |
+
|
477 |
+
if self.get_real_masks_dataset is not None:
|
478 |
+
ret_dict = self.get_real_masks_dataset.__getitem__(row['idx'] % len(self.get_real_masks_dataset))
|
479 |
+
mask = ret_dict['mask'] #this mask is already sized target height-by-width
|
480 |
+
|
481 |
+
img[mask == False] = 0
|
482 |
+
|
483 |
+
return img, mask
|
484 |
+
|
485 |
+
def get_annotations_row(self, idx):
|
486 |
+
return {'idx': idx,
|
487 |
+
'period_nm': np.random.randint(self.min_period, self.max_period),
|
488 |
+
'px_per_nm': 1.0,
|
489 |
+
'granum_image': 'artificial_%d.png' % idx,
|
490 |
+
'angle': 90}
|
491 |
+
|
492 |
+
|
493 |
+
def __len__(self):
|
494 |
+
return 237 # number of samples as in real data in the train set (70% of 339 is 237,3)
|
495 |
+
|
496 |
+
|
497 |
+
class AdHocDataset(AbstractDataset):
|
498 |
+
def __init__(self, images_masks_pxpernm: list[tuple[np.ndarray, np.ndarray, float]], *args, **kwargs):
|
499 |
+
super().__init__(*args, **kwargs)
|
500 |
+
self.data = images_masks_pxpernm
|
501 |
+
|
502 |
+
def __len__(self):
|
503 |
+
return len(self.data)
|
504 |
+
|
505 |
+
def __getitem__(self, idx):
|
506 |
+
image, mask, px_per_nm = self.data[idx]
|
507 |
+
|
508 |
+
image, mask, scale_factor, image_raw = self.load_and_transform_image_and_mask(image, mask)
|
509 |
+
|
510 |
+
image_data = {
|
511 |
+
'image': image,
|
512 |
+
}
|
513 |
+
|
514 |
+
for transform in self.transforms:
|
515 |
+
image_data = transform(**image_data)
|
516 |
+
# transform operates on image field ONLY of image_data, and returns a dictionary with the same keys
|
517 |
+
|
518 |
+
ret_dict = {
|
519 |
+
'image': image_data['image'],
|
520 |
+
'period_px': torch.tensor(0, dtype=torch.float32),
|
521 |
+
'filename': str(idx),
|
522 |
+
'px_per_nm': px_per_nm,
|
523 |
+
'scale': scale_factor, # the scale factor is used to calculate the true period error
|
524 |
+
# (before scale) in losses and metrics
|
525 |
+
'neutral': -self.transforms[0].mean/self.transforms[0].std #value of 0 after the scale transform
|
526 |
+
}
|
527 |
+
|
528 |
+
if self.retain_raw_images:
|
529 |
+
ret_dict['image_raw'] = image_raw
|
530 |
+
|
531 |
+
if self.retain_masks:
|
532 |
+
ret_dict['mask'] = mask
|
533 |
+
|
534 |
+
return ret_dict
|
535 |
+
|
536 |
+
|
537 |
+
def load_and_transform_image_and_mask(self, img, mask):
|
538 |
+
|
539 |
+
angle = 90
|
540 |
+
#check if gaussian phase is on
|
541 |
+
if self.gaussian_phase_transforms_epoch is not None and self.model.current_epoch >= self.gaussian_phase_transforms_epoch:
|
542 |
+
max_random_tilt = self.max_random_tilt_gaussian_phase
|
543 |
+
max_noise_intensity = self.max_noise_intensity_gaussian_phase
|
544 |
+
min_horizontal_subsampling = self.min_horizontal_subsampling_gaussian_phase
|
545 |
+
min_vertical_subsampling = self.min_vertical_subsampling_gaussian_phase
|
546 |
+
max_add_colors_to_histogram = self.max_add_colors_to_histogram_gaussian_phase
|
547 |
+
max_remove_colors_from_histogram = self.max_remove_colors_from_histogram_gaussian_phase
|
548 |
+
else:
|
549 |
+
max_random_tilt = self.max_random_tilt
|
550 |
+
max_noise_intensity = self.max_noise_intensity
|
551 |
+
min_horizontal_subsampling = self.min_horizontal_subsampling
|
552 |
+
min_vertical_subsampling = self.min_vertical_subsampling
|
553 |
+
max_add_colors_to_histogram = self.max_add_colors_to_histogram
|
554 |
+
max_remove_colors_from_histogram = self.max_remove_colors_from_histogram
|
555 |
+
|
556 |
+
|
557 |
+
if self.transform_level >= 2 and max_random_tilt is not None:
|
558 |
+
####### RANDOM TILT
|
559 |
+
angle += np.random.uniform(-max_random_tilt, max_random_tilt)
|
560 |
+
|
561 |
+
|
562 |
+
img = scipy.ndimage.rotate(img, 90 - angle, reshape=True, order=3) # HORIZONTAL POSITION
|
563 |
+
###the part of the image that is added after rotation is all black (0s)
|
564 |
+
mask = scipy.ndimage.rotate(mask, 90 - angle, reshape=True, order = 0) # HORIZONTAL POSITION
|
565 |
+
#order = 0 is the nearest neighbor interpolation, so the mask is not interpolated
|
566 |
+
|
567 |
+
############# CROP
|
568 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
569 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
570 |
+
|
571 |
+
#crop the image to the verical and horizontal limits.
|
572 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
573 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
574 |
+
|
575 |
+
|
576 |
+
img_raw = img.copy()
|
577 |
+
|
578 |
+
|
579 |
+
if self.transform_level >= 2:
|
580 |
+
########## ADDING NOISE
|
581 |
+
|
582 |
+
if max_noise_intensity > 0.0:
|
583 |
+
noise_intensity = np.random.random() * max_noise_intensity
|
584 |
+
noisy_img = add_microscope_noise(img, noise_intensity=noise_intensity)
|
585 |
+
img[mask] = noisy_img[mask]
|
586 |
+
|
587 |
+
if self.transform_level == -1:
|
588 |
+
#special case where we take at most 300 middle pixels from the image
|
589 |
+
# (vertical subsampling)
|
590 |
+
# to handle very latge images correctly
|
591 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
592 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
593 |
+
|
594 |
+
y_size = y_to - y_from + 1
|
595 |
+
|
596 |
+
random_size = 300 #not so random, ay?
|
597 |
+
|
598 |
+
if y_size > random_size:
|
599 |
+
random_start = y_size // 2 - random_size // 2
|
600 |
+
|
601 |
+
y_from = random_start
|
602 |
+
y_to = random_start + random_size - 1
|
603 |
+
|
604 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
605 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
606 |
+
|
607 |
+
# recrop the image if necessary
|
608 |
+
# -- even after only horizontal subsampling it may be necessary to recrop the image
|
609 |
+
|
610 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
611 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
612 |
+
|
613 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
614 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
615 |
+
|
616 |
+
if self.transform_level >= 1:
|
617 |
+
############## HORIZONTAL SUBSAMPLING
|
618 |
+
if min_horizontal_subsampling is not None:
|
619 |
+
x_size = x_to - x_from + 1
|
620 |
+
|
621 |
+
# add some random horizontal shift
|
622 |
+
random_size = np.random.randint(x_size * min_horizontal_subsampling / 100.0, x_size + 1)
|
623 |
+
random_start = np.random.randint(0, x_size - random_size + 1) + x_from
|
624 |
+
|
625 |
+
img = img[:, random_start:(random_start + random_size)]
|
626 |
+
mask = mask[:, random_start:(random_start + random_size)]
|
627 |
+
|
628 |
+
############ VERTICAL SUBSAMPLING
|
629 |
+
if min_vertical_subsampling is not None:
|
630 |
+
|
631 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
632 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
633 |
+
|
634 |
+
y_size = y_to - y_from + 1
|
635 |
+
|
636 |
+
random_size = np.random.randint(y_size * min_vertical_subsampling / 100.0, y_size + 1)
|
637 |
+
random_start = np.random.randint(0, y_size - random_size + 1) + y_from
|
638 |
+
|
639 |
+
y_from = random_start
|
640 |
+
y_to = random_start + random_size - 1
|
641 |
+
|
642 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
643 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
644 |
+
|
645 |
+
if min_horizontal_subsampling is not None or min_vertical_subsampling is not None:
|
646 |
+
#recrop the image if necessary
|
647 |
+
# -- even after only horizontal subsampling it may be necessary to recrop the image
|
648 |
+
|
649 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
650 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
651 |
+
|
652 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
653 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
654 |
+
|
655 |
+
|
656 |
+
######### ADD SYMMETRIC FILLING OF THE IMAGE BEYOND THE MASK
|
657 |
+
#img = add_symmetric_filling_beyond_mask(img, mask)
|
658 |
+
#This leaves holes in the image, so we will not use it
|
659 |
+
|
660 |
+
#plt.imshow(img)
|
661 |
+
#plt.show()
|
662 |
+
######### HORIZONTAL AND VERTICAL SYMMETRY.
|
663 |
+
# When superimposed, the result is 180 degree rotation
|
664 |
+
if self.transform_level >= 1 and self.hv_symmetry:
|
665 |
+
for axis in range(2):
|
666 |
+
if np.random.randint(0, 2) % 2 == 0:
|
667 |
+
img = np.flip(img, axis = axis)
|
668 |
+
mask = np.flip(mask, axis = axis)
|
669 |
+
#plt.imshow(img)
|
670 |
+
#plt.show()
|
671 |
+
|
672 |
+
if self.transform_level >= 2 and (max_add_colors_to_histogram > 0 or max_remove_colors_from_histogram > 0):
|
673 |
+
lower_bound = np.random.randint(-max_add_colors_to_histogram, max_remove_colors_from_histogram + 1)
|
674 |
+
upper_bound = np.random.randint(255 - max_remove_colors_from_histogram, 255 + max_add_colors_to_histogram + 1)
|
675 |
+
# first clip the values outstanding from the range (lower_bound -- upper_bound)
|
676 |
+
img[mask] = np.clip(img[mask], lower_bound, upper_bound)
|
677 |
+
# the range (lower_bound -- upper_bound) gets mapped to the range (0--255)
|
678 |
+
# but only in a portion of the image where mask = True
|
679 |
+
img[mask] = np.interp(img[mask], (lower_bound, upper_bound), (0, 255)).astype(np.uint8)
|
680 |
+
|
681 |
+
#### since preserve_range in skimage.transform.resize is set to False, the image
|
682 |
+
#### will be converted to float. Consult:
|
683 |
+
# https://scikit-image.org/docs/stable/api/skimage.transform.html#skimage.transform.resize
|
684 |
+
# https://scikit-image.org/docs/dev/user_guide/data_types.html
|
685 |
+
|
686 |
+
# In our case the image gets conveted to floats ranging 0-1
|
687 |
+
old_height = img.shape[0]
|
688 |
+
img = skimage.transform.resize(img, (self.image_height, self.image_width), order=3)
|
689 |
+
new_height = img.shape[0]
|
690 |
+
mask = skimage.transform.resize(mask, (self.image_height, self.image_width), order=0, preserve_range=True)
|
691 |
+
# order = 0 is the nearest neighbor interpolation, so the mask is not interpolated
|
692 |
+
|
693 |
+
scale_factor = new_height / old_height
|
694 |
+
|
695 |
+
|
696 |
+
#plt.imshow(img)
|
697 |
+
#plt.show()
|
698 |
+
#plt.imshow(mask)
|
699 |
+
#plt.show()
|
700 |
+
return img, mask, scale_factor, img_raw
|
701 |
+
|
702 |
+
|
703 |
+
class AdHocDataset2(AbstractDataset):
|
704 |
+
def __init__(self, images_masks_pxpernm: list[tuple[np.ndarray, np.ndarray, float]], *args, **kwargs):
|
705 |
+
super().__init__(*args, **kwargs)
|
706 |
+
self.data = images_masks_pxpernm
|
707 |
+
|
708 |
+
def __len__(self):
|
709 |
+
return len(self.data)
|
710 |
+
|
711 |
+
def __getitem__(self, idx):
|
712 |
+
image, mask, px_per_nm = self.data[idx]
|
713 |
+
|
714 |
+
image, mask, scale_factor, image_raw = self.load_and_transform_image_and_mask(image, mask)
|
715 |
+
|
716 |
+
image_data = {
|
717 |
+
'image': image,
|
718 |
+
}
|
719 |
+
|
720 |
+
for transform in self.transforms:
|
721 |
+
image_data = transform(**image_data)
|
722 |
+
# transform operates on image field ONLY of image_data, and returns a dictionary with the same keys
|
723 |
+
|
724 |
+
ret_dict = {
|
725 |
+
'image': image_data['image'],
|
726 |
+
'scale': scale_factor, # the scale factor is used to calculate the true period error
|
727 |
+
# (before scale) in losses and metrics
|
728 |
+
'neutral': -self.transforms[0].mean/self.transforms[0].std #value of 0 after the scale transform
|
729 |
+
}
|
730 |
+
|
731 |
+
return ret_dict
|
732 |
+
|
733 |
+
|
734 |
+
def load_and_transform_image_and_mask(self, img, mask):
|
735 |
+
|
736 |
+
img_raw = img.copy()
|
737 |
+
|
738 |
+
if self.transform_level == -1:
|
739 |
+
#special case where we take at most 300 middle pixels from the image
|
740 |
+
# (vertical subsampling)
|
741 |
+
# to handle very latge images correctly
|
742 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
743 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
744 |
+
|
745 |
+
y_size = y_to - y_from + 1
|
746 |
+
|
747 |
+
max_size = 300
|
748 |
+
|
749 |
+
if y_size > max_size:
|
750 |
+
random_start = y_size // 2 - max_size // 2
|
751 |
+
|
752 |
+
y_from = random_start
|
753 |
+
y_to = random_start + max_size - 1
|
754 |
+
|
755 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
756 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
757 |
+
|
758 |
+
# recrop the image if necessary
|
759 |
+
# -- even after only horizontal subsampling it may be necessary to recrop the image
|
760 |
+
|
761 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
762 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
763 |
+
|
764 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
765 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
766 |
+
|
767 |
+
|
768 |
+
#### since preserve_range in skimage.transform.resize is set to False, the image
|
769 |
+
#### will be converted to float. Consult:
|
770 |
+
# https://scikit-image.org/docs/stable/api/skimage.transform.html#skimage.transform.resize
|
771 |
+
# https://scikit-image.org/docs/dev/user_guide/data_types.html
|
772 |
+
|
773 |
+
# In our case the image gets conveted to floats ranging 0-1
|
774 |
+
old_height = img.shape[0]
|
775 |
+
img = skimage.transform.resize(img, (self.image_height, self.image_width), order=3)
|
776 |
+
new_height = img.shape[0]
|
777 |
+
mask = skimage.transform.resize(mask, (self.image_height, self.image_width), order=0, preserve_range=True)
|
778 |
+
# order = 0 is the nearest neighbor interpolation, so the mask is not interpolated
|
779 |
+
|
780 |
+
scale_factor = new_height / old_height
|
781 |
+
|
782 |
+
return img, mask, scale_factor, img_raw
|
783 |
+
|
784 |
+
class AdHocDataset3(AbstractDataset):
|
785 |
+
def __init__(self, images_and_masks: list[tuple[np.ndarray, np.ndarray]], *args, **kwargs):
|
786 |
+
super().__init__(*args, **kwargs)
|
787 |
+
self.data = images_and_masks
|
788 |
+
|
789 |
+
def __len__(self):
|
790 |
+
return len(self.data)
|
791 |
+
|
792 |
+
def __getitem__(self, idx):
|
793 |
+
image, mask = self.data[idx]
|
794 |
+
|
795 |
+
image, mask, scale_factor = self.load_and_transform_image_and_mask(image, mask)
|
796 |
+
|
797 |
+
image_data = {
|
798 |
+
'image': image,
|
799 |
+
}
|
800 |
+
|
801 |
+
for transform in self.transforms:
|
802 |
+
image_data = transform(**image_data)
|
803 |
+
# transform operates on image field ONLY of image_data, and returns a dictionary with the same keys
|
804 |
+
|
805 |
+
ret_dict = {
|
806 |
+
'image': image_data['image'],
|
807 |
+
'scale': scale_factor, # the scale factor is used to calculate the true period error
|
808 |
+
# (before scale) in losses and metrics
|
809 |
+
#value of 0 after the scale transform
|
810 |
+
}
|
811 |
+
|
812 |
+
return ret_dict
|
813 |
+
|
814 |
+
|
815 |
+
def load_and_transform_image_and_mask(self, img, mask):
|
816 |
+
|
817 |
+
if self.transform_level == -1:
|
818 |
+
#special case where we take at most 300 middle pixels from the image
|
819 |
+
# (vertical subsampling)
|
820 |
+
# to handle very latge images correctly
|
821 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
822 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
823 |
+
|
824 |
+
y_size = y_to - y_from + 1
|
825 |
+
|
826 |
+
max_size = 300
|
827 |
+
|
828 |
+
if y_size > max_size:
|
829 |
+
random_start = y_size // 2 - max_size // 2
|
830 |
+
|
831 |
+
y_from = random_start
|
832 |
+
y_to = random_start + max_size - 1
|
833 |
+
|
834 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
835 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
836 |
+
|
837 |
+
# recrop the image if necessary
|
838 |
+
# -- even after only horizontal subsampling it may be necessary to recrop the image
|
839 |
+
|
840 |
+
x_from, x_to = detect_boundaries(mask, axis=0)
|
841 |
+
y_from, y_to = detect_boundaries(mask, axis=1)
|
842 |
+
|
843 |
+
img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
|
844 |
+
mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
|
845 |
+
|
846 |
+
|
847 |
+
#### since preserve_range in skimage.transform.resize is set to False, the image
|
848 |
+
#### will be converted to float. Consult:
|
849 |
+
# https://scikit-image.org/docs/stable/api/skimage.transform.html#skimage.transform.resize
|
850 |
+
# https://scikit-image.org/docs/dev/user_guide/data_types.html
|
851 |
+
|
852 |
+
# In our case the image gets conveted to floats ranging 0-1
|
853 |
+
old_height = img.shape[0]
|
854 |
+
img = skimage.transform.resize(img, (self.image_height, self.image_width), order=3)
|
855 |
+
new_height = img.shape[0]
|
856 |
+
mask = skimage.transform.resize(mask, (self.image_height, self.image_width), order=0, preserve_range=True)
|
857 |
+
# order = 0 is the nearest neighbor interpolation, so the mask is not interpolated
|
858 |
+
|
859 |
+
scale_factor = new_height / old_height
|
860 |
+
|
861 |
+
return img, mask, scale_factor
|
period_calculation/image_transforms.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import cv2
|
4 |
+
|
5 |
+
|
6 |
+
def np_batched_radon(image_batch):
|
7 |
+
#image_batch: torch tensor, batch_size x 1 x img_size x img_size
|
8 |
+
# squeeze order #1 and transform to numpy
|
9 |
+
|
10 |
+
image_batch = image_batch.squeeze(1).cpu().numpy()
|
11 |
+
|
12 |
+
batch_size, img_size = image_batch.shape[:2]
|
13 |
+
if batch_size > 512: # limit batch size to 512 because cv2.warpAffine fails for batch> 512
|
14 |
+
return np.concatenate([np_batched_radon(image_batch[i:i+512]) for i in range(0,batch_size,512)], axis=0)
|
15 |
+
theta = np.arange(180)
|
16 |
+
radon_image = np.zeros((image_batch.shape[0], img_size, len(theta)),
|
17 |
+
dtype='float32')
|
18 |
+
|
19 |
+
for i, angle in enumerate(theta):
|
20 |
+
M = cv2.getRotationMatrix2D(((img_size-1)/2.0,(img_size-1)/2.0),angle,1)
|
21 |
+
rotated = cv2.warpAffine(np.transpose(image_batch, (1, 2, 0)),M,(img_size,img_size))
|
22 |
+
|
23 |
+
#plt.imshow(rotated[:,:,0])
|
24 |
+
#plt.show()
|
25 |
+
|
26 |
+
if batch_size == 1: # cv2.warpAffine cancels batch dimension if equal to 1
|
27 |
+
rotated = rotated[:,:, np.newaxis]
|
28 |
+
rotated = np.transpose(rotated, (2, 0, 1)) / 224.0
|
29 |
+
#rotated = rotated / np.array(255, dtype='float32')
|
30 |
+
radon_image[:, :, i] = rotated.sum(axis=1)
|
31 |
+
|
32 |
+
#plot the image
|
33 |
+
|
34 |
+
# plt.imshow(radon_image[0])
|
35 |
+
# plt.show()
|
36 |
+
|
37 |
+
return radon_image
|
38 |
+
|
39 |
+
|
40 |
+
def torch_batched_radon(image_batch, neutral_value):
|
41 |
+
#image_batch: batch_size x 1 x img_size x img_size
|
42 |
+
#np_batched_radon(image_batch - neutral_value)
|
43 |
+
|
44 |
+
image_batch = image_batch - neutral_value # so the 0 value is neutral
|
45 |
+
|
46 |
+
|
47 |
+
batch_size = image_batch.shape[0]
|
48 |
+
img_size = image_batch.shape[2]
|
49 |
+
|
50 |
+
theta = np.arange(180) # we don't need torch here, we will evaluate individual angles below
|
51 |
+
|
52 |
+
radon_image = torch.zeros((batch_size, 1, img_size, len(theta)), dtype=torch.float, device=image_batch.device)
|
53 |
+
|
54 |
+
|
55 |
+
for i, angle in enumerate(theta):
|
56 |
+
#M = cv2.getRotationMatrix2D(((img_size-1)/2.0,(img_size-1)/2.0),angle,1)
|
57 |
+
#calculate the same rotation matrix but with torch:
|
58 |
+
M = torch.tensor(cv2.getRotationMatrix2D(((img_size-1)/2.0,(img_size-1)/2.0),angle,1)).to(image_batch.device, dtype=torch.float32)
|
59 |
+
angle = torch.tensor((angle+90)/180.0*np.pi)
|
60 |
+
M1 = torch.tensor([[torch.sin(angle), torch.cos(angle), 0],
|
61 |
+
[torch.cos(angle), -torch.sin(angle), 0]]).to(image_batch.device, dtype=torch.float32)
|
62 |
+
|
63 |
+
|
64 |
+
# we need to add a batch dimension to the rotation matrix
|
65 |
+
M1 = M1.repeat(batch_size, 1, 1)
|
66 |
+
|
67 |
+
grid = torch.nn.functional.affine_grid(M1, image_batch.shape, align_corners=False)
|
68 |
+
rotated = torch.nn.functional.grid_sample(image_batch, grid, mode='bilinear', padding_mode='zeros', align_corners=False)
|
69 |
+
rotated = rotated.squeeze(1)
|
70 |
+
|
71 |
+
#plt.imshow(rotated[0].cpu().numpy())
|
72 |
+
#plt.show()
|
73 |
+
|
74 |
+
radon_image[:, 0, :, i] = rotated.sum(axis=1) / 224.0 + neutral_value
|
75 |
+
|
76 |
+
#plt.imshow(radon_image[0, 0].cpu().numpy())
|
77 |
+
#plt.show()
|
78 |
+
|
79 |
+
return radon_image
|
period_calculation/models/abstract_model.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torch
|
3 |
+
|
4 |
+
from hydra.utils import instantiate
|
5 |
+
|
6 |
+
|
7 |
+
class AbstractModel(pl.LightningModule):
|
8 |
+
def __init__(self,
|
9 |
+
lr=0.001,
|
10 |
+
optimizer_hparams=dict(),
|
11 |
+
scheduler=dict(classname='MultiStepLR', kwargs=dict(milestones=[100, 150], gamma=0.1))
|
12 |
+
):
|
13 |
+
super().__init__()
|
14 |
+
# Exports the hyperparameters to a YAML file, and create "self.hparams" namespace
|
15 |
+
self.save_hyperparameters()
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
raise NotImplementedError("Subclass needs to implement this method")
|
19 |
+
|
20 |
+
def configure_optimizers(self):
|
21 |
+
# AdamW is Adam with a correct implementation of weight decay (see here
|
22 |
+
# for details: https://arxiv.org/pdf/1711.05101.pdf)
|
23 |
+
print("configuring the optimizer and lr scheduler with learning rate=%.5f"%self.hparams.lr)
|
24 |
+
optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, **self.hparams.optimizer_hparams)
|
25 |
+
# scheduler = getattr(torch.optim.lr_scheduler, self.hparams.lr_hparams['classname'])(optimizer, **self.hparams.lr_hparams['kwargs'])
|
26 |
+
if self.hparams.scheduler is not None:
|
27 |
+
scheduler = instantiate({**self.hparams.scheduler, '_partial_': True})(optimizer)
|
28 |
+
|
29 |
+
return [optimizer], [scheduler]
|
30 |
+
else:
|
31 |
+
return optimizer
|
32 |
+
|
33 |
+
def additional_losses(self):
|
34 |
+
"""get additional_losses"""
|
35 |
+
return torch.zeros((1))
|
36 |
+
|
37 |
+
def process_batch_supervised(self, batch):
|
38 |
+
"""get predictions, losses and mean errors (MAE)"""
|
39 |
+
raise NotImplementedError("Subclass needs to implement this method")
|
40 |
+
|
41 |
+
def log_all(self, losses, metrics, prefix=''):
|
42 |
+
for k, v in losses.items():
|
43 |
+
self.log(f'{prefix}{k}_loss', v.item() if isinstance(v, torch.Tensor) else v)
|
44 |
+
|
45 |
+
for k, v in metrics.items():
|
46 |
+
self.log(f'{prefix}{k}', v.item() if isinstance(v, torch.Tensor) else v)
|
47 |
+
|
48 |
+
def training_step(self, batch, batch_idx):
|
49 |
+
# "batch" is the output of the training data loader.
|
50 |
+
preds, losses, metrics = self.process_batch_supervised(batch)
|
51 |
+
self.log_all(losses, metrics, prefix='train_')
|
52 |
+
|
53 |
+
return losses['final']
|
54 |
+
|
55 |
+
def validation_step(self, batch, batch_idx):
|
56 |
+
preds, losses, metrics = self.process_batch_supervised(batch)
|
57 |
+
self.log_all(losses, metrics, prefix='val_')
|
58 |
+
|
59 |
+
def test_step(self, batch, batch_idx):
|
60 |
+
preds, losses, metrics = self.process_batch_supervised(batch)
|
61 |
+
self.log_all(losses, metrics, prefix='test_')
|
period_calculation/models/gauss_model.py
ADDED
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
from period_calculation.models.abstract_model import AbstractModel
|
6 |
+
|
7 |
+
from period_calculation.config import model_config # this is a dictionary with the model configuration
|
8 |
+
|
9 |
+
class GaussPeriodModel(AbstractModel):
|
10 |
+
def __init__(self,
|
11 |
+
*args, **kwargs
|
12 |
+
):
|
13 |
+
super().__init__(*args, **kwargs)
|
14 |
+
|
15 |
+
self.seq = torch.nn.Sequential(
|
16 |
+
torch.nn.Conv2d(1, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
|
17 |
+
torch.nn.ReLU(),
|
18 |
+
torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
|
19 |
+
torch.nn.MaxPool2d((2, 2), stride=(2, 2)),
|
20 |
+
|
21 |
+
torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
|
22 |
+
torch.nn.ReLU(),
|
23 |
+
torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
|
24 |
+
torch.nn.MaxPool2d((2, 1), stride=(2, 1)),
|
25 |
+
|
26 |
+
torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
|
27 |
+
torch.nn.ReLU(),
|
28 |
+
torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
|
29 |
+
torch.nn.MaxPool2d((2, 1), stride=(2, 1)),
|
30 |
+
|
31 |
+
torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
|
32 |
+
torch.nn.ReLU(),
|
33 |
+
torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
|
34 |
+
torch.nn.MaxPool2d((2, 1), stride=(2, 1)),
|
35 |
+
|
36 |
+
torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
|
37 |
+
torch.nn.MaxPool2d((2, 1), stride=(2, 1)),
|
38 |
+
|
39 |
+
torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
|
40 |
+
torch.nn.MaxPool2d((2, 1), stride=(2, 1)),
|
41 |
+
|
42 |
+
torch.nn.Dropout(0.1)
|
43 |
+
)
|
44 |
+
self.query = torch.nn.Parameter(torch.empty(1, 2, 32)) #two heads only
|
45 |
+
torch.nn.init.xavier_normal_(self.query)
|
46 |
+
|
47 |
+
self.linear1 = torch.nn.Linear(64, 8)
|
48 |
+
self.linear2 = torch.nn.Linear(8, 1)
|
49 |
+
|
50 |
+
self.query_sd = torch.nn.Parameter(torch.empty(1, 2, 32))
|
51 |
+
torch.nn.init.xavier_normal_(self.query_sd)
|
52 |
+
|
53 |
+
self.linear_sd1 = torch.nn.Linear(64, 8)
|
54 |
+
self.linear_sd2 = torch.nn.Linear(8, 1)
|
55 |
+
self.relu = torch.nn.ReLU()
|
56 |
+
|
57 |
+
|
58 |
+
def copy_network_trunk(self, model):
|
59 |
+
# https://discuss.pytorch.org/t/copy-weights-from-only-one-layer-of-one-model-to-another-model-with-different-structure/153419
|
60 |
+
with torch.no_grad():
|
61 |
+
for i, layer in enumerate(model.seq):
|
62 |
+
if i%2 == 0 and i!=20: #convolutional layers are the ones with even indexes with the exeption of the 20th (=dropout)
|
63 |
+
self.seq[i].weight.copy_(layer.weight)
|
64 |
+
self.seq[i].bias.copy_(layer.bias)
|
65 |
+
|
66 |
+
|
67 |
+
def copy_final_layers(self, model):
|
68 |
+
# https://discuss.pytorch.org/t/copy-weights-from-only-one-layer-of-one-model-to-another-model-with-different-structure/153419
|
69 |
+
|
70 |
+
with torch.no_grad():
|
71 |
+
self.linear1.weight.copy_(model.linear1.weight)
|
72 |
+
self.linear1.bias.copy_(model.linear1.bias)
|
73 |
+
|
74 |
+
self.linear2.weight.copy_(model.linear2.weight)
|
75 |
+
self.linear2.bias.copy_(model.linear2.bias)
|
76 |
+
|
77 |
+
self.query.copy_(model.query)
|
78 |
+
|
79 |
+
def duplicate_final_layers(self):
|
80 |
+
# https://discuss.pytorch.org/t/copy-weights-from-only-one-layer-of-one-model-to-another-model-with-different-structure/153419
|
81 |
+
|
82 |
+
with torch.no_grad():
|
83 |
+
self.linear_sd1.weight.copy_(self.linear1.weight)
|
84 |
+
self.linear_sd1.bias.copy_(self.linear1.bias)
|
85 |
+
|
86 |
+
self.linear_sd2.weight.copy_(self.linear2.weight/10)
|
87 |
+
self.linear_sd2.bias.copy_(self.linear2.bias/10)
|
88 |
+
|
89 |
+
self.query_sd.copy_(self.query)
|
90 |
+
|
91 |
+
def forward(self, x, neutral=None, return_raw=False):
|
92 |
+
#https://www.nature.com/articles/s41598-023-43852-x
|
93 |
+
|
94 |
+
# x is sized # batch x 1 x 476 x 476
|
95 |
+
|
96 |
+
preds = self.seq(x) # batch x 32 x 5 x 220
|
97 |
+
features = torch.flatten(preds, 2) # batch x 32 x 1100
|
98 |
+
|
99 |
+
# attention
|
100 |
+
energy = self.query @ features # batch x 2 x 1100
|
101 |
+
weights = torch.nn.functional.softmax(energy, 2) # batch x 2 x 1100
|
102 |
+
response = features @ weights.transpose(1, 2) # batch x 32 x 2
|
103 |
+
response = torch.flatten(response, 1) # batch x 64
|
104 |
+
|
105 |
+
preds = self.linear1(response) # batch x 8
|
106 |
+
preds = self.linear2(self.relu(preds)) # batch x 1
|
107 |
+
|
108 |
+
# attention sd
|
109 |
+
|
110 |
+
energy_sd = self.query_sd @ features # batch x 2 x 1100
|
111 |
+
weights_sd = torch.nn.functional.softmax(energy_sd, 2) # batch x 2 x 1100
|
112 |
+
response_sd = features @ weights_sd.transpose(1, 2) # batch x 32 x 2
|
113 |
+
response_sd = torch.flatten(response_sd, 1) # batch x 64
|
114 |
+
|
115 |
+
preds_sd = self.linear_sd1(response_sd) # batch x 8
|
116 |
+
preds_sd = self.linear_sd2(self.relu(preds_sd)) # batch x 1
|
117 |
+
|
118 |
+
outputs = [ model_config['receptive_field_height']/(preds[:,0]) , torch.exp(preds_sd[:,0]) ]
|
119 |
+
if return_raw:
|
120 |
+
outputs.append(preds)
|
121 |
+
outputs.append(preds_sd)
|
122 |
+
outputs.append(weights)
|
123 |
+
outputs.append(weights_sd)
|
124 |
+
|
125 |
+
return tuple(outputs)
|
126 |
+
|
127 |
+
def additional_losses(self):
|
128 |
+
"""get additional_losses"""
|
129 |
+
# additional (orthogonal) loss
|
130 |
+
# we multiply the two heads and later the MSE loss (towards zero) sums the result in L2 norm
|
131 |
+
# the idea is that the scalar product of two orthogonal vectors is zero
|
132 |
+
scalar_product = torch.cat((self.query[0, 0] * self.query[0, 1], self.query_sd[0, 0] * self.query_sd[0, 1]), dim=0)
|
133 |
+
orthogonal_loss = torch.nn.functional.mse_loss(scalar_product, torch.zeros_like(scalar_product))
|
134 |
+
return orthogonal_loss
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
def process_batch_supervised(self, batch):
|
139 |
+
"""get predictions, losses and mean errors (metrics)"""
|
140 |
+
|
141 |
+
# get predictions
|
142 |
+
preds = {}
|
143 |
+
preds['period_px'], preds['sd'] = self.forward(batch['image'], batch['neutral'][0], return_raw=False) # preds: period, sd, orto, preds_raw
|
144 |
+
|
145 |
+
# https://johaupt.github.io/blog/NN_prediction_uncertainty.html
|
146 |
+
# calculate losses
|
147 |
+
mse_period_px = torch.nn.functional.mse_loss(batch['period_px'],
|
148 |
+
preds['period_px'])
|
149 |
+
|
150 |
+
gaussian_nll = torch.nn.functional.gaussian_nll_loss(batch['period_px'],
|
151 |
+
preds['period_px'],
|
152 |
+
(preds['sd']) ** 2)
|
153 |
+
|
154 |
+
orthogonal_weight = 0.1
|
155 |
+
orthogonal_loss = self.additional_losses()
|
156 |
+
length_of_the_first_phase = 0
|
157 |
+
if self.current_epoch < length_of_the_first_phase:
|
158 |
+
#transition from MSE to Gaussian Negative Log Likelihood with sin/cos over first epochs
|
159 |
+
angle = torch.tensor((self.current_epoch) / (length_of_the_first_phase) * np.pi / 2)
|
160 |
+
total_loss = (gaussian_nll) * torch.sin(angle) + (mse_period_px) * torch.cos(angle) + orthogonal_weight * orthogonal_loss
|
161 |
+
else:
|
162 |
+
total_loss = gaussian_nll + orthogonal_weight * orthogonal_loss
|
163 |
+
|
164 |
+
losses = {
|
165 |
+
'gaussian_nll': gaussian_nll,
|
166 |
+
'mse_period_px': mse_period_px,
|
167 |
+
'orthogonal': orthogonal_loss,
|
168 |
+
'final': total_loss
|
169 |
+
}
|
170 |
+
|
171 |
+
# calculate mean errors
|
172 |
+
ground_truth_detached = batch['period_px'].detach().cpu().numpy()
|
173 |
+
print(ground_truth_detached)
|
174 |
+
mean_detached = preds['period_px'].detach().cpu().numpy()
|
175 |
+
print(mean_detached)
|
176 |
+
sd_detached = preds['sd'].detach().cpu().numpy()
|
177 |
+
print("==>", sd_detached)
|
178 |
+
px_per_nm_detached = batch['px_per_nm'].detach().cpu().numpy()
|
179 |
+
scale_detached = batch['scale'].detach().cpu().numpy()
|
180 |
+
|
181 |
+
period_px_difference = np.mean(abs(
|
182 |
+
ground_truth_detached - mean_detached
|
183 |
+
))
|
184 |
+
|
185 |
+
#initiate both with python array with 5 zeros
|
186 |
+
true_period_px_difference = [0.0] * 5
|
187 |
+
true_period_nm_difference = [0.0] * 5
|
188 |
+
|
189 |
+
for i, dist in enumerate([1.0, 2.0, 3.0, 4.0, 5.0]):
|
190 |
+
true_period_px_difference[i] = (np.sum(abs(
|
191 |
+
((ground_truth_detached - mean_detached) / scale_detached) * (sd_detached / scale_detached <dist))) \
|
192 |
+
/ np.sum(sd_detached / scale_detached < dist)) if np.sum(sd_detached / scale_detached < dist) > 0 else 0
|
193 |
+
|
194 |
+
for i, dist in enumerate([1.0, 2.0, 3.0, 4.0, 5.0]):
|
195 |
+
true_period_nm_difference[i] = (np.sum(abs(
|
196 |
+
((ground_truth_detached - mean_detached) / (scale_detached * px_per_nm_detached)) * (sd_detached / scale_detached <dist))) \
|
197 |
+
/ np.sum(sd_detached / scale_detached < dist)) if np.sum(sd_detached / scale_detached < dist) > 0 else 0
|
198 |
+
|
199 |
+
|
200 |
+
true_period_px_difference_all = np.mean(abs(
|
201 |
+
((ground_truth_detached - mean_detached) / scale_detached)
|
202 |
+
))
|
203 |
+
|
204 |
+
true_period_nm_difference_all = np.mean(abs(
|
205 |
+
((ground_truth_detached - mean_detached) / (scale_detached * px_per_nm_detached))
|
206 |
+
))
|
207 |
+
|
208 |
+
metrics = {
|
209 |
+
'period_px': period_px_difference,
|
210 |
+
'true_period_px_1': true_period_px_difference[0],
|
211 |
+
'true_period_px_2': true_period_px_difference[1],
|
212 |
+
'true_period_px_3': true_period_px_difference[2],
|
213 |
+
'true_period_px_4': true_period_px_difference[3],
|
214 |
+
'true_period_px_5': true_period_px_difference[4],
|
215 |
+
'true_period_px_all': true_period_px_difference_all,
|
216 |
+
|
217 |
+
'true_period_nm_1': true_period_nm_difference[0],
|
218 |
+
'true_period_nm_2': true_period_nm_difference[1],
|
219 |
+
'true_period_nm_3': true_period_nm_difference[2],
|
220 |
+
'true_period_nm_4': true_period_nm_difference[3],
|
221 |
+
'true_period_nm_5': true_period_nm_difference[4],
|
222 |
+
'true_period_nm_all': true_period_nm_difference_all,
|
223 |
+
|
224 |
+
'count_1': np.sum(sd_detached / scale_detached < 1.0),
|
225 |
+
'count_2': np.sum(sd_detached / scale_detached < 2.0),
|
226 |
+
'count_3': np.sum(sd_detached / scale_detached < 3.0),
|
227 |
+
'count_4': np.sum(sd_detached / scale_detached < 4.0),
|
228 |
+
'count_5': np.sum(sd_detached / scale_detached < 5.0),
|
229 |
+
|
230 |
+
'count_all': np.sum(sd_detached > 0.0),
|
231 |
+
'mean_sd': np.mean(sd_detached),
|
232 |
+
'sd_sd': np.std(sd_detached),
|
233 |
+
}
|
234 |
+
|
235 |
+
return preds, losses, metrics
|
236 |
+
|
237 |
+
|
period_calculation/period_measurer.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from pathlib import Path
|
3 |
+
import albumentations as A
|
4 |
+
from albumentations.pytorch import ToTensorV2
|
5 |
+
import skimage
|
6 |
+
import scipy
|
7 |
+
import numpy as np
|
8 |
+
from pytorch_lightning import seed_everything
|
9 |
+
|
10 |
+
from period_calculation.data_reader import AdHocDataset3
|
11 |
+
from period_calculation.models.gauss_model import GaussPeriodModel
|
12 |
+
|
13 |
+
|
14 |
+
transforms = [
|
15 |
+
A.Normalize(**{'mean': 0.2845, 'std': 0.1447}, max_pixel_value=1.0),
|
16 |
+
# Applies the formula (img - mean * max_pixel_value) / (std * max_pixel_value)
|
17 |
+
ToTensorV2()
|
18 |
+
]
|
19 |
+
|
20 |
+
class PeriodMeasurer:
|
21 |
+
"""returns period in pixels"""
|
22 |
+
def __init__(
|
23 |
+
self, weights_file, image_height=476, image_width=476,
|
24 |
+
px_per_nm = 1,
|
25 |
+
sd_threshold_nm=np.inf,
|
26 |
+
period_threshold_nm_min=0, period_threshold_nm_max=np.inf):
|
27 |
+
|
28 |
+
self.model = GaussPeriodModel.load_from_checkpoint(weights_file).to("cpu") #.eval()?
|
29 |
+
self.px_per_nm = px_per_nm
|
30 |
+
self.sd_threshold_nm = sd_threshold_nm
|
31 |
+
self.period_threshold_nm_min = period_threshold_nm_min
|
32 |
+
self.period_threshold_nm_max = period_threshold_nm_max
|
33 |
+
|
34 |
+
def __call__(self, img: np.ndarray, mask: np.ndarray) -> float:
|
35 |
+
seed_everything(44)
|
36 |
+
dataset = AdHocDataset3(
|
37 |
+
images_and_masks = [(img, mask)],
|
38 |
+
transform_level=-1,
|
39 |
+
retain_raw_images=False,
|
40 |
+
transforms=transforms
|
41 |
+
)
|
42 |
+
|
43 |
+
image_data = dataset[0]
|
44 |
+
with torch.no_grad():
|
45 |
+
y_hat, sd_hat = self.model(image_data["image"].unsqueeze(0), return_raw=False)
|
46 |
+
|
47 |
+
y_hat_nm = (y_hat/image_data["scale"]).item() / self.px_per_nm
|
48 |
+
sd_hat_nm = (sd_hat/image_data["scale"]).item() /self.px_per_nm
|
49 |
+
|
50 |
+
|
51 |
+
if (sd_hat_nm>self.sd_threshold_nm) or (y_hat_nm<self.period_threshold_nm_min) or (y_hat_nm>self.period_threshold_nm_max):
|
52 |
+
y_hat_nm = np.nan
|
53 |
+
|
54 |
+
return y_hat_nm, sd_hat_nm
|
requirements.txt
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pandas==2.1.1
|
2 |
+
torch==2.1.0
|
3 |
+
torchvision==0.16
|
4 |
+
ultralytics==8.0.216
|
5 |
+
scikit-image==0.22.0
|
6 |
+
pytorch-lightning==2.1.2
|
7 |
+
timm==0.9.11
|
8 |
+
albumentations==1.4.10
|
9 |
+
hydra-core==1.3.2
|
10 |
+
gradio==4.44.0
|
11 |
+
albucore==0.0.16
|
settings.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
DEMO = True
|
styles.css
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.header {
|
2 |
+
display: flex;
|
3 |
+
padding: 30px;
|
4 |
+
text-align: center;
|
5 |
+
justify-content: center;
|
6 |
+
}
|
7 |
+
|
8 |
+
#header-text {
|
9 |
+
font-size: 50px;
|
10 |
+
line-height: 50px;
|
11 |
+
}
|
12 |
+
|
13 |
+
#header-logo {
|
14 |
+
width: 50px;
|
15 |
+
height: 50px;
|
16 |
+
margin-right: 10px;
|
17 |
+
/*background-image: url("file=images/logo.svg");*/
|
18 |
+
}
|
19 |
+
|
20 |
+
.input-row {
|
21 |
+
max-width: 900px;
|
22 |
+
margin: 0 auto;
|
23 |
+
}
|
24 |
+
|
25 |
+
.margin-bottom {
|
26 |
+
margin-bottom: 48px;
|
27 |
+
}
|
28 |
+
|
29 |
+
.results-header {
|
30 |
+
margin-top: 48px;
|
31 |
+
text-align: center;
|
32 |
+
font-size: 45px;
|
33 |
+
margin-bottom: 12px;
|
34 |
+
}
|
35 |
+
|
36 |
+
.processed-info {
|
37 |
+
display: flex;
|
38 |
+
padding: 30px;
|
39 |
+
text-align: center;
|
40 |
+
justify-content: center;
|
41 |
+
font-size: 26px;
|
42 |
+
}
|
43 |
+
|
44 |
+
.title {
|
45 |
+
margin-bottom: 8px!important;
|
46 |
+
font-size: 22px;
|
47 |
+
}
|
weights/AS_square_v16.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9c1d7d4e56f0ea28b34dd2457807e9266d0d5539cf5adfb0719fed10791f79c5
|
3 |
+
size 44771917
|
weights/model_weights_detector.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:37fe3d98789572cf147d5f0d1ea99d50a57e5c2028454c3825295e89b6350fd5
|
3 |
+
size 23926765
|
weights/period_measurer_weights-1.298_real_full-fa12970.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:90133b01545ae9d30eafcbdec480410e0f8b9fae3ba1aabb280450ce9589100a
|
3 |
+
size 350396
|
weights/yolo/20240604_yolov8_segm_ABRCR1_all_train4_best.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:038555e9b9f3900ec29c9c79d7b3d5b50aa3ca37b3e665ee7aa2394facc7e20e
|
3 |
+
size 23926765
|
weights/yolo/current_yolo.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:37fe3d98789572cf147d5f0d1ea99d50a57e5c2028454c3825295e89b6350fd5
|
3 |
+
size 23926765
|