kkuczkowska commited on
Commit
8948e19
·
verified ·
1 Parent(s): b0c7c3a

Upload folder using huggingface_hub

Browse files
.dockerignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ out/
2
+ grana/
3
+ .*
.gitattributes CHANGED
@@ -1,35 +1,5 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
  *.pt filter=lfs diff=lfs merge=lfs -text
23
  *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ *.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  *.pt filter=lfs diff=lfs merge=lfs -text
3
  *.pth filter=lfs diff=lfs merge=lfs -text
4
+ weights/AS_square_v16.ckpt filter=lfs diff=lfs merge=lfs -text
5
+ weights/period_measurer_weights-1.298_real_full-fa12970.ckpt filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ *.pyc
2
+ **/.ipynb_checkpoints/
3
+ venv/
4
+ results_*
5
+ out/
6
+ grana/
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Datasource local storage ignored files
5
+ /dataSources/
6
+ /dataSources.local.xml
7
+ # Editor-based HTTP Client requests
8
+ /httpRequests/
.idea/GranaMeasure_interface-main.iml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$">
5
+ <excludeFolder url="file://$MODULE_DIR$/venv" />
6
+ </content>
7
+ <orderEntry type="jdk" jdkName="Python 3.10 (GranaMeasure_interface-main)" jdkType="Python SDK" />
8
+ <orderEntry type="sourceFolder" forTests="false" />
9
+ </component>
10
+ </module>
.idea/inspectionProfiles/Project_Default.xml ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <profile version="1.0">
3
+ <option name="myName" value="Project Default" />
4
+ <inspection_tool class="DuplicatedCode" enabled="true" level="WEAK WARNING" enabled_by_default="true">
5
+ <Languages>
6
+ <language minSize="253" name="Python" />
7
+ </Languages>
8
+ </inspection_tool>
9
+ <inspection_tool class="Eslint" enabled="true" level="WARNING" enabled_by_default="true" />
10
+ <inspection_tool class="PyPackageRequirementsInspection" enabled="true" level="WARNING" enabled_by_default="true">
11
+ <option name="ignoredPackages">
12
+ <value>
13
+ <list size="73">
14
+ <item index="0" class="java.lang.String" itemvalue="pandas" />
15
+ <item index="1" class="java.lang.String" itemvalue="PyYAML" />
16
+ <item index="2" class="java.lang.String" itemvalue="django-polymorphic" />
17
+ <item index="3" class="java.lang.String" itemvalue="dacite" />
18
+ <item index="4" class="java.lang.String" itemvalue="django-recaptcha" />
19
+ <item index="5" class="java.lang.String" itemvalue="python-dateutil" />
20
+ <item index="6" class="java.lang.String" itemvalue="psycopg2-binary" />
21
+ <item index="7" class="java.lang.String" itemvalue="python-dotenv" />
22
+ <item index="8" class="java.lang.String" itemvalue="MarkupSafe" />
23
+ <item index="9" class="java.lang.String" itemvalue="astroid" />
24
+ <item index="10" class="java.lang.String" itemvalue="kombu" />
25
+ <item index="11" class="java.lang.String" itemvalue="django-extensions" />
26
+ <item index="12" class="java.lang.String" itemvalue="docopt" />
27
+ <item index="13" class="java.lang.String" itemvalue="sentry-sdk" />
28
+ <item index="14" class="java.lang.String" itemvalue="django-tinymce" />
29
+ <item index="15" class="java.lang.String" itemvalue="django-braces" />
30
+ <item index="16" class="java.lang.String" itemvalue="umpy" />
31
+ <item index="17" class="java.lang.String" itemvalue="gpxpy" />
32
+ <item index="18" class="java.lang.String" itemvalue="docutils" />
33
+ <item index="19" class="java.lang.String" itemvalue="lxml" />
34
+ <item index="20" class="java.lang.String" itemvalue="Markdown" />
35
+ <item index="21" class="java.lang.String" itemvalue="django-formtools" />
36
+ <item index="22" class="java.lang.String" itemvalue="django-celery-beat" />
37
+ <item index="23" class="java.lang.String" itemvalue="django-crispy-forms" />
38
+ <item index="24" class="java.lang.String" itemvalue="pylibmc" />
39
+ <item index="25" class="java.lang.String" itemvalue="tablib" />
40
+ <item index="26" class="java.lang.String" itemvalue="django-import-export" />
41
+ <item index="27" class="java.lang.String" itemvalue="gunicorn" />
42
+ <item index="28" class="java.lang.String" itemvalue="simplejson" />
43
+ <item index="29" class="java.lang.String" itemvalue="dataclasses-json" />
44
+ <item index="30" class="java.lang.String" itemvalue="django-hstore" />
45
+ <item index="31" class="java.lang.String" itemvalue="django-mptt" />
46
+ <item index="32" class="java.lang.String" itemvalue="boto3" />
47
+ <item index="33" class="java.lang.String" itemvalue="django-extra-views" />
48
+ <item index="34" class="java.lang.String" itemvalue="django-filter" />
49
+ <item index="35" class="java.lang.String" itemvalue="mock" />
50
+ <item index="36" class="java.lang.String" itemvalue="django-allauth" />
51
+ <item index="37" class="java.lang.String" itemvalue="django-taggit" />
52
+ <item index="38" class="java.lang.String" itemvalue="supervisor" />
53
+ <item index="39" class="java.lang.String" itemvalue="django-auth-ldap" />
54
+ <item index="40" class="java.lang.String" itemvalue="djangorestframework-gis" />
55
+ <item index="41" class="java.lang.String" itemvalue="frictionless" />
56
+ <item index="42" class="java.lang.String" itemvalue="django-bulk-update" />
57
+ <item index="43" class="java.lang.String" itemvalue="requests" />
58
+ <item index="44" class="java.lang.String" itemvalue="django-storages" />
59
+ <item index="45" class="java.lang.String" itemvalue="numpy" />
60
+ <item index="46" class="java.lang.String" itemvalue="rest-pandas" />
61
+ <item index="47" class="java.lang.String" itemvalue="Jinja2" />
62
+ <item index="48" class="java.lang.String" itemvalue="drf-yasg" />
63
+ <item index="49" class="java.lang.String" itemvalue="sqlparse" />
64
+ <item index="50" class="java.lang.String" itemvalue="docker" />
65
+ <item index="51" class="java.lang.String" itemvalue="celery" />
66
+ <item index="52" class="java.lang.String" itemvalue="ipdb" />
67
+ <item index="53" class="java.lang.String" itemvalue="akismet" />
68
+ <item index="54" class="java.lang.String" itemvalue="djangorestframework" />
69
+ <item index="55" class="java.lang.String" itemvalue="billiard" />
70
+ <item index="56" class="java.lang.String" itemvalue="funcy" />
71
+ <item index="57" class="java.lang.String" itemvalue="six" />
72
+ <item index="58" class="java.lang.String" itemvalue="django-celery-results" />
73
+ <item index="59" class="java.lang.String" itemvalue="amqp" />
74
+ <item index="60" class="java.lang.String" itemvalue="ipython" />
75
+ <item index="61" class="java.lang.String" itemvalue="ffmpeg-python" />
76
+ <item index="62" class="java.lang.String" itemvalue="django-debug-toolbar" />
77
+ <item index="63" class="java.lang.String" itemvalue="logilab-common" />
78
+ <item index="64" class="java.lang.String" itemvalue="pykwalify" />
79
+ <item index="65" class="java.lang.String" itemvalue="django-grappelli" />
80
+ <item index="66" class="java.lang.String" itemvalue="watchdog" />
81
+ <item index="67" class="java.lang.String" itemvalue="Sphinx" />
82
+ <item index="68" class="java.lang.String" itemvalue="azure-storage-blob" />
83
+ <item index="69" class="java.lang.String" itemvalue="Django" />
84
+ <item index="70" class="java.lang.String" itemvalue="django-timezone-field" />
85
+ <item index="71" class="java.lang.String" itemvalue="pytz" />
86
+ <item index="72" class="java.lang.String" itemvalue="Pillow" />
87
+ </list>
88
+ </value>
89
+ </option>
90
+ </inspection_tool>
91
+ <inspection_tool class="PyPep8NamingInspection" enabled="true" level="WEAK WARNING" enabled_by_default="true">
92
+ <option name="ignoredErrors">
93
+ <list>
94
+ <option value="N803" />
95
+ <option value="N802" />
96
+ <option value="N806" />
97
+ </list>
98
+ </option>
99
+ </inspection_tool>
100
+ </profile>
101
+ </component>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.10 (GranaMeasure_interface-main)" project-jdk-type="Python SDK" />
4
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/GranaMeasure_interface-main.iml" filepath="$PROJECT_DIR$/.idea/GranaMeasure_interface-main.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/other.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="PySciProjectComponent">
4
+ <option name="PY_SCI_VIEW_SUGGESTED" value="true" />
5
+ </component>
6
+ </project>
.idea/vcs.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="VcsDirectoryMappings">
4
+ <mapping directory="$PROJECT_DIR$" vcs="Git" />
5
+ </component>
6
+ </project>
.idea/workspace.xml ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ChangeListManager">
4
+ <list default="true" id="f10c5f0a-4791-498a-9005-10ee84337c97" name="Changes" comment="" />
5
+ <option name="SHOW_DIALOG" value="false" />
6
+ <option name="HIGHLIGHT_CONFLICTS" value="true" />
7
+ <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
8
+ <option name="LAST_RESOLUTION" value="IGNORE" />
9
+ </component>
10
+ <component name="FileTemplateManagerImpl">
11
+ <option name="RECENT_TEMPLATES">
12
+ <list>
13
+ <option value="CSS File" />
14
+ <option value="Python Script" />
15
+ </list>
16
+ </option>
17
+ </component>
18
+ <component name="Git.Settings">
19
+ <option name="RECENT_BRANCH_BY_REPOSITORY">
20
+ <map>
21
+ <entry key="$PROJECT_DIR$" value="development" />
22
+ </map>
23
+ </option>
24
+ <option name="RECENT_GIT_ROOT_PATH" value="$PROJECT_DIR$" />
25
+ </component>
26
+ <component name="GitSEFilterConfiguration">
27
+ <file-type-list>
28
+ <filtered-out-file-type name="LOCAL_BRANCH" />
29
+ <filtered-out-file-type name="REMOTE_BRANCH" />
30
+ <filtered-out-file-type name="TAG" />
31
+ <filtered-out-file-type name="COMMIT_BY_MESSAGE" />
32
+ </file-type-list>
33
+ </component>
34
+ <component name="ProjectId" id="2bog1ms5jMwIjHWHezXITkymVNk" />
35
+ <component name="ProjectLevelVcsManager" settingsEditedManually="true" />
36
+ <component name="ProjectViewState">
37
+ <option name="hideEmptyMiddlePackages" value="true" />
38
+ <option name="showLibraryContents" value="true" />
39
+ </component>
40
+ <component name="PropertiesComponent">
41
+ <property name="ASKED_ADD_EXTERNAL_FILES" value="true" />
42
+ <property name="RunOnceActivity.OpenProjectViewOnStart" value="true" />
43
+ <property name="RunOnceActivity.ShowReadmeOnStart" value="true" />
44
+ <property name="SHARE_PROJECT_CONFIGURATION_FILES" value="true" />
45
+ <property name="WebServerToolWindowFactoryState" value="false" />
46
+ <property name="last_opened_file_path" value="$USER_HOME$/BirdNET-Analyzer" />
47
+ <property name="list.type.of.created.stylesheet" value="CSS" />
48
+ <property name="node.js.detected.package.eslint" value="true" />
49
+ <property name="node.js.selected.package.eslint" value="(autodetect)" />
50
+ <property name="settings.editor.selected.configurable" value="com.jetbrains.python.configuration.PyActiveSdkModuleConfigurable" />
51
+ </component>
52
+ <component name="RecentsManager">
53
+ <key name="CopyFile.RECENT_KEYS">
54
+ <recent name="$PROJECT_DIR$/examples" />
55
+ <recent name="$PROJECT_DIR$/mock_data" />
56
+ <recent name="$PROJECT_DIR$" />
57
+ </key>
58
+ <key name="MoveFile.RECENT_KEYS">
59
+ <recent name="$PROJECT_DIR$/images" />
60
+ </key>
61
+ </component>
62
+ <component name="SpellCheckerSettings" RuntimeDictionaries="0" Folders="0" CustomDictionaries="0" DefaultDictionary="application-level" UseSingleDictionary="true" transferred="true" />
63
+ <component name="TaskManager">
64
+ <task active="true" id="Default" summary="Default task">
65
+ <changelist id="f10c5f0a-4791-498a-9005-10ee84337c97" name="Changes" comment="" />
66
+ <created>1706886633279</created>
67
+ <option name="number" value="Default" />
68
+ <option name="presentableId" value="Default" />
69
+ <updated>1706886633279</updated>
70
+ <workItem from="1706886636008" duration="1452000" />
71
+ <workItem from="1707129545673" duration="18000" />
72
+ <workItem from="1707475281532" duration="689000" />
73
+ <workItem from="1707734574375" duration="136000" />
74
+ <workItem from="1707829113355" duration="7000" />
75
+ <workItem from="1715267971553" duration="18107000" />
76
+ <workItem from="1715791673881" duration="83000" />
77
+ <workItem from="1718701069639" duration="1074000" />
78
+ <workItem from="1718702835088" duration="2728000" />
79
+ <workItem from="1718705968353" duration="1911000" />
80
+ <workItem from="1718708463096" duration="1371000" />
81
+ <workItem from="1718710203255" duration="3374000" />
82
+ <workItem from="1718719127184" duration="3910000" />
83
+ <workItem from="1718729675691" duration="3859000" />
84
+ <workItem from="1718733940161" duration="2339000" />
85
+ <workItem from="1718736568891" duration="827000" />
86
+ <workItem from="1718737438910" duration="21000" />
87
+ <workItem from="1718740972224" duration="4830000" />
88
+ <workItem from="1718745952723" duration="1618000" />
89
+ <workItem from="1718793757299" duration="380000" />
90
+ <workItem from="1719923572507" duration="1390000" />
91
+ <workItem from="1720770804165" duration="6223000" />
92
+ <workItem from="1720780265020" duration="5909000" />
93
+ <workItem from="1720791565290" duration="1486000" />
94
+ <workItem from="1721039720633" duration="18000" />
95
+ <workItem from="1721127968492" duration="8000" />
96
+ <workItem from="1723193047208" duration="1397000" />
97
+ <workItem from="1723545275936" duration="11000" />
98
+ <workItem from="1727785730725" duration="1830000" />
99
+ <workItem from="1727787983064" duration="2377000" />
100
+ <workItem from="1727790929474" duration="2947000" />
101
+ <workItem from="1728301291811" duration="2719000" />
102
+ <workItem from="1728475950279" duration="10306000" />
103
+ <workItem from="1728545217803" duration="6586000" />
104
+ <workItem from="1728562720043" duration="7663000" />
105
+ <workItem from="1728657746085" duration="8583000" />
106
+ <workItem from="1729079338205" duration="3308000" />
107
+ <workItem from="1729252342732" duration="600000" />
108
+ <workItem from="1729255292344" duration="3000" />
109
+ <workItem from="1729255305399" duration="170000" />
110
+ <workItem from="1730227437944" duration="1638000" />
111
+ </task>
112
+ <task id="LOCAL-00001" summary="File format + new layout WIP">
113
+ <created>1718743330580</created>
114
+ <option name="number" value="00001" />
115
+ <option name="presentableId" value="LOCAL-00001" />
116
+ <option name="project" value="LOCAL" />
117
+ <updated>1718743330580</updated>
118
+ </task>
119
+ <task id="LOCAL-00002" summary="Color theme setup">
120
+ <created>1718745108453</created>
121
+ <option name="number" value="00002" />
122
+ <option name="presentableId" value="LOCAL-00002" />
123
+ <option name="project" value="LOCAL" />
124
+ <updated>1718745108453</updated>
125
+ </task>
126
+ <task id="LOCAL-00003" summary="Theming">
127
+ <created>1718747250552</created>
128
+ <option name="number" value="00003" />
129
+ <option name="presentableId" value="LOCAL-00003" />
130
+ <option name="project" value="LOCAL" />
131
+ <updated>1718747250552</updated>
132
+ </task>
133
+ <option name="localTasksCounter" value="4" />
134
+ <servers />
135
+ </component>
136
+ <component name="TypeScriptGeneratedFilesManager">
137
+ <option name="version" value="3" />
138
+ </component>
139
+ <component name="VcsManagerConfiguration">
140
+ <option name="ADD_EXTERNAL_FILES_SILENTLY" value="true" />
141
+ <MESSAGE value="File format + new layout WIP" />
142
+ <MESSAGE value="Color theme setup" />
143
+ <MESSAGE value="Theming" />
144
+ <option name="LAST_COMMIT_MESSAGE" value="Theming" />
145
+ </component>
146
+ <component name="XDebuggerManager">
147
+ <breakpoint-manager>
148
+ <breakpoints>
149
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
150
+ <url>file://$PROJECT_DIR$/app.py</url>
151
+ <line>172</line>
152
+ <option name="timeStamp" value="1" />
153
+ </line-breakpoint>
154
+ </breakpoints>
155
+ </breakpoint-manager>
156
+ </component>
157
+ </project>
Dockerfile ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.9 as build
2
+
3
+ RUN apt-get update && apt-get install ffmpeg libsm6 libxext6 -y
4
+
5
+ # Set up a new user named "user" with user ID 1000
6
+ RUN useradd -m -u 1000 user
7
+
8
+ # Switch to the "user" user
9
+ USER user
10
+
11
+ # Set home to the user's home directory
12
+ ENV HOME=/home/user \
13
+ PATH=/home/user/.local/bin:$PATH
14
+
15
+ # Set the working directory to the user's home directory
16
+ WORKDIR $HOME/app
17
+
18
+ COPY requirements.txt .
19
+ RUN pip install --no-cache-dir -r ./requirements.txt --extra-index-url https://download.pytorch.org/whl/cpu
20
+ # RUN mim install mmengine
21
+ # RUN mim install "mmcv==2.1.0" & mim install "mmdet==3.3.0"
22
+
23
+ FROM build as final
24
+
25
+ COPY --chown=user . .
26
+
27
+ CMD python app.py
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 center4ml
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md CHANGED
@@ -1,12 +1,8 @@
1
  ---
2
  title: GranaMeasure
3
- emoji: 📚
4
- colorFrom: purple
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.4.0
8
  app_file: app.py
9
- pinned: false
 
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: GranaMeasure
 
 
 
 
 
3
  app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 4.20.0
6
  ---
