fadindashfr commited on
Commit
4c5329d
·
1 Parent(s): 00edf39

initial commit all file

Browse files
.gitattributes CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
35
+ *.ts filter=lfs diff=lfs merge=lfs -text
Description.md ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Overview
2
+ Nuclei classification within Haematoxylin & Eosi stained histology images. Classifying nuclei cells as the following types:
3
+ - Other
4
+ - Inflammatory
5
+ - Epithelial
6
+ - Spindle-Shaped
7
+
8
+ References: https://doi.org/10.1016/j.media.2019.101563
9
+
10
+ ## Dataset
11
+ The model is trained with Colorectal Nuclear Segmentation and Phenotypes (CoNSeP) dataset https://warwick.ac.uk/fac/cross_fac/tia/data/hovernet. Images were extracted from 16 colorectal adenocarcinoma (CRA) WSIs.
12
+
13
+ - Target: Nuclei
14
+ - Task: Nuclei Cells Class Classification
15
+ - Modality: Image (RGB)
16
+
17
+ ## Model Architecture
18
+ The model is trained using DenseNet121 over CoNSep dataset.
19
+ ![alt text](file/figures/architecture.png)
20
+
21
+ ## Demo
22
+ Please select or upload a nuclei histology image and label image to see Nuclei Cells Classification capabilities of this model
app.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from monai.bundle import ConfigParser
3
+ import gradio as gr
4
+
5
+ parser = ConfigParser() # load configuration files that specify various parameters for running the MONAI workflow.
6
+ parser.read_config(f="configs/inference.json") # read the config from specified JSON file
7
+ parser.read_meta(f="configs/metadata.json") # read the metadata from specified JSON file
8
+
9
+ inference = parser.get_parsed_content("inferer")
10
+ network = parser.get_parsed_content("network_def")
11
+ preprocess = parser.get_parsed_content("preprocessing")
12
+ state_dict = torch.load("models/model.pt")
13
+ network.load_state_dict(state_dict, strict=True) # Loads a model’s parameter dictionary
14
+ class_names = {
15
+ 0: "Other",
16
+ 1: "Inflammatory",
17
+ 2: "Epithelial",
18
+ 3: "Spindle-Shaped",
19
+ }
20
+
21
+ def classify_image(image_file, label_file):
22
+ data = {"image":image_file, "label":label_file}
23
+ batch = preprocess(data)
24
+ network.eval()
25
+ with torch.no_grad():
26
+ pred = inference(batch['image'].unsqueeze(dim=0), network) # expect 4 channels input (3 RGB, 1 Label mask)
27
+ prob = pred.softmax(-1).detach().cpu().numpy()[0]
28
+ confidences = {class_names[i]: float(prob[i]) for i in range(len(class_names))}
29
+ return confidences
30
+
31
+ example_files1 = [
32
+ [r'sample_data\Images\test_11_2_0628.png',
33
+ r'sample_data\Labels\test_11_2_0628.png'],
34
+ [r'sample_data\Images\test_9_4_0149.png',
35
+ r'sample_data\Labels\test_9_4_0149.png'],
36
+ [r'sample_data\Images\test_12_3_0292.png',
37
+ r'sample_data\Labels\test_12_3_0292.png'],
38
+ [r'sample_data\Images\test_9_4_0019.png',
39
+ r'sample_data\Labels\test_9_4_0019.png']
40
+ ]
41
+
42
+ example_files2 = [
43
+ [r'sample_data\Images\test_14_3_0433.png',
44
+ r'sample_data\Labels\test_14_3_0433.png'],
45
+ [r'sample_data\Images\test_14_4_0544.png',
46
+ r'sample_data\Labels\test_14_4_0544.png'],
47
+ [r'sample_data\Images\train_1_1_0095.png',
48
+ r'sample_data\Labels\train_1_1_0095.png'],
49
+ [r'sample_data\Images\train_1_3_0020.png',
50
+ r'sample_data\Labels\train_1_3_0020.png'],
51
+ ]
52
+
53
+ with open('Description.md','r') as file:
54
+ markdown_content = file.read()
55
+ with gr.Blocks() as app:
56
+ gr.Markdown("# Pathology Nuclei Classification")
57
+ gr.Markdown(markdown_content)
58
+ with gr.Row():
59
+ with gr.Column():
60
+ with gr.Row():
61
+ inp_img = gr.Image(type="filepath", image_mode="RGB")
62
+ label_img = gr.Image(type="filepath", image_mode="L")
63
+ with gr.Row():
64
+ process_btn = gr.Button(value="Process")
65
+ clear_btn = gr.Button(value="Clear")
66
+ out_txt = gr.Label(label="Probabilities", num_top_classes=4)
67
+
68
+ process_btn.click(fn=classify_image, inputs=[inp_img, label_img], outputs=out_txt)
69
+ clear_btn.click(lambda:(
70
+ gr.update(value=None),
71
+ gr.update(value=None),
72
+ gr.update(value=None)
73
+ ),
74
+ inputs=None,
75
+ outputs=[inp_img, label_img,out_txt]
76
+ )
77
+
78
+ gr.Markdown("## Image Examples")
79
+ with gr.Row():
80
+ for file in example_files1:
81
+ gr.Examples(
82
+ [file], inputs=[inp_img, label_img]
83
+ )
84
+ with gr.Row():
85
+ for file in example_files2:
86
+ gr.Examples(
87
+ [file], inputs=[inp_img, label_img]
88
+ )
89
+ app.launch()
configs/evaluate.json ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "validate#dataset#cache_rate": 0,
3
+ "validate#postprocessing": {
4
+ "_target_": "Compose",
5
+ "transforms": [
6
+ {
7
+ "_target_": "Activationsd",
8
+ "keys": "pred",
9
+ "softmax": true
10
+ },
11
+ {
12
+ "_target_": "AsDiscreted",
13
+ "keys": [
14
+ "pred",
15
+ "label"
16
+ ],
17
+ "argmax": [
18
+ true,
19
+ false
20
+ ],
21
+ "to_onehot": 4
22
+ },
23
+ {
24
+ "_target_": "ToTensord",
25
+ "keys": [
26
+ "pred",
27
+ "label"
28
+ ],
29
+ "device": "@device"
30
+ },
31
+ {
32
+ "_target_": "SaveImaged",
33
+ "_disabled_": true,
34
+ "keys": "pred",
35
+ "meta_keys": "pred_meta_dict",
36
+ "output_dir": "@output_dir",
37
+ "output_ext": ".json"
38
+ }
39
+ ]
40
+ },
41
+ "validate#handlers": [
42
+ {
43
+ "_target_": "CheckpointLoader",
44
+ "load_path": "$@ckpt_dir + '/model.pt'",
45
+ "load_dict": {
46
+ "model": "@network"
47
+ }
48
+ },
49
+ {
50
+ "_target_": "StatsHandler",
51
+ "iteration_log": false
52
+ },
53
+ {
54
+ "_target_": "MetricsSaver",
55
+ "save_dir": "@output_dir",
56
+ "metrics": [
57
+ "val_f1",
58
+ "val_accuracy"
59
+ ],
60
+ "metric_details": [
61
+ "val_f1"
62
+ ],
63
+ "batch_transform": "$monai.handlers.from_engine(['image_meta_dict'])",
64
+ "summary_ops": "*"
65
+ }
66
+ ],
67
+ "evaluating": [
68
+ "$import sys",
69
+ "$sys.path.append(@bundle_root)",
70
+ "$setattr(torch.backends.cudnn, 'benchmark', True)",
71
+ "$import scripts",
72
+ "$monai.data.register_writer('json', scripts.ClassificationWriter)",
73
+ "$@validate#evaluator.run()"
74
+ ]
75
+ }
configs/inference.json ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import glob",
4
+ "$import json",
5
+ "$import pathlib",
6
+ "$import os"
7
+ ],
8
+ "bundle_root": "/workspace/data/pathology_nuclei_classification",
9
+ "output_dir": "$@bundle_root + '/eval'",
10
+ "dataset_dir": "/workspace/data/CoNSePNuclei",
11
+ "images": "$list(sorted(glob.glob(@dataset_dir + '/Test/Images/*.png')))[:1]",
12
+ "labels": "$list(sorted(glob.glob(@dataset_dir + '/Test/Labels/*.png')))[:1]",
13
+ "input_data": "$[{'image': i, 'label': l} for i,l in zip(@images, @labels)]",
14
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
15
+ "network_def": {
16
+ "_target_": "DenseNet121",
17
+ "spatial_dims": 2,
18
+ "in_channels": 4,
19
+ "out_channels": 4
20
+ },
21
+ "network": "$@network_def.to(@device)",
22
+ "preprocessing": {
23
+ "_target_": "Compose",
24
+ "transforms": [
25
+ {
26
+ "_target_": "LoadImaged",
27
+ "keys": [
28
+ "image",
29
+ "label"
30
+ ],
31
+ "dtype": "uint8"
32
+ },
33
+ {
34
+ "_target_": "EnsureChannelFirstd",
35
+ "keys": [
36
+ "image",
37
+ "label"
38
+ ]
39
+ },
40
+ {
41
+ "_target_": "ScaleIntensityRanged",
42
+ "keys": "image",
43
+ "a_min": 0.0,
44
+ "a_max": 255.0,
45
+ "b_min": -1.0,
46
+ "b_max": 1.0
47
+ },
48
+ {
49
+ "_target_": "AddLabelAsGuidanced",
50
+ "keys": "image",
51
+ "source": "label"
52
+ }
53
+ ]
54
+ },
55
+ "dataset": {
56
+ "_target_": "Dataset",
57
+ "data": "@input_data",
58
+ "transform": "@preprocessing"
59
+ },
60
+ "dataloader": {
61
+ "_target_": "DataLoader",
62
+ "dataset": "@dataset",
63
+ "batch_size": 1,
64
+ "shuffle": false,
65
+ "num_workers": 4
66
+ },
67
+ "inferer": {
68
+ "_target_": "SimpleInferer"
69
+ },
70
+ "postprocessing": {
71
+ "_target_": "Compose",
72
+ "transforms": [
73
+ {
74
+ "_target_": "Activationsd",
75
+ "keys": "pred",
76
+ "softmax": true
77
+ },
78
+ {
79
+ "_target_": "SaveImaged",
80
+ "keys": "pred",
81
+ "meta_keys": "pred_meta_dict",
82
+ "output_dir": "@output_dir",
83
+ "output_ext": ".json"
84
+ }
85
+ ]
86
+ },
87
+ "handlers": [
88
+ {
89
+ "_target_": "CheckpointLoader",
90
+ "load_path": "$@bundle_root + '/models/model.pt'",
91
+ "load_dict": {
92
+ "model": "@network"
93
+ }
94
+ },
95
+ {
96
+ "_target_": "StatsHandler",
97
+ "iteration_log": false
98
+ }
99
+ ],
100
+ "evaluator": {
101
+ "_target_": "SupervisedEvaluator",
102
+ "device": "@device",
103
+ "val_data_loader": "@dataloader",
104
+ "network": "@network",
105
+ "inferer": "@inferer",
106
+ "postprocessing": "@postprocessing",
107
+ "val_handlers": "@handlers",
108
+ "amp": true
109
+ },
110
+ "evaluating": [
111
+ "$setattr(torch.backends.cudnn, 'benchmark', True)",
112
+ "$import scripts",
113
+ "$monai.data.register_writer('json', scripts.ClassificationWriter)",
114
115
+ ]
116
+ }
configs/logging.conf ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [loggers]
2
+ keys=root
3
+
4
+ [handlers]
5
+ keys=consoleHandler
6
+
7
+ [formatters]
8
+ keys=fullFormatter
9
+
10
+ [logger_root]
11
+ level=INFO
12
+ handlers=consoleHandler
13
+
14
+ [handler_consoleHandler]
15
+ class=StreamHandler
16
+ level=INFO
17
+ formatter=fullFormatter
18
+ args=(sys.stdout,)
19
+
20
+ [formatter_fullFormatter]
21
+ format=%(asctime)s - %(name)s - %(levelname)s - %(message)s
configs/metadata.json ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
3
+ "version": "0.0.5",
4
+ "changelog": {
5
+ "0.0.5": "add name tag",
6
+ "0.0.4": "Fix evaluation",
7
+ "0.0.3": "Update to use MONAI 1.1.0",
8
+ "0.0.2": "Update The Torch Vision Transform",
9
+ "0.0.1": "initialize the model package structure"
10
+ },
11
+ "monai_version": "1.1.0",
12
+ "pytorch_version": "1.13.0",
13
+ "numpy_version": "1.21.2",
14
+ "optional_packages_version": {
15
+ "nibabel": "4.0.1",
16
+ "pytorch-ignite": "0.4.9"
17
+ },
18
+ "name": "Pathology nuclei classification",
19
+ "task": "Pathology Nuclei classification",
20
+ "description": "A pre-trained model for Nuclei Classification within Haematoxylin & Eosin stained histology images",
21
+ "authors": "MONAI team",
22
+ "copyright": "Copyright (c) MONAI Consortium",
23
+ "data_source": "consep_dataset.zip from https://warwick.ac.uk/fac/cross_fac/tia/data/hovernet",
24
+ "data_type": "png",
25
+ "image_classes": "RGB channel data, intensity scaled to [0, 1]",
26
+ "label_classes": "single channel data",
27
+ "pred_classes": "4 channels OneHot data, channel 0 is Other, channel 1 is Inflammatory, channel 2 is Epithelial, channel 3 is Spindle-Shaped",
28
+ "eval_metrics": {
29
+ "f1_score": 0.85
30
+ },
31
+ "intended_use": "This is an example, not to be used for diagnostic purposes",
32
+ "references": [
33
+ "S. Graham, Q. D. Vu, S. E. A. Raza, A. Azam, Y-W. Tsang, J. T. Kwak and N. Rajpoot. \"HoVer-Net: Simultaneous Segmentation and Classification of Nuclei in Multi-Tissue Histology Images.\" Medical Image Analysis, Sept. 2019. https://doi.org/10.1016/j.media.2019.101563"
34
+ ],
35
+ "network_data_format": {
36
+ "inputs": {
37
+ "image": {
38
+ "type": "magnitude",
39
+ "format": "RGB",
40
+ "modality": "regular",
41
+ "num_channels": 4,
42
+ "spatial_shape": [
43
+ 128,
44
+ 128
45
+ ],
46
+ "dtype": "float32",
47
+ "value_range": [
48
+ 0,
49
+ 1
50
+ ],
51
+ "is_patch_data": false,
52
+ "channel_def": {
53
+ "0": "R",
54
+ "1": "G",
55
+ "2": "B",
56
+ "3": "Mask"
57
+ }
58
+ }
59
+ },
60
+ "outputs": {
61
+ "pred": {
62
+ "type": "probabilities",
63
+ "format": "classes",
64
+ "num_channels": 4,
65
+ "spatial_shape": [
66
+ 1,
67
+ 4
68
+ ],
69
+ "dtype": "float32",
70
+ "value_range": [
71
+ 0,
72
+ 1,
73
+ 2,
74
+ 3
75
+ ],
76
+ "is_patch_data": false,
77
+ "channel_def": {
78
+ "0": "Other",
79
+ "1": "Inflammatory",
80
+ "2": "Epithelial",
81
+ "3": "Spindle-Shaped"
82
+ }
83
+ }
84
+ }
85
+ }
86
+ }
configs/multi_gpu_evaluate.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "device": "$torch.device(f'cuda:{dist.get_rank()}')",
3
+ "network": {
4
+ "_target_": "torch.nn.parallel.DistributedDataParallel",
5
+ "module": "$@network_def.to(@device)",
6
+ "device_ids": [
7
+ "@device"
8
+ ]
9
+ },
10
+ "validate#sampler": {
11
+ "_target_": "DistributedSampler",
12
+ "dataset": "@validate#dataset",
13
+ "even_divisible": false,
14
+ "shuffle": false
15
+ },
16
+ "validate#dataloader#sampler": "@validate#sampler",
17
+ "validate#handlers#1#_disabled_": "$dist.get_rank() > 0",
18
+ "evaluating": [
19
+ "$import sys",
20
+ "$sys.path.append(@bundle_root)",
21
+ "$import torch.distributed as dist",
22
+ "$dist.init_process_group(backend='nccl')",
23
+ "$torch.cuda.set_device(@device)",
24
+ "$setattr(torch.backends.cudnn, 'benchmark', True)",
25
+ "$import logging",
26
+ "$@validate#evaluator.logger.setLevel(logging.WARNING if dist.get_rank() > 0 else logging.INFO)",
27
+ "$import scripts",
28
+ "$monai.data.register_writer('json', scripts.ClassificationWriter)",
29
+ "$@validate#evaluator.run()",
30
+ "$dist.destroy_process_group()"
31
+ ]
32
+ }
configs/multi_gpu_train.json ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "device": "$torch.device(f'cuda:{dist.get_rank()}')",
3
+ "network": {
4
+ "_target_": "torch.nn.parallel.DistributedDataParallel",
5
+ "module": "$@network_def.to(@device)",
6
+ "device_ids": [
7
+ "@device"
8
+ ]
9
+ },
10
+ "train#sampler": {
11
+ "_target_": "DistributedSampler",
12
+ "dataset": "@train#dataset",
13
+ "even_divisible": true,
14
+ "shuffle": true
15
+ },
16
+ "train#dataloader#sampler": "@train#sampler",
17
+ "train#dataloader#shuffle": false,
18
+ "train#trainer#train_handlers": "$@train#handlers[: -2 if dist.get_rank() > 0 else None]",
19
+ "validate#sampler": {
20
+ "_target_": "DistributedSampler",
21
+ "dataset": "@validate#dataset",
22
+ "even_divisible": false,
23
+ "shuffle": false
24
+ },
25
+ "validate#dataloader#sampler": "@validate#sampler",
26
+ "validate#evaluator#val_handlers": "$None if dist.get_rank() > 0 else @validate#handlers",
27
+ "training": [
28
+ "$import sys",
29
+ "$sys.path.append(@bundle_root)",
30
+ "$import torch.distributed as dist",
31
+ "$dist.init_process_group(backend='nccl')",
32
+ "$torch.cuda.set_device(@device)",
33
+ "$monai.utils.set_determinism(seed=123)",
34
+ "$setattr(torch.backends.cudnn, 'benchmark', True)",
35
+ "$import logging",
36
+ "$@train#trainer.logger.setLevel(logging.WARNING if dist.get_rank() > 0 else logging.INFO)",
37
+ "$@validate#evaluator.logger.setLevel(logging.WARNING if dist.get_rank() > 0 else logging.INFO)",
38
+ "$@train#trainer.run()",
39
+ "$dist.destroy_process_group()"
40
+ ]
41
+ }
configs/train.json ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "imports": [
3
+ "$import glob",
4
+ "$import ignite",
5
+ "$import json",
6
+ "$import pathlib",
7
+ "$import os"
8
+ ],
9
+ "bundle_root": "/workspace/data/pathology_nuclei_classification",
10
+ "ckpt_dir": "$@bundle_root + '/models'",
11
+ "output_dir": "$@bundle_root + '/eval'",
12
+ "dataset_dir": "/workspace/data/CoNSePNuclei",
13
+ "dataset_json": "$@dataset_dir + '/dataset.json'",
14
+ "train_datalist": "$json.loads(pathlib.Path(@dataset_json).read_text())['training']",
15
+ "val_datalist": "$json.loads(pathlib.Path(@dataset_json).read_text())['validation']",
16
+ "val_interval": 1,
17
+ "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
18
+ "network_def": {
19
+ "_target_": "DenseNet121",
20
+ "spatial_dims": 2,
21
+ "in_channels": 4,
22
+ "out_channels": 4
23
+ },
24
+ "network": "$@network_def.to(@device)",
25
+ "loss": {
26
+ "_target_": "torch.nn.CrossEntropyLoss"
27
+ },
28
+ "optimizer": {
29
+ "_target_": "torch.optim.Adam",
30
+ "params": "[email protected]()",
31
+ "lr": 0.0001
32
+ },
33
+ "max_epochs": 50,
34
+ "train": {
35
+ "preprocessing": {
36
+ "_target_": "Compose",
37
+ "transforms": [
38
+ {
39
+ "_target_": "LoadImaged",
40
+ "keys": [
41
+ "image",
42
+ "label"
43
+ ],
44
+ "dtype": "uint8"
45
+ },
46
+ {
47
+ "_target_": "EnsureChannelFirstd",
48
+ "keys": [
49
+ "image",
50
+ "label"
51
+ ]
52
+ },
53
+ {
54
+ "_target_": "SplitLabeld",
55
+ "keys": "label",
56
+ "mask_value": "",
57
+ "others_value": 255,
58
+ "to_binary_mask": false
59
+ },
60
+ {
61
+ "_target_": "RandTorchVisiond",
62
+ "keys": "image",
63
+ "name": "ColorJitter",
64
+ "brightness": 0.25,
65
+ "contrast": 0.75,
66
+ "saturation": 0.25,
67
+ "hue": 0.04
68
+ },
69
+ {
70
+ "_target_": "RandFlipd",
71
+ "keys": [
72
+ "image",
73
+ "label",
74
+ "others"
75
+ ],
76
+ "prob": 0.5
77
+ },
78
+ {
79
+ "_target_": "RandRotate90d",
80
+ "keys": [
81
+ "image",
82
+ "label",
83
+ "others"
84
+ ],
85
+ "prob": 0.5
86
+ },
87
+ {
88
+ "_target_": "ScaleIntensityRanged",
89
+ "keys": "image",
90
+ "a_min": 0.0,
91
+ "a_max": 255.0,
92
+ "b_min": -1.0,
93
+ "b_max": 1.0
94
+ },
95
+ {
96
+ "_target_": "AddLabelAsGuidanced",
97
+ "keys": "image",
98
+ "source": "label"
99
+ },
100
+ {
101
+ "_target_": "SetLabelClassd",
102
+ "keys": "label",
103
+ "offset": -1
104
+ },
105
+ {
106
+ "_target_": "SelectItemsd",
107
+ "keys": [
108
+ "image",
109
+ "label"
110
+ ]
111
+ }
112
+ ]
113
+ },
114
+ "dataset": {
115
+ "_target_": "CacheDataset",
116
+ "data": "@train_datalist",
117
+ "transform": "@train#preprocessing",
118
+ "cache_rate": 1.0,
119
+ "num_workers": 4
120
+ },
121
+ "dataloader": {
122
+ "_target_": "DataLoader",
123
+ "dataset": "@train#dataset",
124
+ "batch_size": 64,
125
+ "shuffle": true,
126
+ "num_workers": 4
127
+ },
128
+ "inferer": {
129
+ "_target_": "SimpleInferer"
130
+ },
131
+ "postprocessing": {
132
+ "_target_": "Compose",
133
+ "transforms": [
134
+ {
135
+ "_target_": "Activationsd",
136
+ "keys": "pred",
137
+ "softmax": true
138
+ },
139
+ {
140
+ "_target_": "AsDiscreted",
141
+ "keys": [
142
+ "pred",
143
+ "label"
144
+ ],
145
+ "argmax": [
146
+ true,
147
+ false
148
+ ],
149
+ "to_onehot": 4
150
+ },
151
+ {
152
+ "_target_": "ToTensord",
153
+ "keys": [
154
+ "pred",
155
+ "label"
156
+ ],
157
+ "device": "@device"
158
+ }
159
+ ]
160
+ },
161
+ "handlers": [
162
+ {
163
+ "_target_": "ValidationHandler",
164
+ "validator": "@validate#evaluator",
165
+ "epoch_level": true,
166
+ "interval": "@val_interval"
167
+ },
168
+ {
169
+ "_target_": "StatsHandler",
170
+ "tag_name": "train_loss",
171
+ "output_transform": "$monai.handlers.from_engine(['loss'], first=True)"
172
+ },
173
+ {
174
+ "_target_": "TensorBoardStatsHandler",
175
+ "log_dir": "@output_dir",
176
+ "tag_name": "train_loss",
177
+ "output_transform": "$monai.handlers.from_engine(['loss'], first=True)"
178
+ },
179
+ {
180
+ "_target_": "scripts.TensorBoardImageHandler",
181
+ "class_names": {
182
+ "0": "Other",
183
+ "1": "Inflammatory",
184
+ "2": "Epithelial",
185
+ "3": "Spindle-Shaped"
186
+ },
187
+ "log_dir": "@output_dir",
188
+ "batch_limit": 4,
189
+ "tag_name": "train"
190
+ }
191
+ ],
192
+ "key_metric": {
193
+ "train_f1": {
194
+ "_target_": "ConfusionMatrix",
195
+ "metric_name": "f1 score",
196
+ "output_transform": "$monai.handlers.from_engine(['pred', 'label'])"
197
+ }
198
+ },
199
+ "trainer": {
200
+ "_target_": "SupervisedTrainer",
201
+ "max_epochs": "@max_epochs",
202
+ "device": "@device",
203
+ "train_data_loader": "@train#dataloader",
204
+ "network": "@network",
205
+ "loss_function": "@loss",
206
+ "optimizer": "@optimizer",
207
+ "inferer": "@train#inferer",
208
+ "postprocessing": "@train#postprocessing",
209
+ "key_train_metric": "@train#key_metric",
210
+ "train_handlers": "@train#handlers",
211
+ "amp": true
212
+ }
213
+ },
214
+ "validate": {
215
+ "preprocessing": {
216
+ "_target_": "Compose",
217
+ "transforms": [
218
+ {
219
+ "_target_": "LoadImaged",
220
+ "keys": [
221
+ "image",
222
+ "label"
223
+ ],
224
+ "dtype": "uint8"
225
+ },
226
+ {
227
+ "_target_": "EnsureChannelFirstd",
228
+ "keys": [
229
+ "image",
230
+ "label"
231
+ ]
232
+ },
233
+ {
234
+ "_target_": "SplitLabeld",
235
+ "keys": "label",
236
+ "mask_value": "",
237
+ "others_value": 255,
238
+ "to_binary_mask": false
239
+ },
240
+ {
241
+ "_target_": "ScaleIntensityRanged",
242
+ "keys": "image",
243
+ "a_min": 0.0,
244
+ "a_max": 255.0,
245
+ "b_min": -1.0,
246
+ "b_max": 1.0
247
+ },
248
+ {
249
+ "_target_": "AddLabelAsGuidanced",
250
+ "keys": "image",
251
+ "source": "label"
252
+ },
253
+ {
254
+ "_target_": "SetLabelClassd",
255
+ "keys": "label",
256
+ "offset": -1
257
+ },
258
+ {
259
+ "_target_": "SelectItemsd",
260
+ "keys": [
261
+ "image",
262
+ "label",
263
+ "image_meta_dict"
264
+ ]
265
+ }
266
+ ]
267
+ },
268
+ "dataset": {
269
+ "_target_": "CacheDataset",
270
+ "data": "@val_datalist",
271
+ "transform": "@validate#preprocessing",
272
+ "cache_rate": 1.0
273
+ },
274
+ "dataloader": {
275
+ "_target_": "DataLoader",
276
+ "dataset": "@validate#dataset",
277
+ "batch_size": 64,
278
+ "shuffle": false,
279
+ "num_workers": 4
280
+ },
281
+ "inferer": {
282
+ "_target_": "SimpleInferer"
283
+ },
284
+ "postprocessing": "%train#postprocessing",
285
+ "handlers": [
286
+ {
287
+ "_target_": "StatsHandler",
288
+ "iteration_log": false
289
+ },
290
+ {
291
+ "_target_": "TensorBoardStatsHandler",
292
+ "log_dir": "@output_dir",
293
+ "iteration_log": false
294
+ },
295
+ {
296
+ "_target_": "CheckpointSaver",
297
+ "save_dir": "@ckpt_dir",
298
+ "save_dict": {
299
+ "model": "@network"
300
+ },
301
+ "save_key_metric": true,
302
+ "key_metric_filename": "model.pt"
303
+ },
304
+ {
305
+ "_target_": "scripts.TensorBoardImageHandler",
306
+ "class_names": {
307
+ "0": "Other",
308
+ "1": "Inflammatory",
309
+ "2": "Epithelial",
310
+ "3": "Spindle-Shaped"
311
+ },
312
+ "log_dir": "@output_dir",
313
+ "batch_limit": 8,
314
+ "tag_name": "val"
315
+ }
316
+ ],
317
+ "key_metric": {
318
+ "val_f1": {
319
+ "_target_": "ConfusionMatrix",
320
+ "metric_name": "f1 score",
321
+ "output_transform": "$monai.handlers.from_engine(['pred', 'label'])"
322
+ }
323
+ },
324
+ "additional_metrics": {
325
+ "val_accuracy": {
326
+ "_target_": "ignite.metrics.Accuracy",
327
+ "output_transform": "$monai.handlers.from_engine(['pred', 'label'])"
328
+ }
329
+ },
330
+ "evaluator": {
331
+ "_target_": "SupervisedEvaluator",
332
+ "device": "@device",
333
+ "val_data_loader": "@validate#dataloader",
334
+ "network": "@network",
335
+ "inferer": "@validate#inferer",
336
+ "postprocessing": "@validate#postprocessing",
337
+ "key_val_metric": "@validate#key_metric",
338
+ "additional_metrics": "@validate#additional_metrics",
339
+ "val_handlers": "@validate#handlers",
340
+ "amp": true
341
+ }
342
+ },
343
+ "training": [
344
+ "$import sys",
345
+ "$sys.path.append(@bundle_root)",
346
+ "$monai.utils.set_determinism(seed=123)",
347
+ "$setattr(torch.backends.cudnn, 'benchmark', True)",
348
+ "$@train#trainer.run()"
349
+ ]
350
+ }
figures/architecture.png ADDED
models/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f254ae6b0318e1375d48c1e9d6056d236d4a1a32957afc4aeafba0e047c46b2b
3
+ size 28419489
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ monai==1.1.0
2
+ gradio
sample_data/Images/test_11_2_0628.png ADDED
sample_data/Images/test_12_3_0292.png ADDED
sample_data/Images/test_14_3_0433.png ADDED
sample_data/Images/test_14_4_0544.png ADDED
sample_data/Images/test_9_4_0019.png ADDED
sample_data/Images/test_9_4_0149.png ADDED
sample_data/Images/train_1_1_0095.png ADDED
sample_data/Images/train_1_3_0020.png ADDED
sample_data/Labels/test_11_2_0628.png ADDED
sample_data/Labels/test_12_3_0292.png ADDED
sample_data/Labels/test_14_3_0433.png ADDED
sample_data/Labels/test_14_4_0544.png ADDED
sample_data/Labels/test_9_4_0019.png ADDED
sample_data/Labels/test_9_4_0149.png ADDED
sample_data/Labels/train_1_1_0095.png ADDED
sample_data/Labels/train_1_3_0020.png ADDED