7
+ # GranaMeasure_interface
8
+ interface code for GranaMeasure
angle_calculation/angle_model.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import pytorch_lightning as pl
4
+ import timm
5
+
6
+ # from hydra.utils import instantiate
7
+ from scipy.stats import circmean, circstd
8
+ from scipy import ndimage
9
+ from skimage.transform import resize
10
+
11
+ from sampling import get_crop_batch
12
+ from granum_utils import get_circle_mask
13
+ import image_transforms
14
+ from envelope_correction import calculate_best_angle_from_mask
15
+ ## loss
16
+
17
+ class ConfidenceScaler:
18
+ def __init__(self, data: np.ndarray):
19
+ self.data = data
20
+ self.data.sort()
21
+ def __call__(self, x):
22
+ return np.searchsorted(self.data,x) / len(self.data)
23
+
24
+ class PatchedPredictor:
25
+ def __init__(self,
26
+ model,
27
+ crop_size=96,
28
+ normalization=dict(mean=0,std=1),
29
+ n_samples=32,
30
+ mask=None,# 'circle', None
31
+ filter_outliers=True,
32
+ apply_radon=False, # apply Radon transform
33
+ radon_size=(128,128), # (int, int) reshape radon transformed image to this shape,
34
+ angle_confidence_threshold=0,
35
+ use_envelope_correction=True
36
+ ):
37
+ self.model = model
38
+ self.crop_size = crop_size
39
+ self.normalization = normalization
40
+ self.n_samples = n_samples
41
+ if mask not in [None, 'circle']:
42
+ raise ValueError(f'unknown mask {mask}')
43
+ self.mask = mask
44
+ self.filter_outliers = filter_outliers
45
+
46
+ self.apply_radon = apply_radon
47
+ self.radon_size = radon_size
48
+
49
+ self.angle_confidence_threshold = angle_confidence_threshold
50
+ self.use_envelope_correction = use_envelope_correction
51
+
52
+ @torch.no_grad()
53
+ def __call__(self, img: np.ndarray, mask: np.ndarray):
54
+ pl.seed_everything(44)
55
+ # get crops with different scales and rotation
56
+ crops, angles_tta, scales_tta = get_crop_batch(
57
+ img, mask,
58
+ crop_size=self.crop_size,
59
+ samples_per_scale=self.n_samples,
60
+ use_variance_threshold=True
61
+ )
62
+ if len(crops) == 0:
63
+ return dict(
64
+ est_angle=np.nan,
65
+ est_angle_confidence=0.,
66
+ )
67
+
68
+ # preprocess batch (normalize, mask, transform)
69
+ batch = self._preprocess_batch(crops)
70
+
71
+ # predict for batch - we don't use period and lumen anymore
72
+ preds_direction, preds_period, preds_lumen_width = self.model(batch)
73
+ # # convert to numpy
74
+ # preds_direction = preds_direction.numpy()
75
+ # preds_period = preds_period.numpy()
76
+ # preds_lumen_width = preds_lumen_width.numpy()
77
+
78
+ # aggregate angles
79
+ est_angles = (preds_direction - angles_tta) % 180
80
+ est_angle = circmean(est_angles, low=-90, high=90) + 90
81
+ est_angle_std = circstd(est_angles, low=-90, high=90)
82
+ est_angle_confidence = self._std_to_confidence(est_angle_std, 10) # confidence 0.5 for std =10 degrees
83
+
84
+ if est_angle_confidence < self.angle_confidence_threshold:
85
+ est_angle = np.nan
86
+ est_angle_confidence = 0.
87
+
88
+ if self.use_envelope_correction and (not np.isnan(est_angle)):
89
+ angle_correction = -calculate_best_angle_from_mask(
90
+ ndimage.rotate(mask, -est_angle, reshape=True, order=0)
91
+ )
92
+ est_angle += angle_correction
93
+
94
+ return dict(
95
+ est_angle=est_angle,
96
+ est_angle_confidence=est_angle_confidence,
97
+ )
98
+
99
+ def _apply_radon(self, batch): # may reauire circle mask
100
+ crops_radon = image_transforms.batched_radon(batch.numpy())
101
+ crops_radon = np.transpose(resize(np.transpose(crops_radon, (1, 2, 0)), self.radon_size), (2, 0, 1))
102
+ return torch.tensor(crops_radon)
103
+
104
+ def _preprocess_batch(self, batch):
105
+ if self.mask == 'circle':
106
+ mask = get_circle_mask(batch.shape[1])
107
+ batch[:,mask] = 0
108
+ if self.apply_radon:
109
+ batch = self._apply_radon(batch)
110
+ batch = ((batch/255) - self.normalization['mean'])/self.normalization['std']
111
+ return batch.unsqueeze(1) # add channel dimension
112
+
113
+ def _filter_outliers(self, x, qmin=0.25, qmax=0.75):
114
+ x_min, x_max = np.quantile(x, [qmin, qmax])
115
+ return x[(x>=x_min) & (x<=x_max)]
116
+
117
+ def _std_to_confidence(self, x, x_thr, y_thr=0.5):
118
+ """transform [0, inf] to [1,0], such that f(x_thr)=y_thr"""
119
+ return 1 / (1+x*(1-y_thr)/(x_thr*y_thr))
120
+
121
+ class CosineLoss(torch.nn.Module):
122
+ def __init__(self, p=1, degrees=False, scale=1):
123
+ super().__init__()
124
+ self.p = p
125
+ self.degrees = degrees
126
+ self.scale = scale
127
+ def forward(self, x, y):
128
+ if self.degrees:
129
+ x = torch.deg2rad(x)
130
+ y = torch.deg2rad(y)
131
+ return torch.mean((1-torch.cos(x-y))**self.p) * self.scale
132
+
133
+ ## model
134
+ class AngleParser2d(torch.nn.Module):
135
+ def __init__(self, angle_range=180):
136
+ super().__init__()
137
+ self.angle_range = angle_range
138
+ def forward(self, batch):
139
+ # r = torch.linalg.norm(batch, dim=1)
140
+ preds_y_proj = torch.sigmoid(batch[:,0]) - 0.5
141
+ preds_x_proj = torch.sigmoid(batch[:,1]) - 0.5
142
+ preds_direction = self.angle_range/360.*torch.rad2deg(torch.arctan2(preds_y_proj, preds_x_proj))
143
+ return preds_direction
144
+
145
+ class AngleRegularizer(torch.nn.Module):
146
+ def __init__(self, strength=1.0, scale=1.0, p=2):
147
+ super().__init__()
148
+ self.strength = strength
149
+ self.scale = scale
150
+ self.p = p
151
+ def forward(self, batch):
152
+ r = torch.linalg.norm(batch, dim=1)
153
+ return self.strength * torch.norm(r - self.scale, p=self.p)
154
+
155
+ class AngleRegularizerLog(torch.nn.Module):
156
+ def __init__(self, strength=1.0, scale=1.0, p=2):
157
+ super().__init__()
158
+ self.strength = strength
159
+ self.scale = scale
160
+ self.p = p
161
+ def forward(self, batch):
162
+ r = torch.linalg.norm(batch, dim=1)
163
+ return self.strength * torch.norm(torch.log(r/self.scale), p=self.p)
164
+
165
+ class StripsModel(pl.LightningModule):
166
+ def __init__(self,
167
+ model_name = 'resnet18',
168
+ lr=0.001,
169
+ optimizer_hparams=dict(),
170
+ lr_hparams=dict(classname='MultiStepLR', kwargs=dict(milestones=[100, 150], gamma=0.1)),
171
+ loss_hparams=dict(rotation_weight=10., lumen_fraction_weight=50.),
172
+ angle_hparams=dict(angle_range=180.),
173
+ regularizer_hparams=None,
174
+ sigmoid_smoother=10.
175
+ ):
176
+ super().__init__()
177
+ # Exports the hyperparameters to a YAML file, and create "self.hparams" namespace
178
+ self.save_hyperparameters()
179
+ # Create model - implemented in non-abstract classes
180
+ self.model = timm.create_model(model_name, in_chans=1, num_classes=4) #2 + self.hparams.angle_hparams['ndim'])
181
+ self.angle_parser = AngleParser2d(**self.hparams.angle_hparams)
182
+ self.regularizer = self._get_regularizer(self.hparams.regularizer_hparams)
183
+ self.losses = {
184
+ 'direction': CosineLoss(2., True),
185
+ 'period': torch.nn.functional.mse_loss,
186
+ 'lumen_fraction': torch.nn.functional.mse_loss
187
+ }
188
+ self.losses_weights = {
189
+ 'direction': self.hparams.loss_hparams['rotation_weight'],
190
+ 'period': 1,
191
+ 'lumen_fraction': self.hparams.loss_hparams['lumen_fraction_weight'],
192
+ 'regularization': self.hparams.loss_hparams.get('regularization_weight', 0.)
193
+ }
194
+
195
+ def _get_regularizer(self, regularizer_params):
196
+ if regularizer_params is None:
197
+ return None
198
+ else:
199
+ return instantiate(regularizer_params)
200
+
201
+
202
+ def forward(self, x, return_raw=False):
203
+ """get predictions from image batch"""
204
+ preds = self.model(x) # preds: logit angle_sin, logit angle_cos, period, logit lumen fraction or logit angle, period, logit lumen fraction
205
+ preds_direction = self.angle_parser(preds)
206
+ preds_period = preds[:,-2]
207
+ preds_lumen_fraction = torch.sigmoid(preds[:,-1]*self.hparams.sigmoid_smoother) #lumen fraction is between 0 and 1, so we take sigmoid fo this
208
+
209
+ outputs = [preds_direction, preds_period, preds_lumen_fraction]
210
+ if return_raw:
211
+ outputs.append(preds)
212
+
213
+ return tuple(outputs)
214
+
215
+ def configure_optimizers(self):
216
+ # AdamW is Adam with a correct implementation of weight decay (see here
217
+ # for details: https://arxiv.org/pdf/1711.05101.pdf)
218
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, **self.hparams.optimizer_hparams)
219
+ # scheduler = getattr(torch.optim.lr_scheduler, self.hparams.lr_hparams['classname'])(optimizer, **self.hparams.lr_hparams['kwargs'])
220
+ scheduler = instantiate({**self.hparams.lr_hparams, '_partial_': True})(optimizer)
221
+ return [optimizer], [scheduler]
222
+
223
+ def process_batch_supervised(self, batch):
224
+ """get predictions, losses and mean errors (MAE)"""
225
+
226
+ # get predictions
227
+ preds = {}
228
+ preds['direction'], preds['period'], preds['lumen_fraction'], preds_raw = self.forward(batch['image'], return_raw=True) # preds: angle, period, lumen fraction, raw preds
229
+
230
+ # calculate losses
231
+ losses = {
232
+ 'direction': self.losses['direction'](2*batch['direction'], 2*preds['direction']),
233
+ 'period': self.losses['period'](batch['period'], preds['period']),
234
+ 'lumen_fraction': self.losses['lumen_fraction'](batch['lumen_fraction'], preds['lumen_fraction']),
235
+ }
236
+ if self.regularizer is not None:
237
+ losses['regularization'] = self.regularizer(preds_raw[:,:2])
238
+
239
+ losses['final'] = \
240
+ losses['direction']*self.losses_weights['direction'] + \
241
+ losses['period']*self.losses_weights['period'] + \
242
+ losses['lumen_fraction']*self.losses_weights['lumen_fraction'] + \
243
+ losses.get('regularization', 0.)*self.losses_weights.get('regularization', 0.)
244
+
245
+ # calculate mean errors
246
+ period_difference = np.mean(abs(
247
+ batch['period'].detach().cpu().numpy() - \
248
+ preds['period'].detach().cpu().numpy()
249
+ ))
250
+
251
+ a1 = batch['direction'].detach().cpu().numpy()
252
+ a2 = preds['direction'].detach().cpu().numpy()
253
+ angle_difference = np.mean(0.5*np.degrees(np.arccos(np.cos(2*np.radians(a2-a1)))))
254
+
255
+ lumen_fraction_difference = np.mean(abs(preds['lumen_fraction'].detach().cpu().numpy()-batch['lumen_fraction'].detach().cpu().numpy()))
256
+
257
+ mae = {
258
+ 'period': period_difference,
259
+ 'direction': angle_difference,
260
+ 'lumen_fraction': lumen_fraction_difference
261
+ }
262
+
263
+ return preds, losses, mae
264
+
265
+ def log_all(self, losses, mae, prefix=''):
266
+ self.log(f"{prefix}angle_loss", losses['direction'].item())
267
+ self.log(f"{prefix}period_loss", losses['period'].item())
268
+ self.log(f"{prefix}lumen_fraction_loss", losses['lumen_fraction'].item())
269
+ self.log(f"{prefix}period_difference", mae['period'])
270
+ self.log(f"{prefix}angle_difference", mae['direction'])
271
+ self.log(f"{prefix}lumen_fraction_difference", mae['lumen_fraction'])
272
+ self.log(f"{prefix}loss", losses['final'])
273
+ if 'regularization' in losses:
274
+ self.log(f"{prefix}regularization_loss", losses['regularization'].item())
275
+
276
+ def training_step(self, batch, batch_idx):
277
+ # "batch" is the output of the training data loader.
278
+ preds, losses, mae = self.process_batch_supervised(batch)
279
+ self.log_all(losses, mae, prefix='train_')
280
+
281
+ return losses['final']
282
+
283
+ def validation_step(self, batch, batch_idx):
284
+ preds, losses, mae = self.process_batch_supervised(batch)
285
+ self.log_all(losses, mae, prefix='val_')
286
+
287
+ def test_step(self, batch, batch_idx):
288
+ preds, losses, mae = self.process_batch_supervised(batch)
289
+ self.log_all(losses, mae, prefix='test_')
290
+
291
+
292
+ class StripsModelLumenWidth(pl.LightningModule):
293
+ def __init__(self,
294
+ model_name = 'resnet18',
295
+ lr=0.001,
296
+ optimizer_hparams=dict(),
297
+ lr_hparams=dict(classname='MultiStepLR', kwargs=dict(milestones=[100, 150], gamma=0.1)),
298
+ loss_hparams=dict(rotation_weight=10., lumen_width_weight=50.),
299
+ angle_hparams=dict(angle_range=180.),
300
+ regularizer_hparams=None,
301
+ sigmoid_smoother=10.
302
+ ):
303
+ super().__init__()
304
+ # Exports the hyperparameters to a YAML file, and create "self.hparams" namespace
305
+ self.save_hyperparameters()
306
+ # Create model - implemented in non-abstract classes
307
+ self.model = timm.create_model(model_name, in_chans=1, num_classes=4) #2 + self.hparams.angle_hparams['ndim'])
308
+ self.angle_parser = AngleParser2d(**self.hparams.angle_hparams)
309
+ self.regularizer = self._get_regularizer(self.hparams.regularizer_hparams)
310
+ self.losses = {
311
+ 'direction': CosineLoss(2., True),
312
+ 'period': torch.nn.functional.mse_loss,
313
+ 'lumen_width': torch.nn.functional.mse_loss
314
+ }
315
+ self.losses_weights = {
316
+ 'direction': self.hparams.loss_hparams['rotation_weight'],
317
+ 'period': 1,
318
+ 'lumen_width': self.hparams.loss_hparams['lumen_width_weight'],
319
+ 'regularization': self.hparams.loss_hparams.get('regularization_weight', 0.)
320
+ }
321
+
322
+ def _get_regularizer(self, regularizer_params):
323
+ if regularizer_params is None:
324
+ return None
325
+ else:
326
+ return instantiate(regularizer_params)
327
+
328
+ def forward(self, x, return_raw=False):
329
+ """get predictions from image batch"""
330
+ preds = self.model(x) # preds: logit angle_sin, logit angle_cos, period, logit lumen fraction or logit angle, period, logit lumen fraction
331
+ preds_direction = self.angle_parser(preds)
332
+ preds_period = preds[:,-2]
333
+ preds_lumen_width = preds[:,-1] #lumen fraction is between 0 and 1, so we take sigmoid fo this
334
+
335
+ outputs = [preds_direction, preds_period, preds_lumen_width]
336
+ if return_raw:
337
+ outputs.append(preds)
338
+
339
+ return tuple(outputs)
340
+
341
+ def configure_optimizers(self):
342
+ # AdamW is Adam with a correct implementation of weight decay (see here
343
+ # for details: https://arxiv.org/pdf/1711.05101.pdf)
344
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, **self.hparams.optimizer_hparams)
345
+ # scheduler = getattr(torch.optim.lr_scheduler, self.hparams.lr_hparams['classname'])(optimizer, **self.hparams.lr_hparams['kwargs'])
346
+ scheduler = instantiate({**self.hparams.lr_hparams, '_partial_': True})(optimizer)
347
+ return [optimizer], [scheduler]
348
+
349
+ def process_batch_supervised(self, batch):
350
+ """get predictions, losses and mean errors (MAE)"""
351
+
352
+ # get predictions
353
+ preds = {}
354
+ preds['direction'], preds['period'], preds['lumen_width'], preds_raw = self.forward(batch['image'], return_raw=True) # preds: angle, period, lumen fraction, raw preds
355
+
356
+ # calculate losses
357
+ losses = {
358
+ 'direction': self.losses['direction'](2*batch['direction'], 2*preds['direction']),
359
+ 'period': self.losses['period'](batch['period'], preds['period']),
360
+ 'lumen_width': self.losses['lumen_width'](batch['lumen_width'], preds['lumen_width']),
361
+ }
362
+ if self.regularizer is not None:
363
+ losses['regularization'] = self.regularizer(preds_raw[:,:2])
364
+
365
+ losses['final'] = \
366
+ losses['direction']*self.losses_weights['direction'] + \
367
+ losses['period']*self.losses_weights['period'] + \
368
+ losses['lumen_width']*self.losses_weights['lumen_width'] + \
369
+ losses.get('regularization', 0.)*self.losses_weights.get('regularization', 0.)
370
+
371
+ # calculate mean errors
372
+ period_difference = np.mean(abs(
373
+ batch['period'].detach().cpu().numpy() - \
374
+ preds['period'].detach().cpu().numpy()
375
+ ))
376
+
377
+ a1 = batch['direction'].detach().cpu().numpy()
378
+ a2 = preds['direction'].detach().cpu().numpy()
379
+ angle_difference = np.mean(0.5*np.degrees(np.arccos(np.cos(2*np.radians(a2-a1)))))
380
+
381
+ lumen_width_difference = np.mean(abs(preds['lumen_width'].detach().cpu().numpy()-batch['lumen_width'].detach().cpu().numpy()))
382
+
383
+ lumen_fraction_pred = preds['lumen_width'].detach().cpu().numpy()/preds['period'].detach().cpu().numpy()
384
+ lumen_fraction_gt = batch['lumen_width'].detach().cpu().numpy()/batch['period'].detach().cpu().numpy()
385
+ lumen_fraction_difference = np.mean(abs(lumen_fraction_pred-lumen_fraction_gt))
386
+
387
+ mae = {
388
+ 'period': period_difference,
389
+ 'direction': angle_difference,
390
+ 'lumen_width': lumen_width_difference,
391
+ 'lumen_fraction': lumen_fraction_difference
392
+ }
393
+
394
+ return preds, losses, mae
395
+
396
+ def log_all(self, losses, mae, prefix=''):
397
+ for k, v in losses.items():
398
+ self.log(f'{prefix}{k}_loss', v.item() if isinstance(v, torch.Tensor) else v)
399
+ for k, v in mae.items():
400
+ self.log(f'{prefix}{k}_difference', v.item() if isinstance(v, torch.Tensor) else v)
401
+
402
+ def training_step(self, batch, batch_idx):
403
+ # "batch" is the output of the training data loader.
404
+ preds, losses, mae = self.process_batch_supervised(batch)
405
+ self.log_all(losses, mae, prefix='train_')
406
+
407
+ return losses['final']
408
+
409
+ def validation_step(self, batch, batch_idx):
410
+ preds, losses, mae = self.process_batch_supervised(batch)
411
+ self.log_all(losses, mae, prefix='val_')
412
+
413
+ def test_step(self, batch, batch_idx):
414
+ preds, losses, mae = self.process_batch_supervised(batch)
415
+ self.log_all(losses, mae, prefix='test_')
416
+
417
+
418
+
419
+ # class StripsModel(StripsModelGeneral):
420
+ # def __init__(self, model_name, *args, **kwargs):
421
+ # super().__init__( *args, **kwargs)
422
+ # self.model = timm.create_model(model_name, in_chans=1, num_classes=4)
423
+ # def forward(self, x):
424
+ # """get predictions from image batch"""
425
+ # preds = self.model(x) # preds: logit angle_sin, logit angle_cos, period, logit lumen fraction
426
+ # preds_sin = 1. - 2*torch.sigmoid(preds[:,0])
427
+ # preds_cos = 1. - 2*torch.sigmoid(preds[:,1])
428
+ # preds_direction = 0.5*torch.rad2deg(torch.arctan2(preds_sin, preds_cos))
429
+ # preds_period = preds[:,2]
430
+ # preds_lumen_fraction = torch.sigmoid(preds[:,3]) #lumen fraction is between 0 and 1, so we take sigmoid fo this
431
+ # return preds_direction, preds_period, preds_lumen_fraction
432
+
433
+ # class StripsModelAngle1(StripsModelGeneral):
434
+ # def __init__(self, model_name, *args, **kwargs):
435
+ # super().__init__( *args, **kwargs)
436
+ # self.model = timm.create_model(model_name, in_chans=1, num_classes=3)
437
+ # def forward(self, x):
438
+ # """get predictions from image batch"""
439
+ # preds = self.model(x) # preds: logit angle_sin, logit angle
440
+ # preds_direction = torch.pi * torch.sigmoid(preds[:,0])
441
+ # preds_period = preds[:,1]
442
+ # preds_lumen_fraction = torch.sigmoid(preds[:,2]) #lumen fraction is between 0 and 1, so we take sigmoid fo this
443
+ # return preds_direction, preds_period, preds_lumen_fraction
444
+
angle_calculation/classic.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy import signal
3
+ from scipy import ndimage
4
+ from scipy.fftpack import next_fast_len
5
+ from skimage.transform import rotate
6
+ from skimage._shared.utils import convert_to_float
7
+ from skimage.transform import warp
8
+ import matplotlib.pyplot as plt
9
+ import cv2
10
+ from copy import deepcopy
11
+
12
+ def get_directional_std(image, theta=None,*, preserve_range=False):
13
+
14
+ if image.ndim != 2:
15
+ raise ValueError('The input image must be 2-D')
16
+ if theta is None:
17
+ theta = np.arange(180)
18
+
19
+ image = convert_to_float(image.copy(), preserve_range) #TODO: needed?
20
+
21
+ shape_min = min(image.shape)
22
+ img_shape = np.array(image.shape)
23
+
24
+ # Crop image to make it square
25
+ slices = tuple(slice(int(np.ceil(excess / 2)),
26
+ int(np.ceil(excess / 2) + shape_min))
27
+ if excess > 0 else slice(None)
28
+ for excess in (img_shape - shape_min))
29
+ image = image[slices]
30
+ shape_min = min(image.shape)
31
+ img_shape = np.array(image.shape)
32
+
33
+ radius = shape_min // 2
34
+ coords = np.array(np.ogrid[:image.shape[0], :image.shape[1]],
35
+ dtype=object)
36
+ dist = ((coords - img_shape // 2) ** 2).sum(0)
37
+ outside_reconstruction_circle = dist > radius ** 2
38
+ image[outside_reconstruction_circle] = 0
39
+
40
+ valid_square_slice = slice(int(np.ceil(radius*(1-1/np.sqrt(2)))), int(np.ceil(radius*(1+1/np.sqrt(2)))) )
41
+
42
+ # padded_image is always square
43
+ if image.shape[0] != image.shape[1]:
44
+ raise ValueError('padded_image must be a square')
45
+ center = image.shape[0] // 2
46
+ result = np.zeros(len(theta))
47
+
48
+ for i, angle in enumerate(np.deg2rad(theta)):
49
+ cos_a, sin_a = np.cos(angle), np.sin(angle)
50
+ R = np.array([[cos_a, sin_a, -center * (cos_a + sin_a - 1)],
51
+ [-sin_a, cos_a, -center * (cos_a - sin_a - 1)],
52
+ [0, 0, 1]])
53
+ rotated = warp(image, R, clip=False)
54
+ result[i] = rotated[valid_square_slice, valid_square_slice].std(axis=0).mean()
55
+ return result
56
+
57
+ def acf2d(x, nlags=None):
58
+ xo = x - x.mean(axis=0)
59
+ n = len(x)
60
+ if nlags is None:
61
+ nlags = n -1
62
+ lag_len = nlags
63
+
64
+ xi = np.arange(1, n + 1)
65
+ d = np.expand_dims(np.hstack((xi, xi[:-1][::-1])),1)
66
+
67
+ nobs = len(xo)
68
+ n = next_fast_len(2 * nobs + 1)
69
+ Frf = np.fft.fft(xo, n=n, axis=0)
70
+
71
+ acov = np.fft.ifft(Frf * np.conjugate(Frf), axis=0)[:nobs] / d[nobs - 1 :]
72
+ acov = acov.real
73
+ ac = acov[: nlags + 1] / acov[:1]
74
+ return ac
75
+
76
+ def get_period(acf_table, n_samples=50):
77
+ #TODO: use peak heights to select best candidates. use std to eliminate outliers
78
+ period_candidates = []
79
+ period_candidates_hights = []
80
+ for i in np.random.randint(0, acf_table.shape[1], min(acf_table.shape[1], n_samples)):
81
+ peaks = signal.find_peaks(acf_table[:,i])[0]
82
+ if len(peaks) == 0:
83
+ continue
84
+ peak_idx = peaks[0]
85
+ period_candidates.append(peak_idx)
86
+ period_candidates_hights.append(acf_table[peak_idx,i])
87
+ period_candidates = np.array(period_candidates)
88
+ period_candidates_hights = np.array(period_candidates_hights)
89
+
90
+ if len(period_candidates) == 0:
91
+ return np.nan, np.nan
92
+ elif len(period_candidates) == 1:
93
+ return period_candidates[0], np.nan
94
+ q1, q3 = np.quantile(period_candidates, [0.25, 0.75])
95
+ candidates_std = np.std(period_candidates[(period_candidates>=q1)&(period_candidates<=q3)])
96
+ # return period_candidates, period_candidates_hights
97
+ return np.median(period_candidates), candidates_std
98
+
99
+ def get_rotation_with_confidence(padded_image, blur_size=4, make_plots=True):
100
+ std_by_angle = get_directional_std(cv2.blur(padded_image, (blur_size,blur_size)))
101
+ rotation_angle = np.argmin(std_by_angle)
102
+
103
+ rotation_quality = 1 - np.min(std_by_angle)/np.median(std_by_angle)
104
+ if make_plots:
105
+ plt.plot(std_by_angle)
106
+ plt.axvline(rotation_angle, c='k')
107
+ plt.title(f'quality: {rotation_quality:0.2f}')
108
+ return rotation_angle, rotation_quality
109
+
110
+ def calculate_autocorrelation(oriented_img, blur_kernel=(7,1), make_plots=True):
111
+ autocorrelation = acf2d(cv2.blur(oriented_img.T, blur_kernel))
112
+ if make_plots:
113
+ fig, axs = plt.subplots(ncols=2, figsize=(12,6))
114
+ axs[0].imshow(autocorrelation)
115
+ axs[1].plot(autocorrelation.sum(axis=1))
116
+ return autocorrelation
117
+
118
+ def get_period_with_confidence(autocorrelation_tab, n_samples=30):
119
+ period, period_std = get_period(autocorrelation_tab, n_samples=n_samples)
120
+ if period_std == np.nan:
121
+ period_confidence = 0.001
122
+ else:
123
+ period_confidence = period/(period+2*period_std)
124
+ return period, period_confidence
125
+
126
+ def calculate_white_fraction(img, blur_size=4, make_plots=True): #TODO: add mask
127
+ blurred = cv2.blur(img, (blur_size, blur_size))
128
+ blurred_sum = blurred.sum(axis=0)
129
+ lower, upper = np.quantile(blurred_sum, [0.15, 0.85])
130
+ sign = blurred_sum > (lower+upper)/2
131
+
132
+ sign_change = sign[:-1] != sign[1:]
133
+ sign_change_indices = np.where(sign_change)[0]
134
+
135
+ if len(sign_change_indices) >= 2 + (sign[-1] == sign[0]):
136
+ cut_first = sign_change_indices[0]+1
137
+
138
+ if sign[-1] == sign[0]:
139
+ cut_last = sign_change_indices[-2]
140
+ else:
141
+ cut_last = sign_change_indices[-1]
142
+
143
+ white_fraction = np.mean(sign[cut_first:cut_last])
144
+ else:
145
+ white_fraction = np.nan
146
+ cut_first, cut_last = None, None
147
+ if make_plots:
148
+ fig, axs = plt.subplots(ncols=3, figsize=(16,6))
149
+ blurred_sum_normalized = blurred_sum - blurred_sum.min()
150
+ blurred_sum_normalized /= blurred_sum_normalized.max()
151
+ axs[0].plot(blurred_sum_normalized)
152
+ axs[0].plot(sign)
153
+ axs[1].plot(blurred_sum_normalized[cut_first:cut_last])
154
+ axs[1].plot(sign[cut_first:cut_last])
155
+ axs[2].imshow(img, cmap='gray')
156
+ for i, idx in enumerate(sign_change_indices):
157
+ plt.axvline(idx, c=['r', 'lime'][i%2])
158
+ fig.suptitle(f'fraction: {white_fraction:0.2f}')
159
+
160
+ return white_fraction
161
+
162
+ def process_img_crop(img, nm_per_px=1, make_plots=False, return_extra=False):
163
+
164
+ # image must be square
165
+ assert img.shape[0] == img.shape[1]
166
+ crop_size = img.shape[0]
167
+
168
+ # find orientation
169
+ rotation_angle, rotation_quality = get_rotation_with_confidence(img, blur_size=4, make_plots=make_plots)
170
+
171
+ # rotate and crop image
172
+ crop_margin = int((1 - 1/np.sqrt(2))*crop_size*0.5)
173
+ oriented_img = rotate(img, -rotation_angle)[2*crop_margin:-crop_margin, crop_margin:-crop_margin]
174
+
175
+ # calculate autocorrelation
176
+ autocorrelation = calculate_autocorrelation(oriented_img, blur_kernel=(7,1), make_plots=make_plots)
177
+
178
+ # find period
179
+ period, period_confidence = get_period_with_confidence(autocorrelation)
180
+ if make_plots:
181
+ print(f'period: {period}, confidence: {period_confidence}')
182
+
183
+ # find white fraction
184
+ white_fraction = calculate_white_fraction(oriented_img, make_plots=make_plots)
185
+ white_width = white_fraction*period
186
+
187
+ result = {
188
+ 'direction': rotation_angle,
189
+ 'direction confidence': rotation_quality,
190
+ 'period': period*nm_per_px,
191
+ 'period confidence': period_confidence,
192
+ 'lumen width': white_width*nm_per_px
193
+ }
194
+ if return_extra:
195
+ result['extra'] = {
196
+ 'autocorrelation': autocorrelation,
197
+ 'oriented_img': oriented_img
198
+ }
199
+
200
+ return result
201
+
202
+ def get_top_k(a, k):
203
+ ind = np.argpartition(a, -k)[-k:]
204
+ return a[ind]
205
+
206
+ def get_crops(img, distance_map, crop_size, N_sample):
207
+ crop_r= np.sqrt(2)*crop_size / 2
208
+ possible_positions_y, possible_positions_x = np.where(distance_map >= crop_r)
209
+ no_edge_mask = (possible_positions_y>crop_r) & \
210
+ (possible_positions_x>crop_r) & \
211
+ (possible_positions_y<(distance_map.shape[0]-crop_r)) & \
212
+ (possible_positions_x<(distance_map.shape[1]-crop_r))
213
+
214
+ possible_positions_x = possible_positions_x[no_edge_mask]
215
+ possible_positions_y = possible_positions_y[no_edge_mask]
216
+ N_available = len(possible_positions_x)
217
+ positions_indices = np.random.choice(np.arange(N_available), min(N_sample, N_available), replace=False)
218
+
219
+ for idx in positions_indices:
220
+ yield img[possible_positions_y[idx]-crop_size//2:possible_positions_y[idx]+crop_size//2,possible_positions_x[idx]-crop_size//2:possible_positions_x[idx]+crop_size//2].copy()
221
+
222
+ def sliced_mean(x, slice_size):
223
+ cs_y = np.cumsum(x, axis=0)
224
+ cs_y = np.concatenate((np.zeros((1, cs_y.shape[1]), dtype=cs_y.dtype), cs_y), axis=0)
225
+ slices_y = (cs_y[slice_size:] - cs_y[:-slice_size])/slice_size
226
+ cs_xy = np.cumsum(slices_y, axis=1)
227
+ cs_xy = np.concatenate((np.zeros((cs_xy.shape[0], 1), dtype=cs_xy.dtype), cs_xy), axis=1)
228
+ slices_xy = (cs_xy[:,slice_size:] - cs_xy[:,:-slice_size])/slice_size
229
+ return slices_xy
230
+
231
+ def sliced_var(x, slice_size):
232
+ x = x.astype('float64')
233
+ return sliced_mean(x**2, slice_size) - sliced_mean(x, slice_size)**2
234
+
235
+ def select_samples(granum_image, granum_mask, crop_size=96, n_samples=64, granum_fraction_min=1.0, variance_p=2):
236
+ granum_occupancy = sliced_mean(granum_mask, crop_size)
237
+ possible_indices = np.stack(np.where(granum_occupancy >= granum_fraction_min), axis=1)
238
+
239
+ if variance_p == 0:
240
+ p = np.ones(len(possible_indices))
241
+ else:
242
+ variance_map = sliced_var(granum_image, crop_size)
243
+ p = variance_map[possible_indices[:,0], possible_indices[:,1]]**variance_p
244
+ p /= np.sum(p)
245
+
246
+ chosen_indices = np.random.choice(
247
+ np.arange(len(possible_indices)),
248
+ min(len(possible_indices), n_samples),
249
+ replace=False,
250
+ p = p
251
+ )
252
+
253
+ crops = []
254
+ for crop_idx, idx in enumerate(chosen_indices):
255
+ crops.append(
256
+ granum_image[
257
+ possible_indices[idx,0]:possible_indices[idx,0]+crop_size,
258
+ possible_indices[idx,1]:possible_indices[idx,1]+crop_size
259
+ ]
260
+ )
261
+ return np.array(crops)
262
+
263
+ def calculate_distance_map(mask):
264
+ padded = np.pad(mask, pad_width=1, mode='constant', constant_values=False)
265
+ distance_map_padded = ndimage.distance_transform_edt(padded)
266
+ return distance_map_padded[1:-1,1:-1]
267
+
268
+
269
+ def measure_object(
270
+ img, mask,
271
+ nm_per_px=1, n_tries = 3,
272
+ direction_thr_min = 0.07, direction_thr_enough = 0.1,
273
+ crop_size = 200,
274
+ **kwargs):
275
+
276
+ distance_map = calculate_distance_map(mask)
277
+ crop_size = min(crop_size, int(min(get_top_k(distance_map.flatten(), n_tries)*0.5**0.5)))
278
+
279
+ direction_confidence = 0
280
+ best_stripes_data = {}
281
+ for i, img_crop in enumerate(get_crops(img, distance_map, crop_size, n_tries)):
282
+ stripes_data = process_img_crop(img_crop, nm_per_px=nm_per_px)
283
+ if stripes_data['direction confidence'] >= direction_confidence:
284
+ best_stripes_data = deepcopy(stripes_data)
285
+ direction_confidence = stripes_data['direction confidence']
286
+ if direction_confidence > direction_thr_enough:
287
+ break
288
+
289
+ result = best_stripes_data
290
+
291
+ if direction_confidence >= direction_thr_min:
292
+
293
+ mask_oriented = rotate(mask, 90-result['direction'], resize=True).astype('bool')
294
+ idx_begin_x, idx_end_x = np.where(np.any(mask_oriented, axis=0))[0][np.array([0, -1])]
295
+ idx_begin_y, idx_end_y = np.where(np.any(mask_oriented, axis=1))[0][np.array([0, -1])]
296
+ result['mask_oriented'] = mask_oriented[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x]
297
+ result['img_oriented'] = rotate(img, 90-result['direction'], resize=True)[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x]
298
+
299
+ # measurements = measure_granum_shape(result['mask_oriented'], nm_per_px=nm_per_px, oriented=True)
300
+ # else:
301
+ # measurements = measure_granum_shape(mask, nm_per_px=nm_per_px, oriented=False)
302
+
303
+ # result.update(**measurements)
304
+ # N_layers = result['height'] / result['period']
305
+ # if np.isfinite(N_layers):
306
+ # N_layers = round(N_layers)
307
+
308
+ return result
309
+
310
+ # def measure_object(
311
+ # img, mask,
312
+ # nm_per_px=1, n_tries = 3,
313
+ # direction_thr_min = 0.07, direction_thr_enough = 0.1,
314
+ # crop_size = 200,
315
+ # **kwargs):
316
+
317
+ # distance_map = calculate_distance_map(mask)
318
+ # crop_size = min(crop_size, int((min(get_top_k(distance_map.flatten(), n_tries)*0.5)**0.5)))
319
+
320
+ # direction_confidence = 0
321
+ # best_stripes_data = {}
322
+ # for i, img_crop in enumerate(select_samples(img, mask, crop_size=crop_size, n_samples=n_tries)):
323
+ # stripes_data = process_img_crop(img_crop, nm_per_px=nm_per_px)
324
+ # if stripes_data['direction_confidence'] >= direction_confidence:
325
+ # best_stripes_data = deepcopy(stripes_data)
326
+ # direction_confidence = stripes_data['direction_confidence']
327
+ # if direction_confidence > direction_thr_enough:
328
+ # break
329
+
330
+ # result = best_stripes_data
331
+
332
+ # if direction_confidence >= direction_thr_min:
333
+
334
+ # mask_oriented = rotate(mask, 90-result['direction'], resize=True).astype('bool')
335
+ # idx_begin_x, idx_end_x = np.where(np.any(mask_oriented, axis=0))[0][np.array([0, -1])]
336
+ # idx_begin_y, idx_end_y = np.where(np.any(mask_oriented, axis=1))[0][np.array([0, -1])]
337
+ # result['mask_oriented'] = mask_oriented[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x]
338
+ # result['img_oriented'] = rotate(img, 90-result['direction'], resize=True)[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x]
339
+
340
+ # # measurements = measure_granum_shape(result['mask_oriented'], nm_per_px=nm_per_px, oriented=True)
341
+ # # else:
342
+ # # measurements = measure_granum_shape(mask, nm_per_px=nm_per_px, oriented=False)
343
+
344
+ # # result.update(**measurements)
345
+ # # N_layers = result['height'] / result['period']
346
+ # # if np.isfinite(N_layers):
347
+ # # N_layers = round(N_layers)
348
+
349
+ # return result #{**measurements, **best_stripes_data, 'N layers': N_layers}
angle_calculation/envelope_correction.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy
3
+
4
+ def detect_boundaries(mask, axis):
5
+ # calculate the boundaries of the mask
6
+ #axis = 0 results in x_from, x_to
7
+ #axis = 1 results in y_from, y_to
8
+
9
+
10
+ sum = mask.sum(axis=axis)
11
+
12
+ ind_from = min(sum.nonzero()[0])
13
+ ind_to = max(sum.nonzero()[0])
14
+ return ind_from, ind_to
15
+
16
+ def area(mask):
17
+ x1, y1 = detect_boundaries(mask, 0)
18
+ a = y1 - x1
19
+ x2, y2 = detect_boundaries(mask, 1)
20
+ b = y2 - x2
21
+
22
+ return (a * b, x1, y1, x2, y2)
23
+
24
+ def calculate_best_angle_from_mask(mask, angles=np.arange(-10,10,0.5)):
25
+ areas = []
26
+ for angle in angles:
27
+ rotated_mask = scipy.ndimage.rotate(mask, angle, reshape=True, order = 0) # order = 0 is the nearest neighbor interpolation, so the mask is not interpolated
28
+ this_area = area(rotated_mask)
29
+ areas.append(this_area[0])
30
+
31
+ best_angle = angles[np.argmin(areas)]
32
+ return best_angle
33
+
34
+
angle_calculation/granum_utils.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image, ImageDraw, ImageFont
2
+ import numpy as np
3
+ from scipy import ndimage
4
+
5
+ from dataclasses import dataclass
6
+ from typing import Any, List
7
+ from zipfile import ZipFile
8
+
9
+ def add_text(image: Image.Image, text: str, location=(0.5, 0.5), color='red', size=40) -> Image.Image:
10
+ draw = ImageDraw.Draw(image)
11
+ font = ImageFont.load_default(size=size)
12
+ draw.text((int(image.size[0]*location[0]), int(image.size[1]*location[1])), text, font=font, fill=color)
13
+ return image
14
+
15
+
16
+ def select_unique_mask(mask):
17
+ """if mask consists of multiple parts, select the largest"""
18
+ blobs = ndimage.label(mask)[0]
19
+ blob_labels, blob_sizes = np.unique(blobs, return_counts=True)
20
+ best_blob_label = blob_labels[1:][np.argmax(blob_sizes[1:])]
21
+ return blobs == best_blob_label
22
+
23
+ def object_slice(mask, margin=128):
24
+ rows = np.any(mask, axis=1)
25
+ cols = np.any(mask, axis=0)
26
+ row_min, row_max = np.where(rows)[0][[0, -1]]
27
+ col_min, col_max = np.where(cols)[0][[0, -1]]
28
+
29
+ # Create a slice object for the bounding box
30
+ bounding_box_slice = (
31
+ slice(max(0,row_min-margin), min(row_max + 1+margin, len(rows)+1)),
32
+ slice(max(0,col_min-margin), min(col_max + 1+margin, len(cols)+1))
33
+ )
34
+
35
+ return bounding_box_slice
36
+
37
+ def resize_to(image: Image.Image, s=4032) -> Image.Image:
38
+ w, h = image.size
39
+ longest_size = max(h, w)
40
+
41
+ resize_factor = longest_size / s
42
+
43
+ resized_image = image.resize((int(w/resize_factor), int(h/resize_factor)))
44
+ return resized_image
45
+
46
+ def rolling_mean(x, window):
47
+ cs = np.r_[0, np.cumsum(x)]
48
+ rolling_sum = cs[window:] - cs[:-window]
49
+ return rolling_sum/window
50
+
51
+ @dataclass
52
+ class Granum:
53
+ image: Any = None#Optional[np.ndarray] = None
54
+ mask: Any = None #Optional[np.ndarray] = None
55
+ scaler: Any = None
56
+ nm_per_px: float = float('nan')
57
+ detection_confidence: float = float('nan')
58
+
59
+ def zip_files(files: List[str], output_name: str) -> None:
60
+ with ZipFile(output_name, "w") as zipObj:
61
+ for file in files:
62
+ zipObj.write(file)
63
+
64
+ def filter_boundary_detections(masks, scaler=None):
65
+ last_index_right = -1 if scaler is None else masks.shape[1]-1-scaler.pad_right
66
+ last_index_bottom = -1 if scaler is None else masks.shape[2]-1-scaler.pad_bottom
67
+ doesnt_touch_boundary_mask = ~(np.any(masks[:,0,:] != 0, axis=1) | np.any(masks[:,last_index_right:,:] != 0, axis=(1,2)) | np.any(masks[:,:,0] != 0, axis=1) | np.any(masks[:,:,last_index_bottom:] != 0, axis=(1,2)))
68
+ return doesnt_touch_boundary_mask
69
+
70
+ def get_circle_mask(shape, r=None):
71
+ if isinstance(shape, int):
72
+ shape = (shape, shape)
73
+ if r is None:
74
+ r = min(shape)/2
75
+ X, Y = np.meshgrid(np.arange(shape[1]), np.arange(shape[0]))
76
+ center_x = shape[1] / 2 - 0.5
77
+ center_y = shape[0] / 2 - 0.5
78
+
79
+ mask = ((X-center_x)**2 + (Y-center_y)**2) >= r**2
80
+ return mask
angle_calculation/image_transforms.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+
4
+ def batched_radon(image_batch):
5
+ batch_size, img_size = image_batch.shape[:2]
6
+ if batch_size > 512: # limit batch size to 512 because cv2.warpAffine fails for batch> 512
7
+ return np.concatenate([batched_radon(image_batch[i:i+512]) for i in range(0,batch_size,512)], axis=0)
8
+ theta = np.arange(180)
9
+ radon_image = np.zeros((image_batch.shape[0], img_size, len(theta)),
10
+ dtype='float32')
11
+
12
+ for i, angle in enumerate(theta):
13
+ M = cv2.getRotationMatrix2D(((img_size-1)/2.0,(img_size-1)/2.0),angle,1)
14
+ rotated = cv2.warpAffine(np.transpose(image_batch, (1, 2, 0)),M,(img_size,img_size))
15
+ if batch_size == 1: # cv2.warpAffine cancels batch dimension if equal to 1
16
+ rotated = rotated[:,:, np.newaxis]
17
+ rotated = np.transpose(rotated, (2, 0, 1))
18
+ rotated = rotated / np.array(255, dtype='float32')
19
+ radon_image[:, :, i] = rotated.sum(axis=1)
20
+ return radon_image
21
+
22
+ def get_center_crop_coords(height: int, width: int, crop_height: int, crop_width: int):
23
+ """from https://github.com/albumentations-team/albumentations/blob/master/albumentations/augmentations/crops/functional.py"""
24
+ y1 = (height - crop_height) // 2
25
+ y2 = y1 + crop_height
26
+ x1 = (width - crop_width) // 2
27
+ x2 = x1 + crop_width
28
+ return x1, y1, x2, y2
29
+
30
+ def center_crop(img: np.ndarray, crop_height: int, crop_width: int):
31
+ height, width = img.shape[:2]
32
+ x1, y1, x2, y2 = get_center_crop_coords(height, width, crop_height, crop_width)
33
+ img = img[y1:y2, x1:x2]
34
+ return img
angle_calculation/sampling.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from typing import Tuple
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+ from scipy import ndimage
6
+ import torch
7
+ from torchvision.transforms import functional as tvf
8
+
9
+ from pathlib import Path
10
+
11
+ def sliced_mean(x, slice_size):
12
+ cs_y = np.cumsum(x, axis=0)
13
+ cs_y = np.concatenate((np.zeros((1, cs_y.shape[1]), dtype=cs_y.dtype), cs_y), axis=0)
14
+ slices_y = (cs_y[slice_size:] - cs_y[:-slice_size])/slice_size
15
+ cs_xy = np.cumsum(slices_y, axis=1)
16
+ cs_xy = np.concatenate((np.zeros((cs_xy.shape[0], 1), dtype=cs_xy.dtype), cs_xy), axis=1)
17
+ slices_xy = (cs_xy[:,slice_size:] - cs_xy[:,:-slice_size])/slice_size
18
+ return slices_xy
19
+
20
+ def sliced_var(x, slice_size):
21
+ x = x.astype('float64')
22
+ return sliced_mean(x**2, slice_size) - sliced_mean(x, slice_size)**2
23
+
24
+ def calculate_local_variance(img, var_window):
25
+ """return local variance map with the same size as input image"""
26
+ var = sliced_var(img, var_window)
27
+
28
+ left_pad = var_window // 2 -1
29
+ right_pad = var_window -1 - left_pad
30
+ var_padded = np.pad(
31
+ var,
32
+ pad_width=(
33
+ (left_pad,right_pad),
34
+ (left_pad,right_pad)
35
+ ))
36
+ return var_padded
37
+
38
+ def get_crop_batch(img: np.ndarray, mask: np.ndarray, crop_size=96, crop_scales=np.geomspace(0.5, 2, 7), samples_per_scale=32, use_variance_threshold=False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
39
+ """
40
+ Generate a batch of cropped images from an input image and corresponding mask, at various scales and rotations.
41
+
42
+ Parameters
43
+ ----------
44
+ img : np.ndarray
45
+ The input image from which crops are generated.
46
+ mask : np.ndarray
47
+ The binary mask indicating the region of interest in the image.
48
+ crop_size : int, optional
49
+ The size of the square crop.
50
+ crop_scales : np.ndarray, optional
51
+ An array of scale factors to apply to the crop size.
52
+ samples_per_scale : int, optional
53
+ Number of samples to generate per scale factor.
54
+ use_variance_threshold : bool, optional
55
+ Flag to use variance thresholding for selecting crop locations.
56
+
57
+ Returns
58
+ -------
59
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
60
+ A tuple containing the tensor of crops, their rotation angles, and scale factors.
61
+ """
62
+
63
+ # pad
64
+ pad_size = int(np.ceil(0.5*crop_size*max(crop_scales)*(np.sqrt(2)-1)))
65
+ img_padded = np.pad(img, pad_size)
66
+ mask_padded = np.pad(mask, pad_size)
67
+
68
+ # distance map
69
+ distance_map_padded = ndimage.distance_transform_edt(mask_padded)
70
+ # TODO: adjust scales and samples_per_scale
71
+
72
+ if use_variance_threshold:
73
+ variance_window = min(crop_size//2, min(img.shape))
74
+ variance_map_padded = np.pad(calculate_local_variance(img, variance_window), pad_size)
75
+ variance_median = np.ma.median(np.ma.masked_where(distance_map_padded<0.5*variance_window, variance_map_padded))
76
+ variance_mask = variance_map_padded >= variance_median
77
+ else:
78
+ variance_mask = np.ones_like(mask_padded)
79
+
80
+ # initilize output
81
+ crops_granum = []
82
+ angles_granum = []
83
+ scales_granum = []
84
+ # loop over scales
85
+ for scale in crop_scales:
86
+ half_crop_size_scaled = int(np.floor(scale*0.5*crop_size)) # half of crop size after scaling
87
+ crop_pad = int(np.ceil((np.sqrt(2) - 1)*half_crop_size_scaled)) # pad added in order to allow rotation
88
+ half_crop_size_external = half_crop_size_scaled + crop_pad # size of "external crop" which will be rotated
89
+
90
+ possible_indices = np.stack(np.where(variance_mask & (distance_map_padded >= 2*half_crop_size_scaled)), axis=1)
91
+ if len(possible_indices) == 0:
92
+ continue
93
+ chosen_indices = np.random.choice(np.arange(len(possible_indices)), min(len(possible_indices), samples_per_scale), replace=False)
94
+
95
+ crops = [
96
+ img_padded[y-half_crop_size_external:y+half_crop_size_external, x-half_crop_size_external:x+half_crop_size_external] for y, x in possible_indices[chosen_indices]
97
+ ]
98
+
99
+ # rotate
100
+ rotation_angles = np.random.rand(len(crops))*180 - 90
101
+ crops = [
102
+ ndimage.rotate(crop, angle, reshape=False)[crop_pad:-crop_pad,crop_pad:-crop_pad] for crop, angle in zip(crops, rotation_angles)
103
+ ]
104
+ # add to output
105
+ crops_granum.append(tvf.resize(torch.tensor(np.array(crops)), (crop_size,crop_size),antialias=True)) # resize crops to crop_size
106
+ angles_granum.extend(rotation_angles.tolist())
107
+ scales_granum.extend([scale]*len(crops))
108
+
109
+ if len(angles_granum) == 0:
110
+ return [], [], []
111
+
112
+ crops_granum = torch.concat(crops_granum)
113
+ angles_granum = torch.tensor(angles_granum, dtype=torch.float)
114
+ scales_granum = torch.tensor(scales_granum, dtype=torch.float)
115
+
116
+ return crops_granum, angles_granum, scales_granum
117
+
118
+ def get_crop_batch_from_path(img_path, mask_path=None, use_variance_threshold=False):
119
+ """
120
+ Load an image and its mask from file paths and generate a batch of cropped images.
121
+
122
+ Parameters
123
+ ----------
124
+ img_path : str
125
+ Path to the input image.
126
+ mask_path : str, optional
127
+ Path to the binary mask image. If None, assumes mask path by replacing image extension with '.npy'.
128
+ use_variance_threshold : bool, optional
129
+ Flag to use variance thresholding for selecting crop locations.
130
+
131
+ Returns
132
+ -------
133
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
134
+ A tuple containing the tensor of crops, their rotation angles, and scale factors, obtained from the specified image path.
135
+ """
136
+ if mask_path is None:
137
+ mask_path = str(Path(img_path).with_suffix('.npy'))
138
+ mask = np.load(mask_path)
139
+ img = np.array(Image.open(img_path))[:,:,0]
140
+
141
+ return get_crop_batch(img, mask, use_variance_threshold=use_variance_threshold)
142
+
app.py ADDED
@@ -0,0 +1,602 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import time
4
+ import uuid
5
+ from datetime import datetime
6
+ from decimal import Decimal
7
+
8
+ import gradio as gr
9
+ import matplotlib.pyplot as plt
10
+
11
+ from settings import DEMO
12
+
13
+ plt.switch_backend("agg") # fix for "RuntimeError: main thread is not in main loop"
14
+ import numpy as np
15
+ import pandas as pd
16
+ from PIL import Image
17
+
18
+ from model import GranaAnalyser
19
+
20
+ ga = GranaAnalyser(
21
+ "weights/yolo/20240604_yolov8_segm_ABRCR1_all_train4_best.pt",
22
+ "weights/AS_square_v16.ckpt",
23
+ "weights/period_measurer_weights-1.298_real_full-fa12970.ckpt",
24
+ )
25
+
26
+
27
+ def calc_ratio(pixels, nano):
28
+ """
29
+ Calculates ratio of pixels to nanometers and returns as str to populate ratio_input
30
+ :param pixels:
31
+ :param nano:
32
+ :return:
33
+ """
34
+ if not (pixels and nano):
35
+ pass
36
+ else:
37
+ res = pixels / nano
38
+ return res
39
+
40
+
41
+ # https://jakevdp.github.io/PythonDataScienceHandbook/05.13-kernel-density-estimation.html
42
+ def KDE(dataset, h):
43
+ # the Kernel function
44
+ def K(x):
45
+ return np.exp(-(x ** 2) / 2) / np.sqrt(2 * np.pi)
46
+
47
+ n_samples = dataset.size
48
+
49
+ x_range = dataset # x-value range for plotting KDEs
50
+
51
+ total_sum = 0
52
+ # iterate over datapoints
53
+ for i, xi in enumerate(dataset):
54
+ total_sum += K((x_range - xi) / h)
55
+
56
+ y_range = total_sum / (h * n_samples)
57
+
58
+ return y_range
59
+
60
+
61
+ def prepare_files_for_download(
62
+ dir_name,
63
+ grana_data,
64
+ aggregated_data,
65
+ detection_visualizations_dict,
66
+ images_grana_dict,
67
+ ):
68
+ """
69
+ Save and zip files for download
70
+ :param dir_name:
71
+ :param grana_data: DataFrame containing all grana measurements
72
+ :param aggregated_data: dict containing aggregated measurements
73
+ :return:
74
+ """
75
+ dir_to_zip = f"{dir_name}/to_zip"
76
+
77
+ # raw data
78
+ grana_data_csv_path = f"{dir_to_zip}/grana_raw_data.csv"
79
+ grana_data.to_csv(grana_data_csv_path, index=False)
80
+
81
+ # aggregated measurements
82
+ aggregated_csv_path = f"{dir_to_zip}/grana_aggregated_data.csv"
83
+ aggregated_data.to_csv(aggregated_csv_path)
84
+
85
+ # annotated pictures
86
+ masked_images_dir = f"{dir_to_zip}/annotated_images"
87
+ os.makedirs(masked_images_dir)
88
+ for img_name, img in detection_visualizations_dict.items():
89
+ filename_split = img_name.split(".")
90
+ extension = filename_split[-1]
91
+ filename = ".".join(filename_split[:-1])
92
+ filename = f"{filename}_annotated.{extension}"
93
+ img.save(f"{masked_images_dir}/{filename}")
94
+
95
+ # single_grana images
96
+ grana_images_dir = f"{dir_to_zip}/single_grana_images"
97
+ os.makedirs(grana_images_dir)
98
+ org_images_dict = pd.Series(
99
+ grana_data["source image"].values, index=grana_data["granum ID"]
100
+ ).to_dict()
101
+ for img_name, img in images_grana_dict.items():
102
+ org_filename = org_images_dict[img_name]
103
+ org_filename_split = org_filename.split(".")
104
+ org_filename_no_ext = ".".join(org_filename_split[:-1])
105
+ img_name_ext = f"{org_filename_no_ext}_granum_{str(img_name)}.png"
106
+ img.save(f"{grana_images_dir}/{img_name_ext}")
107
+
108
+ # zip all files
109
+ date_str = datetime.today().strftime("%Y-%m-%d")
110
+ zip_name = f"GRANA_results_{date_str}"
111
+ zip_path = f"{dir_name}/{zip_name}"
112
+ shutil.make_archive(zip_path, "zip", dir_to_zip)
113
+
114
+ # delete to_zip dir
115
+ zip_dir_path = os.path.join(os.getcwd(), dir_to_zip)
116
+ shutil.rmtree(zip_dir_path)
117
+
118
+ download_file_path = f"{zip_path}.zip"
119
+ return download_file_path
120
+
121
+
122
+ def show_info_on_submit(s):
123
+ return (
124
+ gr.Button(interactive=False),
125
+ gr.Button(interactive=False),
126
+ gr.Row(visible=True),
127
+ gr.Row(visible=False),
128
+ )
129
+
130
+
131
+ def load_css():
132
+ with open("styles.css", "r") as f:
133
+ css_content = f.read()
134
+ return css_content
135
+
136
+
137
+ primary_hue = gr.themes.Color(
138
+ c50="#e1f8ee",
139
+ c100="#b7efd5",
140
+ c200="#8de6bd",
141
+ c300="#63dda5",
142
+ c400="#39d48d",
143
+ c500="#27b373",
144
+ c600="#1e8958",
145
+ c700="#155f3d",
146
+ c800="#0c3522",
147
+ c900="#030b07",
148
+ c950="#000",
149
+ )
150
+
151
+
152
+ theme = gr.themes.Default(
153
+ primary_hue=primary_hue,
154
+ font=[gr.themes.GoogleFont("Ubuntu"), "ui-sans-serif", "system-ui", "sans-serif"],
155
+ )
156
+
157
+
158
+ def draw_violin_plot(y, ylabel, title):
159
+ # only generate plot for 3 or more values
160
+ if y.count() < 3:
161
+ return None
162
+
163
+ # Colors
164
+ RED_DARK = "#850e00"
165
+ DARK_GREEN = "#0c3522"
166
+ BRIGHT_GREEN = "#8de6bd"
167
+
168
+ # Create jittered version of "x" (which is only 1)
169
+ x_jittered = []
170
+ kde = KDE(y, (y.max() - y.min()) / y.size / 2)
171
+ kde = kde / kde.max() * 0.2
172
+ for y_val in kde:
173
+ x_jittered.append(1 + np.random.uniform(-y_val, y_val, 1))
174
+
175
+ fig = plt.figure()
176
+ ax = fig.add_subplot(1, 1, 1)
177
+ ax.scatter(x=x_jittered, y=y, s=20, alpha=0.4, c=DARK_GREEN)
178
+
179
+ violins = ax.violinplot(
180
+ y,
181
+ widths=0.45,
182
+ bw_method="silverman",
183
+ showmeans=False,
184
+ showmedians=False,
185
+ showextrema=False,
186
+ )
187
+
188
+ # change violin color
189
+ for pc in violins["bodies"]:
190
+ pc.set_facecolor(BRIGHT_GREEN)
191
+
192
+ # add a boxplot to ax
193
+ # but make the whiskers length equal to 1 SD, i.e. in the proportion of the IQ range, but this length should start from the mean but be visible from the box boundary
194
+ lower = np.mean(y) - 1 * np.std(y)
195
+ upper = np.mean(y) + 1 * np.std(y)
196
+
197
+ medianprops = dict(linewidth=1, color="black", solid_capstyle="butt")
198
+ boxplot_stats = [
199
+ {
200
+ "med": np.median(y),
201
+ "q1": np.percentile(y, 25),
202
+ "q3": np.percentile(y, 75),
203
+ "whislo": lower,
204
+ "whishi": upper,
205
+ }
206
+ ]
207
+
208
+ ax.bxp(
209
+ boxplot_stats, # data for the boxplot
210
+ showfliers=False, # do not show the outliers beyond the caps.
211
+ showcaps=True, # show the caps
212
+ medianprops=medianprops,
213
+ )
214
+
215
+ # Add mean value point
216
+ ax.scatter(1, y.mean(), s=30, color=RED_DARK, zorder=3)
217
+
218
+ ax.set_xticks([])
219
+ ax.set_ylabel(ylabel)
220
+ ax.set_title(title)
221
+ fig.tight_layout()
222
+
223
+ return fig
224
+
225
+
226
+ def transform_aggregated_results_table(results_dict):
227
+ MEASUREMENT_HEADER = "measurement [unit]"
228
+ VALUE_HEADER = "value +-SD"
229
+
230
+ def get_value_str(value, std):
231
+ if np.isnan(value) or np.isnan(std):
232
+ return "-"
233
+ value_str = str(Decimal(str(value)).quantize(Decimal("0.01")))
234
+ std_str = str(Decimal(str(std)).quantize(Decimal("0.01")))
235
+ return f"{value_str} +-{std_str}"
236
+
237
+ def append_to_dict(new_key, old_val_key, old_sd_key):
238
+ aggregated_dict[MEASUREMENT_HEADER].append(new_key)
239
+ value_str = get_value_str(results_dict[old_val_key], results_dict[old_sd_key])
240
+ aggregated_dict[VALUE_HEADER].append(value_str)
241
+
242
+ aggregated_dict = {MEASUREMENT_HEADER: [], VALUE_HEADER: []}
243
+
244
+ # area
245
+ append_to_dict("area [nm^2]", "area nm^2", "area nm^2 std")
246
+
247
+ # perimeter
248
+ append_to_dict("perimeter [nm]", "perimeter nm", "perimeter nm std")
249
+
250
+ # diameter
251
+ append_to_dict("diameter [nm]", "diameter nm", "diameter nm std")
252
+
253
+ # height
254
+ append_to_dict("height [nm]", "height nm", "height nm std")
255
+
256
+ # number of layers
257
+ append_to_dict("number of thylakoids", "Number of layers", "Number of layers std")
258
+
259
+ # SRD
260
+ append_to_dict("SRD [nm]", "period nm", "period nm std")
261
+
262
+ # GSI
263
+ append_to_dict("GSI", "GSI", "GSI std")
264
+
265
+ # N grana
266
+ aggregated_dict[MEASUREMENT_HEADER].append("number of grana")
267
+ aggregated_dict[VALUE_HEADER].append(str(int(results_dict["N grana"])))
268
+
269
+ return aggregated_dict
270
+
271
+
272
+ def rename_columns_in_results_table(results_table):
273
+ column_names = {
274
+ "Granum ID": "granum ID",
275
+ "File name": "source image",
276
+ "area nm^2": "area [nm^2]",
277
+ "perimeter nm": "perimeter [nm]",
278
+ "diameter nm": "diameter [nm]",
279
+ "height nm": "height [nm]",
280
+ "Number of layers": "number of thylakoids",
281
+ "period nm": "SRD [nm]",
282
+ "period SD nm": "SRD SD [nm]",
283
+ }
284
+ results_table = results_table.rename(columns=column_names)
285
+ return results_table
286
+
287
+
288
+ with gr.Blocks(css=load_css(), theme=theme) as demo:
289
+
290
+ svg = """
291
+ <svg id="Layer_1" data-name="Layer 1" xmlns="http://www.w3.org/2000/svg" viewBox="0 0 30.73 33.38">
292
+ <defs>
293
+ <style>
294
+ .cls-1 {
295
+ fill: #27b373;
296
+ stroke-width: 0px;
297
+ }
298
+ </style>
299
+ </defs>
300
+ <path class="cls-1" d="M19.69,11.73h-3.22c-2.74,0-4.96,2.22-4.96,4.96h0c0,2.74,2.22,4.96,4.96,4.96h3.43c.56,0,1,.51.89,1.09-.08.43-.49.72-.92.72h-8.62c-.74,0-1.34-.6-1.34-1.34v-10.87c0-.74.6-1.34,1.34-1.34h13.44c2.73,0,4.95-2.22,4.95-4.95h0c0-2.75-2.22-4.97-4.96-4.97h-13.85C4.85,0,0,4.85,0,10.83v11.71c0,5.98,4.85,10.83,10.83,10.83h9.07c5.76,0,10.49-4.52,10.81-10.21.35-6.29-4.72-11.44-11.02-11.44ZM19.9,31.4h-9.07c-4.89,0-8.85-3.96-8.85-8.85v-11.71C1.98,5.95,5.95,1.98,10.83,1.98h13.81c1.64,0,2.97,1.33,2.97,2.97h0c0,1.65-1.33,2.97-2.96,2.97h-13.4c-1.83,0-3.32,1.49-3.32,3.32v10.87c0,1.83,1.49,3.32,3.32,3.32h8.56c1.51,0,2.83-1.12,2.97-2.62.16-1.72-1.2-3.16-2.88-3.16h-3.52c-1.64,0-2.97-1.33-2.97-2.97h0c0-1.64,1.33-2.97,2.97-2.97h3.34c4.83,0,8.9,3.81,9.01,8.64s-3.9,9.04-8.84,9.04Z"/>
301
+ <path class="cls-1" d="M19.9,29.41h-9.07c-3.79,0-6.87-3.07-6.87-6.87v-11.71c0-3.79,3.07-6.87,6.87-6.87h13.81c.55,0,.99.44.99.99h0c0,.55-.44.99-.99.99h-13.81c-2.7,0-4.88,2.19-4.88,4.88v11.71c0,2.7,2.19,4.88,4.88,4.88h8.94c2.64,0,4.91-2.05,5-4.7s-2.12-5.05-4.87-5.05h-3.52c-.55,0-.99-.44-.99-.99h0c0-.55.44-.99.99-.99h3.36c3.74,0,6.9,2.92,7.01,6.66.11,3.87-3.01,7.06-6.85,7.06Z"/>
302
+ </svg>
303
+ """
304
+
305
+ gr.HTML(
306
+ f'<div class="header"><div id="header-logo">{svg}</div><div id="header-text">GRANA<div></div>'
307
+ )
308
+
309
+ with gr.Row(elem_classes="input-row"): # input
310
+ with gr.Column():
311
+ gr.HTML(
312
+ "<h1>1. Choose images to upload. All the images need to be of the same scale and experimental variant.</h1>"
313
+ )
314
+ img_input = gr.File(file_count="multiple")
315
+
316
+ gr.HTML("<h1>2. Set the scale of the images for the measurements.</h1>")
317
+ with gr.Row():
318
+ with gr.Column():
319
+ gr.HTML("Either provide pixel per nanometer ratio...")
320
+ ratio_input = gr.Number(
321
+ label="pixel per nm", precision=3, step=0.001
322
+ )
323
+
324
+ with gr.Column():
325
+ gr.HTML("...or length of the scale bar in pixels and nanometers.")
326
+ pixels_input = gr.Number(label="Length in pixels")
327
+ nano_input = gr.Number(label="Length in nanometers")
328
+
329
+ pixels_input.change(
330
+ calc_ratio,
331
+ inputs=[pixels_input, nano_input],
332
+ outputs=ratio_input,
333
+ )
334
+ nano_input.change(
335
+ calc_ratio,
336
+ inputs=[pixels_input, nano_input],
337
+ outputs=ratio_input,
338
+ )
339
+
340
+ with gr.Row():
341
+ clear_btn = gr.ClearButton(img_input, "Clear")
342
+ submit_btn = gr.Button("Submit", variant="primary")
343
+
344
+ with gr.Row(visible=False) as loading_row:
345
+ with gr.Column():
346
+ gr.HTML(
347
+ "<div class='processed-info'>Images are being processed. This may take a while...</div>"
348
+ )
349
+
350
+ with gr.Row(visible=False) as output_row:
351
+ with gr.Column():
352
+ gr.HTML(
353
+ '<div class="results-header">Results</div>'
354
+ "<p>Full results are a zip file containing:<p>"
355
+ "<ul>- grana_raw_data.csv: a table with full grana measurements,</ul>"
356
+ "<ul>- grana_aggregated_data.csv: a table with aggregated measurements,</ul>"
357
+ '<ul>- directory "annotated_images" with all submitted images with masks on detected grana,</ul>'
358
+ '<ul>- directory "single_grana_images" with images of all detected grana.</ul>'
359
+ "<p>Note that GRANA only stores the result files for 1 hour.</p>",
360
+ elem_classes="input-row",
361
+ )
362
+ with gr.Row(elem_classes="input-row"):
363
+ download_file_out = gr.DownloadButton(
364
+ label="Download results",
365
+ variant="primary",
366
+ elem_classes="margin-bottom",
367
+ )
368
+ with gr.Row():
369
+ gr.HTML(
370
+ '<h2 class="title">Annotated images</h2>'
371
+ "Gallery of uploaded images with masks of recognized grana structures. "
372
+ "Each granum mask is "
373
+ "labeled with its number. Note that only fully visible grana in the image are masked."
374
+ )
375
+ with gr.Row(elem_classes="margin-bottom"):
376
+ gallery_out = gr.Gallery(
377
+ columns=4,
378
+ rows=2,
379
+ object_fit="contain",
380
+ label="Detection visualizations",
381
+ show_download_button=False,
382
+ )
383
+
384
+ with gr.Row(elem_classes="input-row"):
385
+ gr.HTML(
386
+ '<h2 class="title">Aggregated results for all uploaded images</h2>'
387
+ )
388
+ with gr.Row(elem_classes=["input-row", "margin-bottom"]):
389
+ table_out = gr.Dataframe(label="Aggregated data")
390
+
391
+ with gr.Row():
392
+ gr.HTML(
393
+ '<h2 class="title">Violin graphs</h2>'
394
+ "These graphs present aggregated results for selected structural parameters. "
395
+ "The graph for each parameter is only generated if three or more values are available. "
396
+ "Each graph "
397
+ "displays individual data points, a box plot indicating the first and third quartiles, whiskers "
398
+ "marking the standard deviation (SD), the median value (horizontal line on the box plot), "
399
+ "the mean value (red dot), and a density plot where the width represents the frequency."
400
+ )
401
+
402
+ with gr.Row():
403
+ area_plot_out = gr.Plot(label="Area")
404
+ perimeter_plot_out = gr.Plot(label="Perimeter")
405
+ gsi_plot_out = gr.Plot(label="GSI")
406
+
407
+ with gr.Row(elem_classes="margin-bottom"):
408
+ diameter_plot_out = gr.Plot(label="Diameter")
409
+ height_plot_out = gr.Plot(label="Height")
410
+ srd_plot_out = gr.Plot(label="SRD")
411
+
412
+ with gr.Row():
413
+ gr.HTML(
414
+ '<h2 class="title">Recognized and rotated grana structures</h2>'
415
+ )
416
+
417
+ with gr.Row(elem_classes="margin-bottom"):
418
+ gallery_single_grana_out = gr.Gallery(
419
+ columns=4,
420
+ rows=2,
421
+ object_fit="contain",
422
+ label="Single grana images",
423
+ show_download_button=False,
424
+ )
425
+
426
+ with gr.Row():
427
+ gr.HTML(
428
+ '<h2 class="title">Full results</h2>'
429
+ "Note that structural parameters other than area and perimeter are only calculated for the grana "
430
+ "whose direction and/or SRD could be estimated."
431
+ )
432
+ with gr.Row():
433
+ table_full_out = gr.Dataframe(label="Full measurements data")
434
+
435
+ submit_btn.click(
436
+ show_info_on_submit,
437
+ inputs=[submit_btn],
438
+ outputs=[submit_btn, clear_btn, loading_row, output_row],
439
+ )
440
+
441
+ def enable_submit():
442
+ return (
443
+ gr.Button(interactive=True),
444
+ gr.Button(interactive=True),
445
+ gr.Row(visible=False),
446
+ )
447
+
448
+ def gradio_analize_image(images, scale):
449
+ """
450
+ Model accepts following parameters:
451
+ :param images: list of images to be processed, in either tiff or png format
452
+ :param scale: float, nm to pixel ratio
453
+
454
+ Model returns the following objects:
455
+ - detection_visualizations: list of images with masks to be displayed as gallery and served to download
456
+ as zip of images
457
+ - grana_data: dataframe with measurements for each image to be served to download as a csv file
458
+ - images_grana: list of images with single grana to be served to download as zip of images
459
+ - aggregated_data: dataframe with aggregated measurements for all images to be displayed as table and served
460
+ to download as csv
461
+ """
462
+
463
+ # validate that at least one image has been uploaded
464
+ if images is None or len(images) == 0:
465
+ raise gr.Error("Please upload at least one image")
466
+
467
+ # on demo instance, we limit the number of images to 5
468
+ if DEMO:
469
+ if len(images) > 5:
470
+ raise gr.Error("In demo version it is possible to analyze up to 5 images.")
471
+
472
+ # validate that scale has been provided correctly
473
+ if scale is None or scale == 0:
474
+ raise gr.Error("Please provide scale. Use dot as decimal separator")
475
+
476
+ # validate that all images are png or tiff
477
+ for image in images:
478
+ if not image.name.lower().endswith((".png", ".tif", ".jpg", ".jpeg")):
479
+ raise gr.Error("Only png, tiff, jpg ang jpeg images are supported")
480
+
481
+ # clean up previous results
482
+ # find all directories in current working directory that start with "results_"
483
+ # that were created more than 1 hour ago and delete them with all contents
484
+ for directory_name in os.listdir():
485
+ if directory_name.startswith("results_"):
486
+ dir_path = os.path.join(os.getcwd(), directory_name)
487
+ if os.path.isdir(dir_path):
488
+ if time.time() - os.path.getctime(dir_path) > 60 * 60:
489
+ shutil.rmtree(dir_path)
490
+
491
+ # create a directory for results
492
+ results_dir_name = "results_{uuid}".format(uuid=uuid.uuid4().hex)
493
+ os.makedirs(results_dir_name)
494
+ zip_dir_name = f"{results_dir_name}/to_zip"
495
+ os.makedirs(zip_dir_name)
496
+
497
+ # model takes a dict of images, so we need to convert input to list of PIL.PngImagePlugin.PngImageFile or
498
+ # PIL.TiffImagePlugin.TiffImageFile objects
499
+ images_dict = {
500
+ image.name.split("/")[-1]: Image.open(image.name)
501
+ for i, image in enumerate(images)
502
+ }
503
+
504
+ # model works here
505
+ (
506
+ detection_visualizations_dict,
507
+ grana_data,
508
+ images_grana_dict,
509
+ aggregated_data,
510
+ ) = ga.predict(images_dict, scale)
511
+ detection_visualizations = list(detection_visualizations_dict.values())
512
+ images_grana = list(images_grana_dict.values())
513
+
514
+ # rearrange aggregated data to be displayed as table
515
+ aggregated_dict = transform_aggregated_results_table(aggregated_data)
516
+ aggregated_df_transposed = pd.DataFrame.from_dict(aggregated_dict)
517
+
518
+ # rename columns in full results
519
+ grana_data = rename_columns_in_results_table(grana_data)
520
+
521
+ # save files returned by model to disk so they can be retrieved for downloading
522
+ download_file_path = prepare_files_for_download(
523
+ results_dir_name,
524
+ grana_data,
525
+ aggregated_df_transposed,
526
+ detection_visualizations_dict,
527
+ images_grana_dict,
528
+ )
529
+
530
+ # generate plot
531
+ area_fig = draw_violin_plot(
532
+ grana_data["area [nm^2]"].dropna(),
533
+ "Granum area [nm^2]",
534
+ "Grana areas from all uploaded images",
535
+ )
536
+ perimeter_fig = draw_violin_plot(
537
+ grana_data["perimeter [nm]"].dropna(),
538
+ "Granum perimeter [nm]",
539
+ "Grana perimeters from all uploaded images",
540
+ )
541
+ gsi_fig = draw_violin_plot(
542
+ grana_data["GSI"].dropna(),
543
+ "GSI",
544
+ "GSI from all uploaded images",
545
+ )
546
+ diameter_fig = draw_violin_plot(
547
+ grana_data["diameter [nm]"].dropna(),
548
+ "Granum diameter [nm]",
549
+ "Grana diameters from all uploaded images",
550
+ )
551
+ height_fig = draw_violin_plot(
552
+ grana_data["height [nm]"].dropna(),
553
+ "Granum height [nm]",
554
+ "Grana heights from all uploaded images",
555
+ )
556
+ srd_fig = draw_violin_plot(
557
+ grana_data["SRD [nm]"].dropna(), "SRD [nm]", "SRD from all uploaded images"
558
+ )
559
+
560
+ return [
561
+ gr.Row(visible=True),
562
+ gr.Row(visible=True),
563
+ download_file_path,
564
+ detection_visualizations,
565
+ aggregated_df_transposed,
566
+ area_fig,
567
+ perimeter_fig,
568
+ gsi_fig,
569
+ diameter_fig,
570
+ height_fig,
571
+ srd_fig,
572
+ images_grana,
573
+ grana_data,
574
+ ]
575
+
576
+ submit_btn.click(
577
+ fn=gradio_analize_image,
578
+ inputs=[
579
+ img_input,
580
+ ratio_input,
581
+ ],
582
+ outputs=[
583
+ loading_row,
584
+ output_row,
585
+ # file_download_checkboxes,
586
+ download_file_out,
587
+ gallery_out,
588
+ table_out,
589
+ area_plot_out,
590
+ perimeter_plot_out,
591
+ gsi_plot_out,
592
+ diameter_plot_out,
593
+ height_plot_out,
594
+ srd_plot_out,
595
+ gallery_single_grana_out,
596
+ table_full_out,
597
+ ],
598
+ ).then(fn=enable_submit, inputs=[], outputs=[submit_btn, clear_btn, loading_row])
599
+
600
+ demo.launch(
601
+ share=False, debug=True, server_name="0.0.0.0", allowed_paths=["images/logo.svg"]
602
+ )
grana_detection/mmwrapper.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from typing import Union, Optional
4
+ from PIL import Image
5
+ from mmdet.apis import DetInferencer
6
+ from ultralytics.engine.results import Results
7
+ import warnings
8
+
9
+ class MMDetector(DetInferencer):
10
+ def __call__(
11
+ self,
12
+ inputs,
13
+ ) -> Results:
14
+ """Call the inferencer as in DetInferencer but for single image.
15
+
16
+ Args:
17
+ inputs (np.ndarray | str): Inputs for the inferencer.
18
+
19
+ Returns:
20
+ Result: yolo-like result
21
+ """
22
+
23
+ ori_inputs = self._inputs_to_list(inputs)
24
+
25
+ data = list(self.preprocess(
26
+ ori_inputs, batch_size=1))[0][1]
27
+
28
+ preds = self.forward(data)[0]
29
+
30
+ yolo_result = Results(
31
+ orig_img=ori_inputs[0], path="", names=[""],
32
+ boxes=torch.cat((preds.pred_instances.bboxes, preds.pred_instances.scores.unsqueeze(-1), preds.pred_instances.labels.unsqueeze(-1)), dim=1),
33
+ masks=preds.pred_instances.masks
34
+ )
35
+
36
+ return yolo_result
37
+
38
+ def predict(self, source: Image.Image, conf=None):
39
+ """yolo interface"""
40
+ if conf is not None:
41
+ warnings.warn(f"confidence value {conf} ignored")
42
+ return [self.__call__(np.array(source.convert("RGB")))]
model.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import itertools
2
+ import warnings
3
+ from io import BytesIO
4
+ from copy import deepcopy
5
+ from pathlib import Path
6
+ from typing import List, Tuple, Dict, Optional, Any, Union
7
+ from dataclasses import dataclass, field
8
+
9
+ from PIL import Image
10
+ import numpy as np
11
+ import pandas as pd
12
+ import matplotlib.pyplot as plt
13
+ from scipy import ndimage
14
+ from skimage import measure
15
+ import torchvision.transforms.functional as tvf
16
+ from ultralytics import YOLO
17
+ import torch
18
+ import cv2
19
+ import gradio
20
+
21
+ import sys, os
22
+ sys.path.append(os.path.abspath('angle_calculation'))
23
+ # from classic import measure_object
24
+ from sampling import get_crop_batch
25
+ from angle_model import PatchedPredictor, StripsModelLumenWidth
26
+
27
+ from period_calculation.period_measurer import PeriodMeasurer
28
+ # from grana_detection.mmwrapper import MMDetector # mmdet installation in docker is problematic for now
29
+
30
+ @dataclass
31
+ class Granum:
32
+ id: Optional[int] = None
33
+ image: Any = None
34
+ mask: Any = None
35
+ scaler: Any = None
36
+ nm_per_px: float = float('nan')
37
+ detection_confidence: float = float('nan')
38
+ img_oriented: Optional[np.ndarray] = None # oriented fragment of the image
39
+ mask_oriented: Optional[np.ndarray] = None # oriented fragment of the mask
40
+ measurements: dict = field(default_factory=dict) # dict with grana measurements
41
+
42
+ class ScalerPadder:
43
+ """resize and pad image to specific range.
44
+ minimal_pad: obligatory padding, e.g. required for detector
45
+ """
46
+ def __init__(self, target_size=1024, target_short_edge_min=640, minimal_pad=16, pad_to_multiply=32):
47
+ self.minimal_pad = minimal_pad
48
+ self.target_size = target_size - 2*self.minimal_pad # detection pad is necessary padding size
49
+ self.target_short_edge_min = target_short_edge_min - 2*self.minimal_pad
50
+
51
+ self.max_size_nm = 6000 # training images covers ~3100 nm
52
+ self.min_size_nm = 2400 # training images covers ~3100 nm
53
+ self.pad_to_multiply = pad_to_multiply
54
+
55
+ def transform(self, image: Image.Image, px_per_nm: float=1.298) -> Image.Image:
56
+ self.original_size = image.size
57
+ self.original_px_per_nm = px_per_nm
58
+ w, h = self.original_size
59
+ longest_size = max(h, w)
60
+ img_size_nm = longest_size / px_per_nm
61
+ if img_size_nm > self.max_size_nm:
62
+ error_message = f'too large image, image size: {img_size_nm:0.1f}nm, max allowed: {self.max_size_nm}nm'
63
+ # raise ValueError(error_message)
64
+ # warnings.warn(warning_message)
65
+ gradio.Warning(error_message)
66
+ # add_text(image, warning_message, location=(0.1, 0.1), color='blue', size=int(40*longest_size/self.target_size))
67
+
68
+ self.resize_factor = self.target_size / (max(self.min_size_nm, img_size_nm) * px_per_nm)
69
+ self.px_per_nm_transformed = px_per_nm * self.resize_factor
70
+
71
+ resized_image = resize_with_cv2(image, (int(h*self.resize_factor), int(w*self.resize_factor)))
72
+
73
+ if w >= h:
74
+ pad_w = self.target_size-resized_image.size[0]
75
+ pad_h = max(0, self.target_short_edge_min-resized_image.size[1])
76
+ else:
77
+ pad_w = max(0, self.target_short_edge_min-resized_image.size[0])
78
+ pad_h = self.target_size-resized_image.size[1]
79
+
80
+ # apply minimal padding
81
+ pad_w += 2*self.minimal_pad
82
+ pad_h += 2*self.minimal_pad
83
+
84
+ # round to multiplication
85
+ pad_w += (self.pad_to_multiply - resized_image.size[0]%self.pad_to_multiply)%self.pad_to_multiply
86
+ pad_h += (self.pad_to_multiply - resized_image.size[1]%self.pad_to_multiply)%self.pad_to_multiply
87
+
88
+ self.pad_right = pad_w // 2
89
+ self.pad_left = pad_w - self.pad_right
90
+
91
+ self.pad_up = pad_h // 2
92
+ self.pad_bottom = pad_h - self.pad_up
93
+
94
+ padded_image = tvf.pad(resized_image, [self.pad_left,self.pad_up, self.pad_right, self.pad_bottom], padding_mode='reflect') # fill 114 as in YOLO
95
+ return padded_image
96
+
97
+ @property
98
+ def unpad_slice(self) -> Tuple[slice]:
99
+ return slice(self.pad_up,-self.pad_bottom if self.pad_bottom>0 else None), slice(self.pad_left,-self.pad_right if self.pad_right>0 else None)
100
+
101
+ def inverse_transform(self, image: Union[np.ndarray, Image.Image], output_size: Optional[Tuple[int]]=None, output_nm_per_px: Optional[float]=None, return_pil: bool=True) -> Image.Image:
102
+ if isinstance(image, Image.Image):
103
+ image = np.array(image)
104
+ # h, w = image.shape[:2]
105
+ # unpadded_image = image[self.pad_up:h-self.pad_bottom,self.pad_left:w-self.pad_right]
106
+ unapdded_image = image[self.unpad_slice]
107
+
108
+ if output_size is not None and output_nm_per_px is not None:
109
+ raise ValueError("one of output_size or output_nm_per_px must not be None")
110
+ elif output_nm_per_px is not None:
111
+ resize_factor = self.original_nm_per_px/output_nm_per_px
112
+ output_size = (int(self.original_size[0]*resize_factor), int(self.original_size[1]*resize_factor))
113
+ elif output_size is None:
114
+ output_size = self.original_size
115
+ resized_image = resize_with_cv2(unapdded_image, (output_size[1],output_size[0]), return_pil=return_pil) #Image.fromarray(unpadded_image).resize(self.original_size)
116
+
117
+ return resized_image
118
+
119
+ def close_contour(contour):
120
+ if not np.array_equal(contour[0], contour[-1]):
121
+ contour = np.vstack((contour, contour[0]))
122
+ return contour
123
+
124
+ def binary_mask_to_polygon(binary_mask, tol=0.01):
125
+ padded_binary_mask = np.pad(binary_mask, pad_width=1, mode='constant', constant_values=0)
126
+ contours = measure.find_contours(padded_binary_mask, 0.5)
127
+ # assert len(contours) == 1 #raise error if there are more than 1 contour
128
+ contour = contours[0]
129
+ contour -= 1 # correct for padding
130
+ contour = close_contour(contour)
131
+
132
+ polygon = measure.approximate_polygon(contour, tol)
133
+
134
+ polygon = np.flip(polygon, axis=1)
135
+ # after padding and subtracting 1 we may get -0.5 points in our polygon. Replace it with 0
136
+ polygon = np.where(polygon>=0, polygon, 0)
137
+ # segmentation = polygon.ravel().tolist()
138
+
139
+ return polygon
140
+
141
+ def measure_shape(binary_mask):
142
+ contour = binary_mask_to_polygon(binary_mask)
143
+ perimeter = np.sum(np.linalg.norm(contour[:-1] - contour[1:], axis=1))
144
+ area = np.sum(binary_mask)
145
+
146
+ return perimeter, area
147
+
148
+ def calculate_gsi(perimeter, height, area):
149
+ a = 0.5*(perimeter - 2*height)
150
+ return 1 - area/(a*height)
151
+
152
+ def object_slice(mask):
153
+ rows = np.any(mask, axis=1)
154
+ cols = np.any(mask, axis=0)
155
+ row_min, row_max = np.where(rows)[0][[0, -1]]
156
+ col_min, col_max = np.where(cols)[0][[0, -1]]
157
+
158
+ # Create a slice object for the bounding box
159
+ bounding_box_slice = (slice(row_min, row_max + 1), slice(col_min, col_max + 1))
160
+
161
+ return bounding_box_slice
162
+
163
+
164
+ def figure_to_pil(fig):
165
+ buf = BytesIO()
166
+ fig.savefig(buf, format='png')
167
+ buf.seek(0)
168
+
169
+ # Load the image from the buffer as a PIL Image
170
+ image = deepcopy(Image.open(buf))
171
+
172
+ # Close the buffer
173
+ buf.close()
174
+ return image
175
+
176
+ def resize_to(image: Image.Image, s: int=4032, return_factor: bool =False) -> Image.Image:
177
+ w, h = image.size
178
+ longest_size = max(h, w)
179
+
180
+ resize_factor = longest_size / s
181
+
182
+ resized_image = image.resize((int(w/resize_factor), int(h/resize_factor)))
183
+ if return_factor:
184
+ return resized_image, resize_factor
185
+ return resized_image
186
+
187
+ def resize_with_cv2(image, shape, return_pil=True):
188
+ """resize using cv2 with cv2.INTER_LINEAR - consistent with YOLO"""
189
+ h, w = shape
190
+ if isinstance(image, Image.Image):
191
+ image = np.array(image)
192
+
193
+ resized = cv2.resize(image, (w, h), interpolation=cv2.INTER_LINEAR)
194
+ if return_pil:
195
+ return Image.fromarray(resized)
196
+ else:
197
+ return resized
198
+
199
+ def select_unique_mask(mask):
200
+ """if mask consists of multiple parts, select the largest"""
201
+ if not np.any(mask): # if mask is empty, return without change
202
+ return mask
203
+ blobs = ndimage.label(mask)[0]
204
+ blob_labels, blob_sizes = np.unique(blobs, return_counts=True)
205
+ best_blob_label = blob_labels[1:][np.argmax(blob_sizes[1:])]
206
+ return blobs == best_blob_label
207
+
208
+ def sliced_mean(x, slice_size):
209
+ cs_y = np.cumsum(x, axis=0)
210
+ cs_y = np.concatenate((np.zeros((1, cs_y.shape[1]), dtype=cs_y.dtype), cs_y), axis=0)
211
+ slices_y = (cs_y[slice_size:] - cs_y[:-slice_size])/slice_size
212
+ cs_xy = np.cumsum(slices_y, axis=1)
213
+ cs_xy = np.concatenate((np.zeros((cs_xy.shape[0], 1), dtype=cs_xy.dtype), cs_xy), axis=1)
214
+ slices_xy = (cs_xy[:,slice_size:] - cs_xy[:,:-slice_size])/slice_size
215
+ return slices_xy
216
+
217
+ def sliced_var(x, slice_size):
218
+ x = x.astype('float64')
219
+ return sliced_mean(x**2, slice_size) - sliced_mean(x, slice_size)**2
220
+
221
+ def calculate_distance_map(mask):
222
+ padded = np.pad(mask, pad_width=1, mode='constant', constant_values=False)
223
+ distance_map_padded = ndimage.distance_transform_edt(padded)
224
+ return distance_map_padded[1:-1,1:-1]
225
+
226
+ def select_samples(granum_image, granum_mask, crop_size=96, n_samples=64, granum_fraction_min=0.75, variance_p=0.):
227
+ granum_occupancy = sliced_mean(granum_mask, crop_size)
228
+ possible_indices = np.stack(np.where(granum_occupancy >= granum_fraction_min), axis=1)
229
+
230
+ if variance_p == 0:
231
+ p = np.ones(len(possible_indices))
232
+ else:
233
+ variance_map = sliced_var(granum_image, crop_size)
234
+ p = variance_map[possible_indices[:,0], possible_indices[:,1]]**variance_p
235
+ p /= np.sum(p)
236
+
237
+ chosen_indices = np.random.choice(
238
+ np.arange(len(possible_indices)),
239
+ min(len(possible_indices), n_samples),
240
+ replace=False,
241
+ p = p
242
+ )
243
+
244
+ crops = []
245
+ for crop_idx, idx in enumerate(chosen_indices):
246
+ crops.append(
247
+ granum_image[
248
+ possible_indices[idx,0]:possible_indices[idx,0]+crop_size,
249
+ possible_indices[idx,1]:possible_indices[idx,1]+crop_size
250
+ ]
251
+ )
252
+ return np.array(crops)
253
+
254
+ def calculate_height(mask_oriented): #HACK
255
+ span = mask_oriented.shape[0] - np.argmax(mask_oriented[::-1], axis=0) - np.argmax(mask_oriented, axis=0)
256
+ return np.quantile(span, 0.8)
257
+
258
+ def calculate_diameter(mask_oriented):
259
+ """returns mean diameter"""
260
+ # calculate 0.25 and 0.75 lines
261
+ vertical_mask = np.any(mask_oriented, axis=1)
262
+ upper_granum_bound = np.argmax(vertical_mask)
263
+ lower_granum_bound = mask_oriented.shape[0] - np.argmax(vertical_mask[::-1])
264
+ upper = round(0.75*upper_granum_bound + 0.25*lower_granum_bound)
265
+ lower = max(upper+1, round(0.25*upper_granum_bound + 0.75*lower_granum_bound))
266
+ valid_rows_slice = slice(upper, lower)
267
+
268
+ # calculate diameters
269
+ span = mask_oriented.shape[1] - np.argmax(mask_oriented[valid_rows_slice,::-1], axis=1) - np.argmax(mask_oriented[valid_rows_slice], axis=1)
270
+ return np.mean(span)
271
+
272
+ def robust_mean(x, q=0.1):
273
+ x_med = np.median(x)
274
+ deviations = abs(x- x_med)
275
+ if max(deviations) == 0:
276
+ mask = np.ones(len(x), dtype='bool')
277
+ else:
278
+ threshold = np.quantile(deviations, 1-q)
279
+ mask = x[deviations<= threshold]
280
+
281
+ return np.mean(x[mask])
282
+
283
+ def rotate_image_and_mask(image, mask, direction):
284
+ mask_oriented = ndimage.rotate(mask.astype('int'), -direction, reshape=True).astype('bool')
285
+ idx_begin_x, idx_end_x = np.where(np.any(mask_oriented, axis=0))[0][np.array([0, -1])]
286
+ idx_begin_y, idx_end_y = np.where(np.any(mask_oriented, axis=1))[0][np.array([0, -1])]
287
+ img_oriented = ndimage.rotate(image, -direction, reshape=True) #[idx_begin_y:idx_end_y, idx_begin_x:idx_end_x]
288
+ return img_oriented, mask_oriented
289
+
290
+ class GranaAnalyser:
291
+ def __init__(self, weights_detector: str, weights_orientation: str, weights_period: str, period_sd_threshold_nm: float=2.5) -> None:
292
+ """
293
+ Initializes the GranaAnalyser with specified weights for detection and measuring.
294
+
295
+ This method loads the weights for the grana detection and measuring algorithms
296
+ from the specified file paths. It also loads mock data for visualization and
297
+ analysis purposes.
298
+
299
+ Parameters:
300
+ weights_detector (str): The file path to the weights file for the grana detection algorithm.
301
+ weights_orientation (str): The file path to the weights file for the grana orientation algorithm.
302
+ weights_period (str): The file path to the weights file for the grana period algorithm.
303
+ """
304
+ self.detector = YOLO(weights_detector)
305
+
306
+ self.orienter = PatchedPredictor(
307
+ StripsModelLumenWidth.load_from_checkpoint(weights_orientation, map_location='cpu').eval(),
308
+ normalization = dict(mean=0.250, std=0.135),
309
+ n_samples=32,
310
+ mask=None,
311
+ crop_size=64,
312
+ angle_confidence_threshold=0.2
313
+ )
314
+
315
+ self.measurement_px_per_nm = 1/0.768 # image scale required for measurement
316
+
317
+ self.period_measurer = PeriodMeasurer(
318
+ weights_period,
319
+ px_per_nm=self.measurement_px_per_nm,
320
+ sd_threshold_nm=period_sd_threshold_nm,
321
+ period_threshold_nm_min=14, period_threshold_nm_max=30
322
+ )
323
+
324
+
325
+ def get_grana_data(self, image, detections, scaler, border_margin=1, min_count=1) -> List[Granum]:
326
+ """filter detections and create grana data"""
327
+ image_numpy = np.array(image)
328
+ if image_numpy.ndim == 3:
329
+ image_numpy = image_numpy[:,:,0]
330
+
331
+ mask_all = None
332
+ grana = []
333
+ for mask, confidence in zip(
334
+ detections.masks.data.cpu().numpy().astype('bool'),
335
+ detections.boxes.conf.cpu().numpy()
336
+ ):
337
+ granum_mask = select_unique_mask(mask[scaler.unpad_slice])
338
+ # check if mask is empty after padding
339
+ if not np.any(granum_mask):
340
+ continue
341
+ granum_mask = ndimage.binary_fill_holes(granum_mask)
342
+
343
+ # check if touches boundary:
344
+ if (np.sum(granum_mask[:border_margin])>min_count) or \
345
+ (np.sum(granum_mask[-border_margin:])>min_count) or \
346
+ (np.sum(granum_mask[:,:border_margin])>min_count) or \
347
+ (np.sum(granum_mask[:,-border_margin:])>min_count):
348
+
349
+ continue
350
+
351
+ # check grana overlap
352
+ if mask_all is None:
353
+ mask_all = granum_mask
354
+ else:
355
+ intersection = mask_all & granum_mask
356
+
357
+ if intersection.sum() >= (granum_mask.sum() * 0.2):
358
+ continue
359
+ mask_all = mask_all | granum_mask
360
+
361
+ granum = Granum(
362
+ image = image,
363
+ mask = granum_mask,
364
+ scaler=scaler,
365
+ detection_confidence=float(confidence)
366
+ )
367
+
368
+ granum.image_numpy = image_numpy
369
+ grana.append(granum)
370
+ return grana
371
+
372
+ def measure_grana(self, grana: List[Granum], measurement_image: np.ndarray) -> List[Granum]:
373
+ """measure grana: includes orientation detection, period detection and geometric measurements"""
374
+ for granum in grana:
375
+ measurement_mask = resize_with_cv2(granum.mask.astype(np.uint8), measurement_image.shape[:2], return_pil=False).astype('bool')
376
+
377
+ granum.bounding_box_slice = object_slice(measurement_mask)
378
+ granum.image_crop = measurement_image[granum.bounding_box_slice][:,:]
379
+ granum.mask_crop = measurement_mask[granum.bounding_box_slice]
380
+
381
+ # initialize measurements
382
+ granum.measurements = {}
383
+
384
+ # measure shape
385
+ granum.measurements['perimeter px'], granum.measurements['area px'] = measure_shape(granum.mask_crop)
386
+
387
+ # measrure orientation
388
+ orienter_predictions = self.orienter(granum.image_crop, granum.mask_crop)
389
+ granum.measurements['direction'] = orienter_predictions["est_angle"]
390
+ granum.measurements['direction confidence'] = orienter_predictions["est_angle_confidence"]
391
+
392
+ if not np.isnan(granum.measurements["direction"]):
393
+ img_oriented, mask_oriented = rotate_image_and_mask(granum.image_crop, granum.mask_crop, granum.measurements["direction"])
394
+ oriented_granum_slice = object_slice(mask_oriented)
395
+ granum.img_oriented = img_oriented[oriented_granum_slice]
396
+ granum.mask_oriented = mask_oriented[oriented_granum_slice]
397
+ granum.measurements['height px'] = calculate_height(granum.mask_oriented)
398
+ granum.measurements['GSI'] = calculate_gsi(
399
+ granum.measurements['perimeter px'],
400
+ granum.measurements['height px'],
401
+ granum.measurements['area px']
402
+ )
403
+ granum.measurements['diameter px'] = calculate_diameter(granum.mask_oriented)
404
+
405
+ oriented_granum_slice = object_slice(granum.mask_oriented)
406
+ granum.measurements["period nm"], granum.measurements["period SD nm"] = self.period_measurer(granum.img_oriented, granum.mask_oriented)
407
+
408
+ if not pd.isna(granum.measurements['period nm']):
409
+ granum.measurements['Number of layers'] = round(granum.measurements['height px']/ self.measurement_px_per_nm / granum.measurements['period nm'])
410
+
411
+ return grana
412
+
413
+ def extract_grana_data(self, grana: List[Granum]) -> pd.DataFrame:
414
+ """collect and scale grana data"""
415
+ grana_data = []
416
+ for granum in grana:
417
+ granum_entry = {
418
+ 'Granum ID': granum.id,
419
+ 'detection confidence': granum.detection_confidence
420
+ }
421
+ # fill with None if absent:
422
+ for key in ['direction', 'Number of layers', 'GSI', 'period nm', 'period SD nm']:
423
+ granum_entry[key] = granum.measurements.get(key, None)
424
+ # scale linearly:
425
+ for key in ['height px', 'diameter px', 'perimeter px', 'perimeter px']:
426
+ granum_entry[f"{key[:-3]} nm"] = granum.measurements.get(key, np.nan) / self.measurement_px_per_nm
427
+ # scale quadratically
428
+ granum_entry['area nm^2'] = granum.measurements['area px'] / self.measurement_px_per_nm**2
429
+
430
+ grana_data.append(granum_entry)
431
+
432
+ return pd.DataFrame(grana_data)
433
+
434
+ def visualize_detections(self, grana: List[Granum], image: Image.Image) -> Image.Image:
435
+ visualization_longer_edge = 1024
436
+ scale = visualization_longer_edge/max(image.size)
437
+ visualization_size = (round(scale*image.size[0]), round(scale*image.size[1]))
438
+ visualization_image = np.array(image.resize(visualization_size).convert('RGB'))
439
+
440
+ if len(grana) > 0:
441
+ grana_mask = resize_with_cv2(
442
+ np.any(np.array([granum.mask for granum in grana]),axis=0).astype(np.uint8),
443
+ visualization_size[::-1],
444
+ return_pil=False
445
+ ).astype('bool')
446
+ visualization_image[grana_mask]= (0.7*visualization_image[grana_mask] + 0.3*np.array([[[39, 179, 115]]])).astype(np.uint8)
447
+
448
+
449
+ for granum in grana:
450
+ scale = visualization_longer_edge/max(granum.mask.shape)
451
+ y, x = ndimage.center_of_mass(granum.mask)
452
+ cv2.putText(visualization_image, f'{granum.id}', org=(int(x*scale)-10, int(y*scale)+10), fontFace=cv2.FONT_HERSHEY_SIMPLEX , fontScale=1, color=(39, 179, 115),thickness = 2)
453
+
454
+ return Image.fromarray(visualization_image)
455
+
456
+
457
+ def generate_grana_images(self, grana: List[Granum], image_name: str ="") -> List[Image.Image]:
458
+ grana_images = {}
459
+ for granum in grana:
460
+ fig, ax = plt.subplots()
461
+ if granum.img_oriented is None:
462
+ image_to_plot = granum.image_crop
463
+ mask_to_plot = granum.mask_crop
464
+ extra_caption = " orientation and period unknown"
465
+ else:
466
+ image_to_plot = granum.img_oriented
467
+ mask_to_plot = granum.mask_oriented
468
+ extra_caption = ""
469
+
470
+ ax.imshow(0.5*255*(~mask_to_plot) +image_to_plot*(1-0.5*(~mask_to_plot)), cmap='gray', vmin=0, vmax=255)
471
+ ax.axis('off')
472
+ ax.set_title(f'[{granum.id}]{image_name}\n{extra_caption}')
473
+ granum_image = figure_to_pil(fig)
474
+ grana_images[granum.id] = granum_image
475
+ plt.close('all')
476
+
477
+ return grana_images
478
+
479
+ def format_data(self, grana_data: pd.DataFrame) -> pd.DataFrame:
480
+ rounding_roles = {'area nm^2': 0, 'perimeter nm': 1, 'diameter nm': 1, 'height nm': 1, 'period nm': 1, 'period SD nm': 2, 'GSI':2, 'direction': 1}
481
+ rounded_data = grana_data.round(rounding_roles)
482
+ columns_order = ['Granum ID', 'File name', 'area nm^2', 'perimeter nm', 'GSI','diameter nm', 'height nm', 'Number of layers','period nm', 'period SD nm', 'direction']
483
+ return rounded_data[columns_order]
484
+
485
+
486
+ def aggregate_data(self, grana_data: pd.DataFrame, confidence: Optional[float]=None) -> Dict:
487
+ if confidence is None:
488
+ filtered = grana_data
489
+ else:
490
+ filtered = grana_data.loc[grana_data['aggregated confidence'] >= confidence]
491
+ aggregation = filtered[['area nm^2', 'perimeter nm', 'diameter nm', 'height nm', 'Number of layers', 'period nm', 'GSI']].mean().to_dict()
492
+ aggregation_std = filtered[['area nm^2', 'perimeter nm', 'diameter nm', 'height nm', 'Number of layers', 'period nm', 'GSI']].std().to_dict()
493
+ aggregation_std = {f"{k} std": v for k, v in aggregation_std.items()}
494
+ aggregation_result = {**aggregation, **aggregation_std, 'N grana': len(filtered)}
495
+ return aggregation_result
496
+
497
+ def predict_on_single(self, image: Image.Image, scale: float, detection_confidence: float=0.25, granum_id_start=1, image_name: str = "") -> Tuple[List[Image.Image], pd.DataFrame, List[Image.Image]]:
498
+ """
499
+ Predicts and aggregates data related to grana using a dictionary of images.
500
+
501
+ Parameters:
502
+ image (Image.Image): PIL Image object to be analyzed
503
+ scale (float): scale of the image: px per nm.
504
+ detection_confidence (float): The detection confidence threshold shape measurement
505
+
506
+ Returns:
507
+ Tuple[Image.Image, pandas.DataFrame, List[Image.Image]]:
508
+ A tuple containing:
509
+ - detection_visualization (Image.Image): PIL image representing
510
+ the detection visualizations.
511
+ - grana_data (pandas.DataFrame): A DataFrame containing the simulated granum data.
512
+ - grana_images (List[Image.Image]): A list of PIL images of the grana.
513
+ """
514
+ # convert to grayscale
515
+ image = image.convert("L")
516
+
517
+ # detect
518
+ scaler = ScalerPadder(target_size=1024, target_short_edge_min=640)
519
+ scaled_image = scaler.transform(image, px_per_nm=scale)
520
+ detections = self.detector.predict(source=scaled_image, conf=detection_confidence)[0]
521
+
522
+ # get grana data
523
+ grana = self.get_grana_data(image, detections, scaler)
524
+ for granum_id, granum in enumerate(grana, start=granum_id_start):
525
+ granum.id = granum_id
526
+
527
+
528
+ # visualize detections
529
+ detection_visualization = self.visualize_detections(grana, image)
530
+
531
+ # measure grana
532
+ measurement_image_resize_factor = self.measurement_px_per_nm / scale
533
+ measurement_image_shape = (
534
+ int(image.size[1]*measurement_image_resize_factor),
535
+ int(image.size[0]*measurement_image_resize_factor)
536
+ )
537
+ measurement_image = resize_with_cv2( # numpy image in scale valid for measurement
538
+ image, measurement_image_shape, return_pil=False
539
+ )
540
+ grana = self.measure_grana(grana, measurement_image)
541
+
542
+ # pandas DataFrame
543
+ grana_data = self.extract_grana_data(grana)
544
+
545
+ # list of PIL images
546
+ grana_images = self.generate_grana_images(grana, image_name=image_name)
547
+
548
+ return detection_visualization, grana_data, grana_images
549
+
550
+ def predict(self, images: Dict[str, Image.Image], scale: float, detection_confidence: float=0.25, parameter_confidence: Optional[float]=None) -> Tuple[List[Image.Image], pd.DataFrame, List[Image.Image], Dict]:
551
+ """
552
+ Predicts and aggregates data related to grana using a dictionary of images.
553
+
554
+ Parameters:
555
+ images (Dict[str, Image.Image]): A dictionary of PIL Image objects to be analyzed,
556
+ keyed by their names.
557
+ scale (float): scale of the image: px per nm
558
+ detection_confidence (float): The detection confidence threshold shape measurement
559
+ parameter_confidence (float): The confidence threshold used for data aggregation. Only
560
+ data with aggregated confidence above this threshold will
561
+ be considered.
562
+
563
+ Returns:
564
+ Tuple[List[Image.Image], pandas.DataFrame, List[Image.Image], Dict]:
565
+ A tuple containing:
566
+ - detection_visualizations (List[Image.Image]): A list of PIL images representing
567
+ the detection visualizations.
568
+ - grana_data (pandas.DataFrame): A DataFrame containing the simulated granum data.
569
+ - grana_images (List[Image.Image]): A list of PIL images of the grana.
570
+ - aggregated_data (Dict): A dictionary containing the aggregated data results.
571
+ """
572
+ detection_visualizations_all = {}
573
+ grana_data_all = None
574
+ grana_images_all = {}
575
+
576
+ granum_id_start = 1
577
+ for image_name, image in images.items():
578
+ detection_visualization, grana_data, grana_images = self.predict_on_single(image, scale=scale, detection_confidence=detection_confidence, granum_id_start=granum_id_start, image_name=image_name)
579
+ granum_id_start += len(grana_data)
580
+ detection_visualizations_all[image_name] = detection_visualization
581
+ grana_images_all.update(grana_images)
582
+
583
+ grana_data['File name'] = image_name
584
+ if grana_data_all is None:
585
+ grana_data_all = grana_data
586
+ else:
587
+ # grana_data['Granum ID'] += len(grana_data_all)
588
+ grana_data_all = pd.concat([grana_data_all, grana_data])
589
+
590
+ # dict
591
+ # grana_data_all.to_csv('grana_data_all.csv', index=False)
592
+ aggregated_data = self.aggregate_data(grana_data_all, parameter_confidence)
593
+
594
+ formatted_grana_data = self.format_data(grana_data_all)
595
+
596
+ return detection_visualizations_all, formatted_grana_data, grana_images_all, aggregated_data
597
+
598
+
599
+ class GranaDetector(GranaAnalyser):
600
+ """supplementary class for grana detection only
601
+ """
602
+ def __init__(self, weights_detector: str, detector_config: Optional[str] = None, model_type="yolo") -> None:
603
+
604
+ if model_type == "yolo":
605
+ self.detector = YOLO(weights_detector)
606
+ elif model_type == "mmdetection":
607
+ self.detector = MMDetector(model=detector_config, weights=weights_detector)
608
+ else:
609
+ raise NotImplementedError()
610
+
611
+ def predict_on_single(self, image: Image.Image, scale: float, detection_confidence: float=0.25, granum_id_start=1, use_scaling=True, granum_border_margin=1, granum_border_min_count=1, scaler_sizes=(1024, 640)) -> List[Granum]:
612
+ # convert to grayscale
613
+ image = image.convert("L")
614
+
615
+ # detect
616
+ if use_scaling:
617
+ scaler = ScalerPadder(target_size=scaler_sizes[0], target_short_edge_min=scaler_sizes[1])
618
+ else:
619
+ #dummy scaler
620
+ scaler = ScalerPadder(target_size=max(image.size), target_short_edge_min=min(image.size), minimal_pad=0, pad_to_multiply=1)
621
+ scaled_image = scaler.transform(image, scale=scale)
622
+ detections = self.detector.predict(source=scaled_image, conf=detection_confidence)[0]
623
+
624
+ # get grana data
625
+ grana = self.get_grana_data(image, detections, scaler, border_margin=granum_border_margin, min_count=granum_border_min_count)
626
+ for i_granum, granum in enumerate(grana, start=1):
627
+ granum.id = i_granum
628
+
629
+ return grana
period_calculation/config.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ import albumentations as A
3
+ from albumentations.pytorch import ToTensorV2
4
+
5
+
6
+ transforms = [
7
+ A.Normalize(**{'mean': 0.2845, 'std': 0.1447}, max_pixel_value=1.0),
8
+ # Applies the formula (img - mean * max_pixel_value) / (std * max_pixel_value)
9
+ ToTensorV2()
10
+ ]
11
+
12
+ model_config = {
13
+ 'receptive_field_height': 220,
14
+ 'receptive_field_width': 38,
15
+ 'stride_height': 64,
16
+ 'stride_width': 2,
17
+ 'image_height': 476,
18
+ 'image_width': 476}
19
+
period_calculation/data_reader.py ADDED
@@ -0,0 +1,861 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import skimage.io
4
+ from pathlib import Path
5
+ import torch
6
+ import scipy
7
+ from PIL import Image, ImageFilter, ImageChops
8
+ # from config import model_config
9
+ from period_calculation.config import model_config
10
+
11
+ # Function to add Gaussian noise
12
+
13
+ def add_microscope_noise(base_image_as_numpy, noise_intensity):
14
+ ###### The code below is for adding noise to the image
15
+ # noise intensity is a number between 0 and 1
16
+ # --- priginal implementation was provided by Michał Bykowski
17
+ # --- and adapted
18
+ # This routine works with PIL images and numpy internally (changing formats as it goes)
19
+ # but the input and output are numpy arrays
20
+
21
+ def add_noise(image, mean=0, std_dev=50): # std_dev impacts the amount of noise
22
+ # Generating noise
23
+ noise = np.random.normal(mean, std_dev, (image.height, image.width))
24
+ # Adding noise to the image
25
+ noisy_image = np.array(image) + noise
26
+ # Ensuring the values remain within valid grayscale range
27
+ noisy_image = np.clip(noisy_image, 0, 255)
28
+ return Image.fromarray(noisy_image.astype('uint8'))
29
+
30
+
31
+ base_image = Image.fromarray(base_image_as_numpy)
32
+ gray_value = 128
33
+ gray = Image.new('L', base_image.size, color=gray_value)
34
+
35
+
36
+ gray = add_noise(gray, std_dev=noise_intensity * 76)
37
+ gray = gray.filter(ImageFilter.GaussianBlur(radius=3))
38
+ gray = add_noise(gray, std_dev=noise_intensity * 23)
39
+ gray = gray.filter(ImageFilter.GaussianBlur(radius=2))
40
+ gray = add_noise(gray, std_dev=noise_intensity * 15)
41
+
42
+ # soft light works as in Photoshop
43
+ # Superimposes two images on top of each other using the Soft Light algorithm
44
+ result = ImageChops.soft_light(base_image, gray)
45
+
46
+ return np.array(result)
47
+
48
+ def detect_boundaries(mask, axis):
49
+ # calculate the boundaries of the mask
50
+ #axis = 0 results in x_from, x_to
51
+ #axis = 1 results in y_from, y_to
52
+
53
+
54
+ sum = mask.sum(axis=axis)
55
+
56
+ ind_from = min(sum.nonzero()[0])
57
+ ind_to = max(sum.nonzero()[0])
58
+ return ind_from, ind_to
59
+
60
+ def add_symmetric_filling_beyond_mask(img, mask):
61
+ for x in range(img.shape[1]):
62
+ if sum(mask[:, x]) != 0: #if there is at least one nonzero index
63
+ nonzero_indices = mask[:, x].nonzero()[0]
64
+
65
+ y_min = min(nonzero_indices)
66
+ y_max = max(nonzero_indices)
67
+
68
+ if y_max == y_min: #there is only one point
69
+ img[:, x] = img[y_min, x]
70
+ else:
71
+ next = y_min + 1
72
+ step = +1 # we start by going upwards
73
+ for y in reversed(range(y_min)):
74
+ img[y, x] = img[next, x]
75
+ if next == y_max or next == y_min: #we hit the boundaries - we reverse
76
+ step *= -1 #reverse direction
77
+ next += step
78
+
79
+ next = y_max - 1
80
+ step = -1 # we start by going downwards
81
+ for y in range(y_max + 1, img.shape[0]): #we hit the boundaries - we reverse
82
+ img[y, x] = img[next, x]
83
+ if next == y_max or next == y_min:
84
+ step *= -1 # reverse direction
85
+ next += step
86
+ return img
87
+ class AbstractDataset(torch.utils.data.Dataset):
88
+
89
+ def __init__(self,
90
+ model = None,
91
+ transforms=[],
92
+ #### distortions during training ####
93
+ hv_symmetry=True, # True or False
94
+
95
+ min_horizontal_subsampling = 50, # None to turn off; or minimal percentage of horizontal size of the image
96
+ min_vertical_subsampling = 70, # None to turn off; or minimal percentage of vertical size of the image
97
+ max_random_tilt = 3, # None to turn off; or maximum tilt in degrees
98
+ max_add_colors_to_histogram = 10, # 0 to turn off; or points of the histogram to be added
99
+ max_remove_colors_from_histogram = 30, # 0 to turn off; or points of the histogram to be removed
100
+ max_noise_intensity = 3.0, # 0.0 to turn off; or max intensity of the noise
101
+
102
+ gaussian_phase_transforms_epoch=None, # None to turn off; or number of the epoch when the gaussian phase starts
103
+ min_horizontal_subsampling_gaussian_phase = 30, # None to turn off; or minimal percentage of horizontal size of the image
104
+ min_vertical_subsampling_gaussian_phase = 70, # None to turn off; or minimal percentage of vertical size of the image
105
+ max_random_tilt_gaussian_phase = 2, # None to turn off; or maximum tilt in degrees
106
+ max_add_colors_to_histogram_gaussian_phase = 10, # 0 to turn off; or points of the histogram to be added
107
+ max_remove_colors_from_histogram_gaussian_phase = 60, # 0 to turn off; or points of the histogram to be removed
108
+ max_noise_intensity_gaussian_phase = 3.5, # 0.0 to turn off; or max intensity of the noise
109
+
110
+ #### controling variables ####
111
+ transform_level=2, # 0 - no transforms, 1 - only the basic transform, 2 - all transforms, -1 - subsampling for high images
112
+ retain_raw_images=False,
113
+ retain_masks=False):
114
+
115
+
116
+ self.model = model # we need that to check epoch number during training
117
+
118
+ self.hv_symmetry = hv_symmetry
119
+
120
+ self.min_horizontal_subsampling = min_horizontal_subsampling
121
+ self.min_vertical_subsampling = min_vertical_subsampling
122
+ self.max_random_tilt = max_random_tilt
123
+ self.max_add_colors_to_histogram = max_add_colors_to_histogram
124
+ self.max_remove_colors_from_histogram = max_remove_colors_from_histogram
125
+ self.max_noise_intensity = max_noise_intensity
126
+
127
+ self.gaussian_phase_transforms_epoch = gaussian_phase_transforms_epoch
128
+ self.min_horizontal_subsampling_gaussian_phase = min_horizontal_subsampling_gaussian_phase
129
+ self.min_vertical_subsampling_gaussian_phase = min_vertical_subsampling_gaussian_phase
130
+ self.max_random_tilt_gaussian_phase = max_random_tilt_gaussian_phase
131
+ self.max_add_colors_to_histogram_gaussian_phase = max_add_colors_to_histogram_gaussian_phase
132
+ self.max_remove_colors_from_histogram_gaussian_phase = max_remove_colors_from_histogram_gaussian_phase
133
+ self.max_noise_intensity_gaussian_phase = max_noise_intensity_gaussian_phase
134
+
135
+ self.image_height = model_config['image_height']
136
+ self.image_width = model_config['image_width']
137
+
138
+ self.transform_level = transform_level
139
+ self.retain_raw_images = retain_raw_images
140
+ self.retain_masks = retain_masks
141
+ self.transforms = transforms
142
+
143
+
144
+ def get_image_and_mask(self, row):
145
+ raise NotImplementedError("Subclass needs to implement this method")
146
+
147
+ def load_and_transform_image_and_mask(self, row):
148
+ img, mask = self.get_image_and_mask(row)
149
+
150
+ angle = row['angle']
151
+ #check if gaussian phase is on
152
+ if self.gaussian_phase_transforms_epoch is not None and self.model.current_epoch >= self.gaussian_phase_transforms_epoch:
153
+ max_random_tilt = self.max_random_tilt_gaussian_phase
154
+ max_noise_intensity = self.max_noise_intensity_gaussian_phase
155
+ min_horizontal_subsampling = self.min_horizontal_subsampling_gaussian_phase
156
+ min_vertical_subsampling = self.min_vertical_subsampling_gaussian_phase
157
+ max_add_colors_to_histogram = self.max_add_colors_to_histogram_gaussian_phase
158
+ max_remove_colors_from_histogram = self.max_remove_colors_from_histogram_gaussian_phase
159
+ else:
160
+ max_random_tilt = self.max_random_tilt
161
+ max_noise_intensity = self.max_noise_intensity
162
+ min_horizontal_subsampling = self.min_horizontal_subsampling
163
+ min_vertical_subsampling = self.min_vertical_subsampling
164
+ max_add_colors_to_histogram = self.max_add_colors_to_histogram
165
+ max_remove_colors_from_histogram = self.max_remove_colors_from_histogram
166
+
167
+
168
+
169
+
170
+
171
+
172
+ if self.transform_level >= 2 and max_random_tilt is not None:
173
+ ####### RANDOM TILT
174
+ angle += np.random.uniform(-max_random_tilt, max_random_tilt)
175
+
176
+ img = scipy.ndimage.rotate(img, 90 - angle, reshape=True, order=3) # HORIZONTAL POSITION
177
+ ###the part of the image that is added after rotation is all black (0s)
178
+ mask = scipy.ndimage.rotate(mask, 90 - angle, reshape=True, order = 0) # HORIZONTAL POSITION
179
+ #order = 0 is the nearest neighbor interpolation, so the mask is not interpolated
180
+
181
+ ############# CROP
182
+ x_from, x_to = detect_boundaries(mask, axis=0)
183
+ y_from, y_to = detect_boundaries(mask, axis=1)
184
+
185
+ #crop the image to the verical and horizontal limits.
186
+ img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
187
+ mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
188
+
189
+
190
+ img_raw = img.copy()
191
+
192
+
193
+ if self.transform_level >= 2:
194
+ ########## ADDING NOISE
195
+
196
+ if max_noise_intensity > 0.0:
197
+ noise_intensity = np.random.random() * max_noise_intensity
198
+ noisy_img = add_microscope_noise(img, noise_intensity=noise_intensity)
199
+ img[mask] = noisy_img[mask]
200
+
201
+ if self.transform_level == -1:
202
+ #special case where we take at most 300 middle pixels from the image
203
+ # (vertical subsampling)
204
+ # to handle very latge images correctly
205
+ x_from, x_to = detect_boundaries(mask, axis=0)
206
+ y_from, y_to = detect_boundaries(mask, axis=1)
207
+
208
+ y_size = y_to - y_from + 1
209
+
210
+ random_size = 300 #not so random, ay?
211
+
212
+ if y_size > random_size:
213
+ random_start = y_size // 2 - random_size // 2
214
+
215
+ y_from = random_start
216
+ y_to = random_start + random_size - 1
217
+
218
+ img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
219
+ mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
220
+
221
+ # recrop the image if necessary
222
+ # -- even after only horizontal subsampling it may be necessary to recrop the image
223
+
224
+ x_from, x_to = detect_boundaries(mask, axis=0)
225
+ y_from, y_to = detect_boundaries(mask, axis=1)
226
+
227
+ img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
228
+ mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
229
+
230
+ if self.transform_level >= 1:
231
+ ############## HORIZONTAL SUBSAMPLING
232
+ if min_horizontal_subsampling is not None:
233
+ x_size = x_to - x_from + 1
234
+
235
+ # add some random horizontal shift
236
+ random_size = np.random.randint(x_size * min_horizontal_subsampling / 100.0, x_size + 1)
237
+ random_start = np.random.randint(0, x_size - random_size + 1) + x_from
238
+
239
+ img = img[:, random_start:(random_start + random_size)]
240
+ mask = mask[:, random_start:(random_start + random_size)]
241
+
242
+ ############ VERTICAL SUBSAMPLING
243
+ if min_vertical_subsampling is not None:
244
+
245
+ x_from, x_to = detect_boundaries(mask, axis=0)
246
+ y_from, y_to = detect_boundaries(mask, axis=1)
247
+
248
+ y_size = y_to - y_from + 1
249
+
250
+ random_size = np.random.randint(y_size * min_vertical_subsampling / 100.0, y_size + 1)
251
+ random_start = np.random.randint(0, y_size - random_size + 1) + y_from
252
+
253
+ y_from = random_start
254
+ y_to = random_start + random_size - 1
255
+
256
+ img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
257
+ mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
258
+
259
+ if min_horizontal_subsampling is not None or min_vertical_subsampling is not None:
260
+ #recrop the image if necessary
261
+ # -- even after only horizontal subsampling it may be necessary to recrop the image
262
+
263
+ x_from, x_to = detect_boundaries(mask, axis=0)
264
+ y_from, y_to = detect_boundaries(mask, axis=1)
265
+
266
+ img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
267
+ mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
268
+
269
+
270
+ ######### ADD SYMMETRIC FILLING OF THE IMAGE BEYOND THE MASK
271
+ #img = add_symmetric_filling_beyond_mask(img, mask)
272
+ #This leaves holes in the image, so we will not use it
273
+
274
+ #plt.imshow(img)
275
+ #plt.show()
276
+ ######### HORIZONTAL AND VERTICAL SYMMETRY.
277
+ # When superimposed, the result is 180 degree rotation
278
+ if self.transform_level >= 1 and self.hv_symmetry:
279
+ for axis in range(2):
280
+ if np.random.randint(0, 2) % 2 == 0:
281
+ img = np.flip(img, axis = axis)
282
+ mask = np.flip(mask, axis = axis)
283
+ #plt.imshow(img)
284
+ #plt.show()
285
+
286
+ if self.transform_level >= 2 and (max_add_colors_to_histogram > 0 or max_remove_colors_from_histogram > 0):
287
+ lower_bound = np.random.randint(-max_add_colors_to_histogram, max_remove_colors_from_histogram + 1)
288
+ upper_bound = np.random.randint(255 - max_remove_colors_from_histogram, 255 + max_add_colors_to_histogram + 1)
289
+ # first clip the values outstanding from the range (lower_bound -- upper_bound)
290
+ img[mask] = np.clip(img[mask], lower_bound, upper_bound)
291
+ # the range (lower_bound -- upper_bound) gets mapped to the range (0--255)
292
+ # but only in a portion of the image where mask = True
293
+ img[mask] = np.interp(img[mask], (lower_bound, upper_bound), (0, 255)).astype(np.uint8)
294
+
295
+ #### since preserve_range in skimage.transform.resize is set to False, the image
296
+ #### will be converted to float. Consult:
297
+ # https://scikit-image.org/docs/stable/api/skimage.transform.html#skimage.transform.resize
298
+ # https://scikit-image.org/docs/dev/user_guide/data_types.html
299
+
300
+ # In our case the image gets conveted to floats ranging 0-1
301
+ old_height = img.shape[0]
302
+ img = skimage.transform.resize(img, (self.image_height, self.image_width), order=3)
303
+ new_height = img.shape[0]
304
+ mask = skimage.transform.resize(mask, (self.image_height, self.image_width), order=0, preserve_range=True)
305
+ # order = 0 is the nearest neighbor interpolation, so the mask is not interpolated
306
+
307
+ scale_factor = new_height / old_height
308
+
309
+
310
+ #plt.imshow(img)
311
+ #plt.show()
312
+ #plt.imshow(mask)
313
+ #plt.show()
314
+ return img, mask, scale_factor, img_raw
315
+
316
+ def get_annotations_row(self, idx):
317
+ raise NotImplementedError("Subclass needs to implement this method")
318
+
319
+ def __getitem__(self, idx):
320
+ row = self.get_annotations_row(idx)
321
+
322
+ image, mask, scale_factor, image_raw = self.load_and_transform_image_and_mask(row)
323
+
324
+ image_data = {
325
+ 'image': image,
326
+ }
327
+
328
+ for transform in self.transforms:
329
+ image_data = transform(**image_data)
330
+ # transform operates on image field ONLY of image_data, and returns a dictionary with the same keys
331
+
332
+ ret_dict = {
333
+ 'image': image_data['image'],
334
+ 'period_px': torch.tensor(row['period_nm'] * scale_factor * row['px_per_nm'], dtype=torch.float32),
335
+ 'filename': row['granum_image'],
336
+ 'px_per_nm': row['px_per_nm'],
337
+ 'scale': scale_factor, # the scale factor is used to calculate the true period error
338
+ # (before scale) in losses and metrics
339
+ 'neutral': -self.transforms[0].mean/self.transforms[0].std #value of 0 after the scale transform
340
+ }
341
+
342
+ if self.retain_raw_images:
343
+ ret_dict['image_raw'] = image_raw
344
+
345
+ if self.retain_masks:
346
+ ret_dict['mask'] = mask
347
+
348
+ return ret_dict
349
+
350
+ def __len__(self):
351
+ raise NotImplementedError("Subclass needs to implement this method")
352
+
353
+ class ImageDataset(AbstractDataset):
354
+ def __init__(self, annotations, data_dir: Path, *args, **kwargs):
355
+ super().__init__(*args, **kwargs)
356
+ self.data_dir = Path(data_dir)
357
+
358
+ self.id = 1
359
+
360
+ if isinstance(annotations, str):
361
+ annotations = data_dir / annotations #make it a Path object relative to data_dir
362
+
363
+ if isinstance(annotations, Path):
364
+ self.annotations = pd.read_csv(data_dir / annotations)
365
+ no_period = ['27_k7 [1]_4.png']
366
+ del_img = ['38_k42[1]_19.png', 'n6363_araLL_60kx_6 [1]_0.png', '27_hs8 [1]_5.png', '27_k7 [1]_20.png',
367
+ 'F1_1_60kx_01 [1]_2.png']
368
+ self.annotations = self.annotations[~self.annotations['granum_image'].isin(no_period)]
369
+ self.annotations = self.annotations[~self.annotations['granum_image'].isin(del_img)]
370
+ else:
371
+ self.annotations = annotations
372
+
373
+ def get_image_and_mask(self, row):
374
+ filename = row['granum_image']
375
+ img_path = self.data_dir / filename
376
+ img_raw = skimage.io.imread(img_path)
377
+
378
+ img = img_raw[:, :, 0] # all three channels are equal, with the exception
379
+ # of the last channel which is the full blue (0,0,255) for outside the mask (so blue channel is 255, red and green are 0)
380
+ mask = (img_raw != (0, 0, 255)).any(axis=2)
381
+ return img, mask
382
+
383
+ def get_annotations_row(self, idx):
384
+ row = self.annotations.iloc[idx].to_dict()
385
+ row['idx'] = idx
386
+ return row
387
+
388
+ def __len__(self):
389
+ return len(self.annotations)
390
+
391
+ class ArtificialDataset(AbstractDataset):
392
+ def __init__(self,
393
+ min_period = 20,
394
+ max_period = 140,
395
+ white_fraction_min = 0.15,
396
+ white_fraction_max=0.45,
397
+
398
+ noise_min_sd = 0.0,
399
+ noise_max_sd = 100.0,
400
+ noise_max_sd_everywhere = 20.0, # 20.0
401
+ leftovers_max = 5,
402
+
403
+ get_real_masks_dataset = None, #None or instance of ImageDataset
404
+ *args, **kwargs):
405
+ super().__init__(*args, **kwargs)
406
+ self.id = 0
407
+ self.min_period = min_period
408
+ self.max_period = max_period
409
+ self.white_fraction_min = white_fraction_min
410
+ self.white_fraction_max = white_fraction_max
411
+
412
+ self.receptive_field_height = model_config['receptive_field_height']
413
+ self.stride_height = model_config['stride_height']
414
+ self.receptive_field_width = model_config['receptive_field_width']
415
+ self.stride_width = model_config['stride_width']
416
+
417
+ self.noise_min_sd = noise_min_sd
418
+ self.noise_max_sd = noise_max_sd
419
+ self.noise_max_sd_everywhere = noise_max_sd_everywhere
420
+
421
+ self.leftovers_max = leftovers_max
422
+
423
+ self.get_real_masks_dataset = get_real_masks_dataset
424
+
425
+
426
+ def get_image_and_mask(self, row):
427
+ # generate a rectangular image of black and white horizontal stripes
428
+ # with black stripes varying with white stripes
429
+
430
+ period_px = row['period_nm'] * row['px_per_nm']
431
+ # white occupying 5-20 % of a total period (white+black)
432
+ white_px = np.random.randint(period_px * self.white_fraction_min, period_px * self.white_fraction_max + 1)
433
+
434
+
435
+ # mask is rectangle of True values
436
+ img = np.zeros((self.image_height, self.image_width), dtype=np.uint8)
437
+ mask = np.ones((self.image_height, self.image_width), dtype=bool)
438
+ black_px = period_px - white_px
439
+ random_start = np.random.randint(0, period_px+1)
440
+ for i in range(self.image_height):
441
+ if (random_start+i) % (black_px + white_px) < black_px:
442
+ # sample width with random numbers from 0 to 101
443
+ img[i, :] = np.random.randint(0, 101, self.image_width)
444
+ else:
445
+ #sample width with random numbers from 156 to 255
446
+ img[i, :] = np.random.randint(156, 256, self.image_width)
447
+
448
+ if self.noise_max_sd_everywhere > self.noise_min_sd:
449
+ sd = np.random.uniform(self.noise_min_sd, self.noise_max_sd_everywhere)
450
+ noise = np.random.normal(0, sd, (self.image_height, self.image_width))
451
+ img = np.clip(img+noise.astype(img.dtype), 0, 255)
452
+
453
+ if self.noise_max_sd > self.noise_min_sd:
454
+ # there is also a metagrid in the image
455
+ # consisting of overlapping receptive fields of size 190x42
456
+ # with stride 64x4
457
+ # the metagrid is 5x102
458
+ overlapping_fields_count_height = (self.image_height - self.receptive_field_height) // self.stride_height + 1
459
+ overlapping_fields_count_width = (self.image_width - self.receptive_field_width) // self.stride_width + 1
460
+
461
+
462
+ sd = np.random.uniform(self.noise_min_sd, self.noise_max_sd)
463
+ noise = np.random.normal(0, sd, (self.image_height, self.image_width))
464
+
465
+ #there will be some left-over metagrid rectangles
466
+ leftovers_count = np.random.randint(1, self.leftovers_max)
467
+ for i in range(leftovers_count):
468
+ metagrid_row = np.random.randint(0, overlapping_fields_count_height)
469
+ metagrid_col = np.random.randint(0, overlapping_fields_count_width)
470
+ #zero-out the noise inside the selected metagrid
471
+ noise[metagrid_row * self.stride_height:metagrid_row * self.stride_height + self.receptive_field_height + 1, \
472
+ metagrid_col * self.stride_width :metagrid_col * self.stride_width + self.receptive_field_width + 1] = 0
473
+
474
+ #add noise to the image
475
+ img = np.clip(img+noise.astype(img.dtype), 0, 255)
476
+
477
+ if self.get_real_masks_dataset is not None:
478
+ ret_dict = self.get_real_masks_dataset.__getitem__(row['idx'] % len(self.get_real_masks_dataset))
479
+ mask = ret_dict['mask'] #this mask is already sized target height-by-width
480
+
481
+ img[mask == False] = 0
482
+
483
+ return img, mask
484
+
485
+ def get_annotations_row(self, idx):
486
+ return {'idx': idx,
487
+ 'period_nm': np.random.randint(self.min_period, self.max_period),
488
+ 'px_per_nm': 1.0,
489
+ 'granum_image': 'artificial_%d.png' % idx,
490
+ 'angle': 90}
491
+
492
+
493
+ def __len__(self):
494
+ return 237 # number of samples as in real data in the train set (70% of 339 is 237,3)
495
+
496
+
497
+ class AdHocDataset(AbstractDataset):
498
+ def __init__(self, images_masks_pxpernm: list[tuple[np.ndarray, np.ndarray, float]], *args, **kwargs):
499
+ super().__init__(*args, **kwargs)
500
+ self.data = images_masks_pxpernm
501
+
502
+ def __len__(self):
503
+ return len(self.data)
504
+
505
+ def __getitem__(self, idx):
506
+ image, mask, px_per_nm = self.data[idx]
507
+
508
+ image, mask, scale_factor, image_raw = self.load_and_transform_image_and_mask(image, mask)
509
+
510
+ image_data = {
511
+ 'image': image,
512
+ }
513
+
514
+ for transform in self.transforms:
515
+ image_data = transform(**image_data)
516
+ # transform operates on image field ONLY of image_data, and returns a dictionary with the same keys
517
+
518
+ ret_dict = {
519
+ 'image': image_data['image'],
520
+ 'period_px': torch.tensor(0, dtype=torch.float32),
521
+ 'filename': str(idx),
522
+ 'px_per_nm': px_per_nm,
523
+ 'scale': scale_factor, # the scale factor is used to calculate the true period error
524
+ # (before scale) in losses and metrics
525
+ 'neutral': -self.transforms[0].mean/self.transforms[0].std #value of 0 after the scale transform
526
+ }
527
+
528
+ if self.retain_raw_images:
529
+ ret_dict['image_raw'] = image_raw
530
+
531
+ if self.retain_masks:
532
+ ret_dict['mask'] = mask
533
+
534
+ return ret_dict
535
+
536
+
537
+ def load_and_transform_image_and_mask(self, img, mask):
538
+
539
+ angle = 90
540
+ #check if gaussian phase is on
541
+ if self.gaussian_phase_transforms_epoch is not None and self.model.current_epoch >= self.gaussian_phase_transforms_epoch:
542
+ max_random_tilt = self.max_random_tilt_gaussian_phase
543
+ max_noise_intensity = self.max_noise_intensity_gaussian_phase
544
+ min_horizontal_subsampling = self.min_horizontal_subsampling_gaussian_phase
545
+ min_vertical_subsampling = self.min_vertical_subsampling_gaussian_phase
546
+ max_add_colors_to_histogram = self.max_add_colors_to_histogram_gaussian_phase
547
+ max_remove_colors_from_histogram = self.max_remove_colors_from_histogram_gaussian_phase
548
+ else:
549
+ max_random_tilt = self.max_random_tilt
550
+ max_noise_intensity = self.max_noise_intensity
551
+ min_horizontal_subsampling = self.min_horizontal_subsampling
552
+ min_vertical_subsampling = self.min_vertical_subsampling
553
+ max_add_colors_to_histogram = self.max_add_colors_to_histogram
554
+ max_remove_colors_from_histogram = self.max_remove_colors_from_histogram
555
+
556
+
557
+ if self.transform_level >= 2 and max_random_tilt is not None:
558
+ ####### RANDOM TILT
559
+ angle += np.random.uniform(-max_random_tilt, max_random_tilt)
560
+
561
+
562
+ img = scipy.ndimage.rotate(img, 90 - angle, reshape=True, order=3) # HORIZONTAL POSITION
563
+ ###the part of the image that is added after rotation is all black (0s)
564
+ mask = scipy.ndimage.rotate(mask, 90 - angle, reshape=True, order = 0) # HORIZONTAL POSITION
565
+ #order = 0 is the nearest neighbor interpolation, so the mask is not interpolated
566
+
567
+ ############# CROP
568
+ x_from, x_to = detect_boundaries(mask, axis=0)
569
+ y_from, y_to = detect_boundaries(mask, axis=1)
570
+
571
+ #crop the image to the verical and horizontal limits.
572
+ img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
573
+ mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
574
+
575
+
576
+ img_raw = img.copy()
577
+
578
+
579
+ if self.transform_level >= 2:
580
+ ########## ADDING NOISE
581
+
582
+ if max_noise_intensity > 0.0:
583
+ noise_intensity = np.random.random() * max_noise_intensity
584
+ noisy_img = add_microscope_noise(img, noise_intensity=noise_intensity)
585
+ img[mask] = noisy_img[mask]
586
+
587
+ if self.transform_level == -1:
588
+ #special case where we take at most 300 middle pixels from the image
589
+ # (vertical subsampling)
590
+ # to handle very latge images correctly
591
+ x_from, x_to = detect_boundaries(mask, axis=0)
592
+ y_from, y_to = detect_boundaries(mask, axis=1)
593
+
594
+ y_size = y_to - y_from + 1
595
+
596
+ random_size = 300 #not so random, ay?
597
+
598
+ if y_size > random_size:
599
+ random_start = y_size // 2 - random_size // 2
600
+
601
+ y_from = random_start
602
+ y_to = random_start + random_size - 1
603
+
604
+ img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
605
+ mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
606
+
607
+ # recrop the image if necessary
608
+ # -- even after only horizontal subsampling it may be necessary to recrop the image
609
+
610
+ x_from, x_to = detect_boundaries(mask, axis=0)
611
+ y_from, y_to = detect_boundaries(mask, axis=1)
612
+
613
+ img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
614
+ mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
615
+
616
+ if self.transform_level >= 1:
617
+ ############## HORIZONTAL SUBSAMPLING
618
+ if min_horizontal_subsampling is not None:
619
+ x_size = x_to - x_from + 1
620
+
621
+ # add some random horizontal shift
622
+ random_size = np.random.randint(x_size * min_horizontal_subsampling / 100.0, x_size + 1)
623
+ random_start = np.random.randint(0, x_size - random_size + 1) + x_from
624
+
625
+ img = img[:, random_start:(random_start + random_size)]
626
+ mask = mask[:, random_start:(random_start + random_size)]
627
+
628
+ ############ VERTICAL SUBSAMPLING
629
+ if min_vertical_subsampling is not None:
630
+
631
+ x_from, x_to = detect_boundaries(mask, axis=0)
632
+ y_from, y_to = detect_boundaries(mask, axis=1)
633
+
634
+ y_size = y_to - y_from + 1
635
+
636
+ random_size = np.random.randint(y_size * min_vertical_subsampling / 100.0, y_size + 1)
637
+ random_start = np.random.randint(0, y_size - random_size + 1) + y_from
638
+
639
+ y_from = random_start
640
+ y_to = random_start + random_size - 1
641
+
642
+ img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
643
+ mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
644
+
645
+ if min_horizontal_subsampling is not None or min_vertical_subsampling is not None:
646
+ #recrop the image if necessary
647
+ # -- even after only horizontal subsampling it may be necessary to recrop the image
648
+
649
+ x_from, x_to = detect_boundaries(mask, axis=0)
650
+ y_from, y_to = detect_boundaries(mask, axis=1)
651
+
652
+ img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
653
+ mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
654
+
655
+
656
+ ######### ADD SYMMETRIC FILLING OF THE IMAGE BEYOND THE MASK
657
+ #img = add_symmetric_filling_beyond_mask(img, mask)
658
+ #This leaves holes in the image, so we will not use it
659
+
660
+ #plt.imshow(img)
661
+ #plt.show()
662
+ ######### HORIZONTAL AND VERTICAL SYMMETRY.
663
+ # When superimposed, the result is 180 degree rotation
664
+ if self.transform_level >= 1 and self.hv_symmetry:
665
+ for axis in range(2):
666
+ if np.random.randint(0, 2) % 2 == 0:
667
+ img = np.flip(img, axis = axis)
668
+ mask = np.flip(mask, axis = axis)
669
+ #plt.imshow(img)
670
+ #plt.show()
671
+
672
+ if self.transform_level >= 2 and (max_add_colors_to_histogram > 0 or max_remove_colors_from_histogram > 0):
673
+ lower_bound = np.random.randint(-max_add_colors_to_histogram, max_remove_colors_from_histogram + 1)
674
+ upper_bound = np.random.randint(255 - max_remove_colors_from_histogram, 255 + max_add_colors_to_histogram + 1)
675
+ # first clip the values outstanding from the range (lower_bound -- upper_bound)
676
+ img[mask] = np.clip(img[mask], lower_bound, upper_bound)
677
+ # the range (lower_bound -- upper_bound) gets mapped to the range (0--255)
678
+ # but only in a portion of the image where mask = True
679
+ img[mask] = np.interp(img[mask], (lower_bound, upper_bound), (0, 255)).astype(np.uint8)
680
+
681
+ #### since preserve_range in skimage.transform.resize is set to False, the image
682
+ #### will be converted to float. Consult:
683
+ # https://scikit-image.org/docs/stable/api/skimage.transform.html#skimage.transform.resize
684
+ # https://scikit-image.org/docs/dev/user_guide/data_types.html
685
+
686
+ # In our case the image gets conveted to floats ranging 0-1
687
+ old_height = img.shape[0]
688
+ img = skimage.transform.resize(img, (self.image_height, self.image_width), order=3)
689
+ new_height = img.shape[0]
690
+ mask = skimage.transform.resize(mask, (self.image_height, self.image_width), order=0, preserve_range=True)
691
+ # order = 0 is the nearest neighbor interpolation, so the mask is not interpolated
692
+
693
+ scale_factor = new_height / old_height
694
+
695
+
696
+ #plt.imshow(img)
697
+ #plt.show()
698
+ #plt.imshow(mask)
699
+ #plt.show()
700
+ return img, mask, scale_factor, img_raw
701
+
702
+
703
+ class AdHocDataset2(AbstractDataset):
704
+ def __init__(self, images_masks_pxpernm: list[tuple[np.ndarray, np.ndarray, float]], *args, **kwargs):
705
+ super().__init__(*args, **kwargs)
706
+ self.data = images_masks_pxpernm
707
+
708
+ def __len__(self):
709
+ return len(self.data)
710
+
711
+ def __getitem__(self, idx):
712
+ image, mask, px_per_nm = self.data[idx]
713
+
714
+ image, mask, scale_factor, image_raw = self.load_and_transform_image_and_mask(image, mask)
715
+
716
+ image_data = {
717
+ 'image': image,
718
+ }
719
+
720
+ for transform in self.transforms:
721
+ image_data = transform(**image_data)
722
+ # transform operates on image field ONLY of image_data, and returns a dictionary with the same keys
723
+
724
+ ret_dict = {
725
+ 'image': image_data['image'],
726
+ 'scale': scale_factor, # the scale factor is used to calculate the true period error
727
+ # (before scale) in losses and metrics
728
+ 'neutral': -self.transforms[0].mean/self.transforms[0].std #value of 0 after the scale transform
729
+ }
730
+
731
+ return ret_dict
732
+
733
+
734
+ def load_and_transform_image_and_mask(self, img, mask):
735
+
736
+ img_raw = img.copy()
737
+
738
+ if self.transform_level == -1:
739
+ #special case where we take at most 300 middle pixels from the image
740
+ # (vertical subsampling)
741
+ # to handle very latge images correctly
742
+ x_from, x_to = detect_boundaries(mask, axis=0)
743
+ y_from, y_to = detect_boundaries(mask, axis=1)
744
+
745
+ y_size = y_to - y_from + 1
746
+
747
+ max_size = 300
748
+
749
+ if y_size > max_size:
750
+ random_start = y_size // 2 - max_size // 2
751
+
752
+ y_from = random_start
753
+ y_to = random_start + max_size - 1
754
+
755
+ img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
756
+ mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
757
+
758
+ # recrop the image if necessary
759
+ # -- even after only horizontal subsampling it may be necessary to recrop the image
760
+
761
+ x_from, x_to = detect_boundaries(mask, axis=0)
762
+ y_from, y_to = detect_boundaries(mask, axis=1)
763
+
764
+ img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
765
+ mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
766
+
767
+
768
+ #### since preserve_range in skimage.transform.resize is set to False, the image
769
+ #### will be converted to float. Consult:
770
+ # https://scikit-image.org/docs/stable/api/skimage.transform.html#skimage.transform.resize
771
+ # https://scikit-image.org/docs/dev/user_guide/data_types.html
772
+
773
+ # In our case the image gets conveted to floats ranging 0-1
774
+ old_height = img.shape[0]
775
+ img = skimage.transform.resize(img, (self.image_height, self.image_width), order=3)
776
+ new_height = img.shape[0]
777
+ mask = skimage.transform.resize(mask, (self.image_height, self.image_width), order=0, preserve_range=True)
778
+ # order = 0 is the nearest neighbor interpolation, so the mask is not interpolated
779
+
780
+ scale_factor = new_height / old_height
781
+
782
+ return img, mask, scale_factor, img_raw
783
+
784
+ class AdHocDataset3(AbstractDataset):
785
+ def __init__(self, images_and_masks: list[tuple[np.ndarray, np.ndarray]], *args, **kwargs):
786
+ super().__init__(*args, **kwargs)
787
+ self.data = images_and_masks
788
+
789
+ def __len__(self):
790
+ return len(self.data)
791
+
792
+ def __getitem__(self, idx):
793
+ image, mask = self.data[idx]
794
+
795
+ image, mask, scale_factor = self.load_and_transform_image_and_mask(image, mask)
796
+
797
+ image_data = {
798
+ 'image': image,
799
+ }
800
+
801
+ for transform in self.transforms:
802
+ image_data = transform(**image_data)
803
+ # transform operates on image field ONLY of image_data, and returns a dictionary with the same keys
804
+
805
+ ret_dict = {
806
+ 'image': image_data['image'],
807
+ 'scale': scale_factor, # the scale factor is used to calculate the true period error
808
+ # (before scale) in losses and metrics
809
+ #value of 0 after the scale transform
810
+ }
811
+
812
+ return ret_dict
813
+
814
+
815
+ def load_and_transform_image_and_mask(self, img, mask):
816
+
817
+ if self.transform_level == -1:
818
+ #special case where we take at most 300 middle pixels from the image
819
+ # (vertical subsampling)
820
+ # to handle very latge images correctly
821
+ x_from, x_to = detect_boundaries(mask, axis=0)
822
+ y_from, y_to = detect_boundaries(mask, axis=1)
823
+
824
+ y_size = y_to - y_from + 1
825
+
826
+ max_size = 300
827
+
828
+ if y_size > max_size:
829
+ random_start = y_size // 2 - max_size // 2
830
+
831
+ y_from = random_start
832
+ y_to = random_start + max_size - 1
833
+
834
+ img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
835
+ mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
836
+
837
+ # recrop the image if necessary
838
+ # -- even after only horizontal subsampling it may be necessary to recrop the image
839
+
840
+ x_from, x_to = detect_boundaries(mask, axis=0)
841
+ y_from, y_to = detect_boundaries(mask, axis=1)
842
+
843
+ img = img[y_from:(y_to + 1), x_from:(x_to + 1)]
844
+ mask = mask[y_from:(y_to + 1), x_from:(x_to + 1)]
845
+
846
+
847
+ #### since preserve_range in skimage.transform.resize is set to False, the image
848
+ #### will be converted to float. Consult:
849
+ # https://scikit-image.org/docs/stable/api/skimage.transform.html#skimage.transform.resize
850
+ # https://scikit-image.org/docs/dev/user_guide/data_types.html
851
+
852
+ # In our case the image gets conveted to floats ranging 0-1
853
+ old_height = img.shape[0]
854
+ img = skimage.transform.resize(img, (self.image_height, self.image_width), order=3)
855
+ new_height = img.shape[0]
856
+ mask = skimage.transform.resize(mask, (self.image_height, self.image_width), order=0, preserve_range=True)
857
+ # order = 0 is the nearest neighbor interpolation, so the mask is not interpolated
858
+
859
+ scale_factor = new_height / old_height
860
+
861
+ return img, mask, scale_factor
period_calculation/image_transforms.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import cv2
4
+
5
+
6
+ def np_batched_radon(image_batch):
7
+ #image_batch: torch tensor, batch_size x 1 x img_size x img_size
8
+ # squeeze order #1 and transform to numpy
9
+
10
+ image_batch = image_batch.squeeze(1).cpu().numpy()
11
+
12
+ batch_size, img_size = image_batch.shape[:2]
13
+ if batch_size > 512: # limit batch size to 512 because cv2.warpAffine fails for batch> 512
14
+ return np.concatenate([np_batched_radon(image_batch[i:i+512]) for i in range(0,batch_size,512)], axis=0)
15
+ theta = np.arange(180)
16
+ radon_image = np.zeros((image_batch.shape[0], img_size, len(theta)),
17
+ dtype='float32')
18
+
19
+ for i, angle in enumerate(theta):
20
+ M = cv2.getRotationMatrix2D(((img_size-1)/2.0,(img_size-1)/2.0),angle,1)
21
+ rotated = cv2.warpAffine(np.transpose(image_batch, (1, 2, 0)),M,(img_size,img_size))
22
+
23
+ #plt.imshow(rotated[:,:,0])
24
+ #plt.show()
25
+
26
+ if batch_size == 1: # cv2.warpAffine cancels batch dimension if equal to 1
27
+ rotated = rotated[:,:, np.newaxis]
28
+ rotated = np.transpose(rotated, (2, 0, 1)) / 224.0
29
+ #rotated = rotated / np.array(255, dtype='float32')
30
+ radon_image[:, :, i] = rotated.sum(axis=1)
31
+
32
+ #plot the image
33
+
34
+ # plt.imshow(radon_image[0])
35
+ # plt.show()
36
+
37
+ return radon_image
38
+
39
+
40
+ def torch_batched_radon(image_batch, neutral_value):
41
+ #image_batch: batch_size x 1 x img_size x img_size
42
+ #np_batched_radon(image_batch - neutral_value)
43
+
44
+ image_batch = image_batch - neutral_value # so the 0 value is neutral
45
+
46
+
47
+ batch_size = image_batch.shape[0]
48
+ img_size = image_batch.shape[2]
49
+
50
+ theta = np.arange(180) # we don't need torch here, we will evaluate individual angles below
51
+
52
+ radon_image = torch.zeros((batch_size, 1, img_size, len(theta)), dtype=torch.float, device=image_batch.device)
53
+
54
+
55
+ for i, angle in enumerate(theta):
56
+ #M = cv2.getRotationMatrix2D(((img_size-1)/2.0,(img_size-1)/2.0),angle,1)
57
+ #calculate the same rotation matrix but with torch:
58
+ M = torch.tensor(cv2.getRotationMatrix2D(((img_size-1)/2.0,(img_size-1)/2.0),angle,1)).to(image_batch.device, dtype=torch.float32)
59
+ angle = torch.tensor((angle+90)/180.0*np.pi)
60
+ M1 = torch.tensor([[torch.sin(angle), torch.cos(angle), 0],
61
+ [torch.cos(angle), -torch.sin(angle), 0]]).to(image_batch.device, dtype=torch.float32)
62
+
63
+
64
+ # we need to add a batch dimension to the rotation matrix
65
+ M1 = M1.repeat(batch_size, 1, 1)
66
+
67
+ grid = torch.nn.functional.affine_grid(M1, image_batch.shape, align_corners=False)
68
+ rotated = torch.nn.functional.grid_sample(image_batch, grid, mode='bilinear', padding_mode='zeros', align_corners=False)
69
+ rotated = rotated.squeeze(1)
70
+
71
+ #plt.imshow(rotated[0].cpu().numpy())
72
+ #plt.show()
73
+
74
+ radon_image[:, 0, :, i] = rotated.sum(axis=1) / 224.0 + neutral_value
75
+
76
+ #plt.imshow(radon_image[0, 0].cpu().numpy())
77
+ #plt.show()
78
+
79
+ return radon_image
period_calculation/models/abstract_model.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+
4
+ from hydra.utils import instantiate
5
+
6
+
7
+ class AbstractModel(pl.LightningModule):
8
+ def __init__(self,
9
+ lr=0.001,
10
+ optimizer_hparams=dict(),
11
+ scheduler=dict(classname='MultiStepLR', kwargs=dict(milestones=[100, 150], gamma=0.1))
12
+ ):
13
+ super().__init__()
14
+ # Exports the hyperparameters to a YAML file, and create "self.hparams" namespace
15
+ self.save_hyperparameters()
16
+
17
+ def forward(self, x):
18
+ raise NotImplementedError("Subclass needs to implement this method")
19
+
20
+ def configure_optimizers(self):
21
+ # AdamW is Adam with a correct implementation of weight decay (see here
22
+ # for details: https://arxiv.org/pdf/1711.05101.pdf)
23
+ print("configuring the optimizer and lr scheduler with learning rate=%.5f"%self.hparams.lr)
24
+ optimizer = torch.optim.AdamW(self.parameters(), lr=self.hparams.lr, **self.hparams.optimizer_hparams)
25
+ # scheduler = getattr(torch.optim.lr_scheduler, self.hparams.lr_hparams['classname'])(optimizer, **self.hparams.lr_hparams['kwargs'])
26
+ if self.hparams.scheduler is not None:
27
+ scheduler = instantiate({**self.hparams.scheduler, '_partial_': True})(optimizer)
28
+
29
+ return [optimizer], [scheduler]
30
+ else:
31
+ return optimizer
32
+
33
+ def additional_losses(self):
34
+ """get additional_losses"""
35
+ return torch.zeros((1))
36
+
37
+ def process_batch_supervised(self, batch):
38
+ """get predictions, losses and mean errors (MAE)"""
39
+ raise NotImplementedError("Subclass needs to implement this method")
40
+
41
+ def log_all(self, losses, metrics, prefix=''):
42
+ for k, v in losses.items():
43
+ self.log(f'{prefix}{k}_loss', v.item() if isinstance(v, torch.Tensor) else v)
44
+
45
+ for k, v in metrics.items():
46
+ self.log(f'{prefix}{k}', v.item() if isinstance(v, torch.Tensor) else v)
47
+
48
+ def training_step(self, batch, batch_idx):
49
+ # "batch" is the output of the training data loader.
50
+ preds, losses, metrics = self.process_batch_supervised(batch)
51
+ self.log_all(losses, metrics, prefix='train_')
52
+
53
+ return losses['final']
54
+
55
+ def validation_step(self, batch, batch_idx):
56
+ preds, losses, metrics = self.process_batch_supervised(batch)
57
+ self.log_all(losses, metrics, prefix='val_')
58
+
59
+ def test_step(self, batch, batch_idx):
60
+ preds, losses, metrics = self.process_batch_supervised(batch)
61
+ self.log_all(losses, metrics, prefix='test_')
period_calculation/models/gauss_model.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch
3
+ import numpy as np
4
+
5
+ from period_calculation.models.abstract_model import AbstractModel
6
+
7
+ from period_calculation.config import model_config # this is a dictionary with the model configuration
8
+
9
+ class GaussPeriodModel(AbstractModel):
10
+ def __init__(self,
11
+ *args, **kwargs
12
+ ):
13
+ super().__init__(*args, **kwargs)
14
+
15
+ self.seq = torch.nn.Sequential(
16
+ torch.nn.Conv2d(1, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
17
+ torch.nn.ReLU(),
18
+ torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
19
+ torch.nn.MaxPool2d((2, 2), stride=(2, 2)),
20
+
21
+ torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
22
+ torch.nn.ReLU(),
23
+ torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
24
+ torch.nn.MaxPool2d((2, 1), stride=(2, 1)),
25
+
26
+ torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
27
+ torch.nn.ReLU(),
28
+ torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
29
+ torch.nn.MaxPool2d((2, 1), stride=(2, 1)),
30
+
31
+ torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
32
+ torch.nn.ReLU(),
33
+ torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
34
+ torch.nn.MaxPool2d((2, 1), stride=(2, 1)),
35
+
36
+ torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
37
+ torch.nn.MaxPool2d((2, 1), stride=(2, 1)),
38
+
39
+ torch.nn.Conv2d(32, 32, (3, 3), stride=(1, 1), padding=(0, 0)),
40
+ torch.nn.MaxPool2d((2, 1), stride=(2, 1)),
41
+
42
+ torch.nn.Dropout(0.1)
43
+ )
44
+ self.query = torch.nn.Parameter(torch.empty(1, 2, 32)) #two heads only
45
+ torch.nn.init.xavier_normal_(self.query)
46
+
47
+ self.linear1 = torch.nn.Linear(64, 8)
48
+ self.linear2 = torch.nn.Linear(8, 1)
49
+
50
+ self.query_sd = torch.nn.Parameter(torch.empty(1, 2, 32))
51
+ torch.nn.init.xavier_normal_(self.query_sd)
52
+
53
+ self.linear_sd1 = torch.nn.Linear(64, 8)
54
+ self.linear_sd2 = torch.nn.Linear(8, 1)
55
+ self.relu = torch.nn.ReLU()
56
+
57
+
58
+ def copy_network_trunk(self, model):
59
+ # https://discuss.pytorch.org/t/copy-weights-from-only-one-layer-of-one-model-to-another-model-with-different-structure/153419
60
+ with torch.no_grad():
61
+ for i, layer in enumerate(model.seq):
62
+ if i%2 == 0 and i!=20: #convolutional layers are the ones with even indexes with the exeption of the 20th (=dropout)
63
+ self.seq[i].weight.copy_(layer.weight)
64
+ self.seq[i].bias.copy_(layer.bias)
65
+
66
+
67
+ def copy_final_layers(self, model):
68
+ # https://discuss.pytorch.org/t/copy-weights-from-only-one-layer-of-one-model-to-another-model-with-different-structure/153419
69
+
70
+ with torch.no_grad():
71
+ self.linear1.weight.copy_(model.linear1.weight)
72
+ self.linear1.bias.copy_(model.linear1.bias)
73
+
74
+ self.linear2.weight.copy_(model.linear2.weight)
75
+ self.linear2.bias.copy_(model.linear2.bias)
76
+
77
+ self.query.copy_(model.query)
78
+
79
+ def duplicate_final_layers(self):
80
+ # https://discuss.pytorch.org/t/copy-weights-from-only-one-layer-of-one-model-to-another-model-with-different-structure/153419
81
+
82
+ with torch.no_grad():
83
+ self.linear_sd1.weight.copy_(self.linear1.weight)
84
+ self.linear_sd1.bias.copy_(self.linear1.bias)
85
+
86
+ self.linear_sd2.weight.copy_(self.linear2.weight/10)
87
+ self.linear_sd2.bias.copy_(self.linear2.bias/10)
88
+
89
+ self.query_sd.copy_(self.query)
90
+
91
+ def forward(self, x, neutral=None, return_raw=False):
92
+ #https://www.nature.com/articles/s41598-023-43852-x
93
+
94
+ # x is sized # batch x 1 x 476 x 476
95
+
96
+ preds = self.seq(x) # batch x 32 x 5 x 220
97
+ features = torch.flatten(preds, 2) # batch x 32 x 1100
98
+
99
+ # attention
100
+ energy = self.query @ features # batch x 2 x 1100
101
+ weights = torch.nn.functional.softmax(energy, 2) # batch x 2 x 1100
102
+ response = features @ weights.transpose(1, 2) # batch x 32 x 2
103
+ response = torch.flatten(response, 1) # batch x 64
104
+
105
+ preds = self.linear1(response) # batch x 8
106
+ preds = self.linear2(self.relu(preds)) # batch x 1
107
+
108
+ # attention sd
109
+
110
+ energy_sd = self.query_sd @ features # batch x 2 x 1100
111
+ weights_sd = torch.nn.functional.softmax(energy_sd, 2) # batch x 2 x 1100
112
+ response_sd = features @ weights_sd.transpose(1, 2) # batch x 32 x 2
113
+ response_sd = torch.flatten(response_sd, 1) # batch x 64
114
+
115
+ preds_sd = self.linear_sd1(response_sd) # batch x 8
116
+ preds_sd = self.linear_sd2(self.relu(preds_sd)) # batch x 1
117
+
118
+ outputs = [ model_config['receptive_field_height']/(preds[:,0]) , torch.exp(preds_sd[:,0]) ]
119
+ if return_raw:
120
+ outputs.append(preds)
121
+ outputs.append(preds_sd)
122
+ outputs.append(weights)
123
+ outputs.append(weights_sd)
124
+
125
+ return tuple(outputs)
126
+
127
+ def additional_losses(self):
128
+ """get additional_losses"""
129
+ # additional (orthogonal) loss
130
+ # we multiply the two heads and later the MSE loss (towards zero) sums the result in L2 norm
131
+ # the idea is that the scalar product of two orthogonal vectors is zero
132
+ scalar_product = torch.cat((self.query[0, 0] * self.query[0, 1], self.query_sd[0, 0] * self.query_sd[0, 1]), dim=0)
133
+ orthogonal_loss = torch.nn.functional.mse_loss(scalar_product, torch.zeros_like(scalar_product))
134
+ return orthogonal_loss
135
+
136
+
137
+
138
+ def process_batch_supervised(self, batch):
139
+ """get predictions, losses and mean errors (metrics)"""
140
+
141
+ # get predictions
142
+ preds = {}
143
+ preds['period_px'], preds['sd'] = self.forward(batch['image'], batch['neutral'][0], return_raw=False) # preds: period, sd, orto, preds_raw
144
+
145
+ # https://johaupt.github.io/blog/NN_prediction_uncertainty.html
146
+ # calculate losses
147
+ mse_period_px = torch.nn.functional.mse_loss(batch['period_px'],
148
+ preds['period_px'])
149
+
150
+ gaussian_nll = torch.nn.functional.gaussian_nll_loss(batch['period_px'],
151
+ preds['period_px'],
152
+ (preds['sd']) ** 2)
153
+
154
+ orthogonal_weight = 0.1
155
+ orthogonal_loss = self.additional_losses()
156
+ length_of_the_first_phase = 0
157
+ if self.current_epoch < length_of_the_first_phase:
158
+ #transition from MSE to Gaussian Negative Log Likelihood with sin/cos over first epochs
159
+ angle = torch.tensor((self.current_epoch) / (length_of_the_first_phase) * np.pi / 2)
160
+ total_loss = (gaussian_nll) * torch.sin(angle) + (mse_period_px) * torch.cos(angle) + orthogonal_weight * orthogonal_loss
161
+ else:
162
+ total_loss = gaussian_nll + orthogonal_weight * orthogonal_loss
163
+
164
+ losses = {
165
+ 'gaussian_nll': gaussian_nll,
166
+ 'mse_period_px': mse_period_px,
167
+ 'orthogonal': orthogonal_loss,
168
+ 'final': total_loss
169
+ }
170
+
171
+ # calculate mean errors
172
+ ground_truth_detached = batch['period_px'].detach().cpu().numpy()
173
+ print(ground_truth_detached)
174
+ mean_detached = preds['period_px'].detach().cpu().numpy()
175
+ print(mean_detached)
176
+ sd_detached = preds['sd'].detach().cpu().numpy()
177
+ print("==>", sd_detached)
178
+ px_per_nm_detached = batch['px_per_nm'].detach().cpu().numpy()
179
+ scale_detached = batch['scale'].detach().cpu().numpy()
180
+
181
+ period_px_difference = np.mean(abs(
182
+ ground_truth_detached - mean_detached
183
+ ))
184
+
185
+ #initiate both with python array with 5 zeros
186
+ true_period_px_difference = [0.0] * 5
187
+ true_period_nm_difference = [0.0] * 5
188
+
189
+ for i, dist in enumerate([1.0, 2.0, 3.0, 4.0, 5.0]):
190
+ true_period_px_difference[i] = (np.sum(abs(
191
+ ((ground_truth_detached - mean_detached) / scale_detached) * (sd_detached / scale_detached <dist))) \
192
+ / np.sum(sd_detached / scale_detached < dist)) if np.sum(sd_detached / scale_detached < dist) > 0 else 0
193
+
194
+ for i, dist in enumerate([1.0, 2.0, 3.0, 4.0, 5.0]):
195
+ true_period_nm_difference[i] = (np.sum(abs(
196
+ ((ground_truth_detached - mean_detached) / (scale_detached * px_per_nm_detached)) * (sd_detached / scale_detached <dist))) \
197
+ / np.sum(sd_detached / scale_detached < dist)) if np.sum(sd_detached / scale_detached < dist) > 0 else 0
198
+
199
+
200
+ true_period_px_difference_all = np.mean(abs(
201
+ ((ground_truth_detached - mean_detached) / scale_detached)
202
+ ))
203
+
204
+ true_period_nm_difference_all = np.mean(abs(
205
+ ((ground_truth_detached - mean_detached) / (scale_detached * px_per_nm_detached))
206
+ ))
207
+
208
+ metrics = {
209
+ 'period_px': period_px_difference,
210
+ 'true_period_px_1': true_period_px_difference[0],
211
+ 'true_period_px_2': true_period_px_difference[1],
212
+ 'true_period_px_3': true_period_px_difference[2],
213
+ 'true_period_px_4': true_period_px_difference[3],
214
+ 'true_period_px_5': true_period_px_difference[4],
215
+ 'true_period_px_all': true_period_px_difference_all,
216
+
217
+ 'true_period_nm_1': true_period_nm_difference[0],
218
+ 'true_period_nm_2': true_period_nm_difference[1],
219
+ 'true_period_nm_3': true_period_nm_difference[2],
220
+ 'true_period_nm_4': true_period_nm_difference[3],
221
+ 'true_period_nm_5': true_period_nm_difference[4],
222
+ 'true_period_nm_all': true_period_nm_difference_all,
223
+
224
+ 'count_1': np.sum(sd_detached / scale_detached < 1.0),
225
+ 'count_2': np.sum(sd_detached / scale_detached < 2.0),
226
+ 'count_3': np.sum(sd_detached / scale_detached < 3.0),
227
+ 'count_4': np.sum(sd_detached / scale_detached < 4.0),
228
+ 'count_5': np.sum(sd_detached / scale_detached < 5.0),
229
+
230
+ 'count_all': np.sum(sd_detached > 0.0),
231
+ 'mean_sd': np.mean(sd_detached),
232
+ 'sd_sd': np.std(sd_detached),
233
+ }
234
+
235
+ return preds, losses, metrics
236
+
237
+
period_calculation/period_measurer.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pathlib import Path
3
+ import albumentations as A
4
+ from albumentations.pytorch import ToTensorV2
5
+ import skimage
6
+ import scipy
7
+ import numpy as np
8
+ from pytorch_lightning import seed_everything
9
+
10
+ from period_calculation.data_reader import AdHocDataset3
11
+ from period_calculation.models.gauss_model import GaussPeriodModel
12
+
13
+
14
+ transforms = [
15
+ A.Normalize(**{'mean': 0.2845, 'std': 0.1447}, max_pixel_value=1.0),
16
+ # Applies the formula (img - mean * max_pixel_value) / (std * max_pixel_value)
17
+ ToTensorV2()
18
+ ]
19
+
20
+ class PeriodMeasurer:
21
+ """returns period in pixels"""
22
+ def __init__(
23
+ self, weights_file, image_height=476, image_width=476,
24
+ px_per_nm = 1,
25
+ sd_threshold_nm=np.inf,
26
+ period_threshold_nm_min=0, period_threshold_nm_max=np.inf):
27
+
28
+ self.model = GaussPeriodModel.load_from_checkpoint(weights_file).to("cpu") #.eval()?
29
+ self.px_per_nm = px_per_nm
30
+ self.sd_threshold_nm = sd_threshold_nm
31
+ self.period_threshold_nm_min = period_threshold_nm_min
32
+ self.period_threshold_nm_max = period_threshold_nm_max
33
+
34
+ def __call__(self, img: np.ndarray, mask: np.ndarray) -> float:
35
+ seed_everything(44)
36
+ dataset = AdHocDataset3(
37
+ images_and_masks = [(img, mask)],
38
+ transform_level=-1,
39
+ retain_raw_images=False,
40
+ transforms=transforms
41
+ )
42
+
43
+ image_data = dataset[0]
44
+ with torch.no_grad():
45
+ y_hat, sd_hat = self.model(image_data["image"].unsqueeze(0), return_raw=False)
46
+
47
+ y_hat_nm = (y_hat/image_data["scale"]).item() / self.px_per_nm
48
+ sd_hat_nm = (sd_hat/image_data["scale"]).item() /self.px_per_nm
49
+
50
+
51
+ if (sd_hat_nm>self.sd_threshold_nm) or (y_hat_nm<self.period_threshold_nm_min) or (y_hat_nm>self.period_threshold_nm_max):
52
+ y_hat_nm = np.nan
53
+
54
+ return y_hat_nm, sd_hat_nm
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pandas==2.1.1
2
+ torch==2.1.0
3
+ torchvision==0.16
4
+ ultralytics==8.0.216
5
+ scikit-image==0.22.0
6
+ pytorch-lightning==2.1.2
7
+ timm==0.9.11
8
+ albumentations==1.4.10
9
+ hydra-core==1.3.2
10
+ gradio==4.44.0
11
+ albucore==0.0.16
settings.py ADDED
@@ -0,0 +1 @@
 
 
1
+ DEMO = True
styles.css ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .header {
2
+ display: flex;
3
+ padding: 30px;
4
+ text-align: center;
5
+ justify-content: center;
6
+ }
7
+
8
+ #header-text {
9
+ font-size: 50px;
10
+ line-height: 50px;
11
+ }
12
+
13
+ #header-logo {
14
+ width: 50px;
15
+ height: 50px;
16
+ margin-right: 10px;
17
+ /*background-image: url("file=images/logo.svg");*/
18
+ }
19
+
20
+ .input-row {
21
+ max-width: 900px;
22
+ margin: 0 auto;
23
+ }
24
+
25
+ .margin-bottom {
26
+ margin-bottom: 48px;
27
+ }
28
+
29
+ .results-header {
30
+ margin-top: 48px;
31
+ text-align: center;
32
+ font-size: 45px;
33
+ margin-bottom: 12px;
34
+ }
35
+
36
+ .processed-info {
37
+ display: flex;
38
+ padding: 30px;
39
+ text-align: center;
40
+ justify-content: center;
41
+ font-size: 26px;
42
+ }
43
+
44
+ .title {
45
+ margin-bottom: 8px!important;
46
+ font-size: 22px;
47
+ }
weights/AS_square_v16.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9c1d7d4e56f0ea28b34dd2457807e9266d0d5539cf5adfb0719fed10791f79c5
3
+ size 44771917
weights/model_weights_detector.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37fe3d98789572cf147d5f0d1ea99d50a57e5c2028454c3825295e89b6350fd5
3
+ size 23926765
weights/period_measurer_weights-1.298_real_full-fa12970.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:90133b01545ae9d30eafcbdec480410e0f8b9fae3ba1aabb280450ce9589100a
3
+ size 350396
weights/yolo/20240604_yolov8_segm_ABRCR1_all_train4_best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:038555e9b9f3900ec29c9c79d7b3d5b50aa3ca37b3e665ee7aa2394facc7e20e
3
+ size 23926765
weights/yolo/current_yolo.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37fe3d98789572cf147d5f0d1ea99d50a57e5c2028454c3825295e89b6350fd5
3
+ size 23926765