Upload 118 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +31 -0
- config.yml +157 -0
- deployment.ipynb +995 -0
- handler.py +11 -0
- inference.py +341 -0
- inference2.py +169 -0
- internals/__init__.py +0 -0
- internals/data/__init__.py +0 -0
- internals/data/dataAccessor.py +104 -0
- internals/data/result.py +19 -0
- internals/data/task.py +125 -0
- internals/pipelines/commons.py +119 -0
- internals/pipelines/controlnets.py +221 -0
- internals/pipelines/img_classifier.py +24 -0
- internals/pipelines/img_to_text.py +31 -0
- internals/pipelines/inpainter.py +41 -0
- internals/pipelines/object_remove.py +82 -0
- internals/pipelines/prompt_modifier.py +54 -0
- internals/pipelines/remove_background.py +16 -0
- internals/pipelines/safety_checker.py +163 -0
- internals/pipelines/twoStepPipeline.py +252 -0
- internals/pipelines/upscaler.py +91 -0
- internals/util/__init__.py +0 -0
- internals/util/args.py +13 -0
- internals/util/avatar.py +59 -0
- internals/util/cache.py +31 -0
- internals/util/commons.py +203 -0
- internals/util/config.py +66 -0
- internals/util/failure_hander.py +40 -0
- internals/util/image.py +18 -0
- internals/util/lora_style.py +154 -0
- internals/util/slack.py +58 -0
- models/ade20k/.DS_Store +0 -0
- models/ade20k/__init__.py +1 -0
- models/ade20k/base.py +627 -0
- models/ade20k/color150.mat +0 -0
- models/ade20k/mobilenet.py +154 -0
- models/ade20k/object150_info.csv +151 -0
- models/ade20k/resnet.py +181 -0
- models/ade20k/segm_lib/.DS_Store +0 -0
- models/ade20k/segm_lib/nn/.DS_Store +0 -0
- models/ade20k/segm_lib/nn/__init__.py +2 -0
- models/ade20k/segm_lib/nn/modules/__init__.py +12 -0
- models/ade20k/segm_lib/nn/modules/batchnorm.py +329 -0
- models/ade20k/segm_lib/nn/modules/comm.py +131 -0
- models/ade20k/segm_lib/nn/modules/replicate.py +94 -0
- models/ade20k/segm_lib/nn/modules/tests/test_numeric_batchnorm.py +56 -0
- models/ade20k/segm_lib/nn/modules/tests/test_sync_batchnorm.py +111 -0
- models/ade20k/segm_lib/nn/modules/unittest.py +29 -0
- models/ade20k/segm_lib/nn/parallel/__init__.py +1 -0
README.md
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# creco-inference
|
| 2 |
+
Unified inference code for SageMaker and Hugging Face endpoints
|
| 3 |
+
|
| 4 |
+
## Deployment
|
| 5 |
+
|
| 6 |
+
- Inference code (this) should be placed in the model folder respectively,
|
| 7 |
+
|
| 8 |
+
### SageMaker
|
| 9 |
+
|
| 10 |
+
```
|
| 11 |
+
model/
|
| 12 |
+
code/
|
| 13 |
+
(repo) <-- The repo inference code as direct child (no sub-folder)
|
| 14 |
+
vae
|
| 15 |
+
unet
|
| 16 |
+
...
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
- Refer `deployment.ipynb` for creating endpoint.
|
| 20 |
+
|
| 21 |
+
### Hugging Face
|
| 22 |
+
|
| 23 |
+
```
|
| 24 |
+
model/
|
| 25 |
+
(repo) <-- The repo inference code as direct child (no sub-folder)
|
| 26 |
+
vae
|
| 27 |
+
unet
|
| 28 |
+
...
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
- Refer [doc](https://huggingface.co/docs/inference-endpoints/guides/create_endpoint) to create endpoint.
|
config.yml
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
run_title: b18_ffc075_batch8x15
|
| 2 |
+
training_model:
|
| 3 |
+
kind: default
|
| 4 |
+
visualize_each_iters: 1000
|
| 5 |
+
concat_mask: true
|
| 6 |
+
store_discr_outputs_for_vis: true
|
| 7 |
+
losses:
|
| 8 |
+
l1:
|
| 9 |
+
weight_missing: 0
|
| 10 |
+
weight_known: 10
|
| 11 |
+
perceptual:
|
| 12 |
+
weight: 0
|
| 13 |
+
adversarial:
|
| 14 |
+
kind: r1
|
| 15 |
+
weight: 10
|
| 16 |
+
gp_coef: 0.001
|
| 17 |
+
mask_as_fake_target: true
|
| 18 |
+
allow_scale_mask: true
|
| 19 |
+
feature_matching:
|
| 20 |
+
weight: 100
|
| 21 |
+
resnet_pl:
|
| 22 |
+
weight: 30
|
| 23 |
+
weights_path: ${env:TORCH_HOME}
|
| 24 |
+
|
| 25 |
+
optimizers:
|
| 26 |
+
generator:
|
| 27 |
+
kind: adam
|
| 28 |
+
lr: 0.001
|
| 29 |
+
discriminator:
|
| 30 |
+
kind: adam
|
| 31 |
+
lr: 0.0001
|
| 32 |
+
visualizer:
|
| 33 |
+
key_order:
|
| 34 |
+
- image
|
| 35 |
+
- predicted_image
|
| 36 |
+
- discr_output_fake
|
| 37 |
+
- discr_output_real
|
| 38 |
+
- inpainted
|
| 39 |
+
rescale_keys:
|
| 40 |
+
- discr_output_fake
|
| 41 |
+
- discr_output_real
|
| 42 |
+
kind: directory
|
| 43 |
+
outdir: /group-volume/User-Driven-Content-Generation/r.suvorov/inpainting/experiments/r.suvorov_2021-04-30_14-41-12_train_simple_pix2pix2_gap_sdpl_novgg_large_b18_ffc075_batch8x15/samples
|
| 44 |
+
location:
|
| 45 |
+
data_root_dir: /group-volume/User-Driven-Content-Generation/datasets/inpainting_data_root_large
|
| 46 |
+
out_root_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/experiments
|
| 47 |
+
tb_dir: /group-volume/User-Driven-Content-Generation/${env:USER}/inpainting/tb_logs
|
| 48 |
+
data:
|
| 49 |
+
batch_size: 15
|
| 50 |
+
val_batch_size: 2
|
| 51 |
+
num_workers: 3
|
| 52 |
+
train:
|
| 53 |
+
indir: ${location.data_root_dir}/train
|
| 54 |
+
out_size: 256
|
| 55 |
+
mask_gen_kwargs:
|
| 56 |
+
irregular_proba: 1
|
| 57 |
+
irregular_kwargs:
|
| 58 |
+
max_angle: 4
|
| 59 |
+
max_len: 200
|
| 60 |
+
max_width: 100
|
| 61 |
+
max_times: 5
|
| 62 |
+
min_times: 1
|
| 63 |
+
box_proba: 1
|
| 64 |
+
box_kwargs:
|
| 65 |
+
margin: 10
|
| 66 |
+
bbox_min_size: 30
|
| 67 |
+
bbox_max_size: 150
|
| 68 |
+
max_times: 3
|
| 69 |
+
min_times: 1
|
| 70 |
+
segm_proba: 0
|
| 71 |
+
segm_kwargs:
|
| 72 |
+
confidence_threshold: 0.5
|
| 73 |
+
max_object_area: 0.5
|
| 74 |
+
min_mask_area: 0.07
|
| 75 |
+
downsample_levels: 6
|
| 76 |
+
num_variants_per_mask: 1
|
| 77 |
+
rigidness_mode: 1
|
| 78 |
+
max_foreground_coverage: 0.3
|
| 79 |
+
max_foreground_intersection: 0.7
|
| 80 |
+
max_mask_intersection: 0.1
|
| 81 |
+
max_hidden_area: 0.1
|
| 82 |
+
max_scale_change: 0.25
|
| 83 |
+
horizontal_flip: true
|
| 84 |
+
max_vertical_shift: 0.2
|
| 85 |
+
position_shuffle: true
|
| 86 |
+
transform_variant: distortions
|
| 87 |
+
dataloader_kwargs:
|
| 88 |
+
batch_size: ${data.batch_size}
|
| 89 |
+
shuffle: true
|
| 90 |
+
num_workers: ${data.num_workers}
|
| 91 |
+
val:
|
| 92 |
+
indir: ${location.data_root_dir}/val
|
| 93 |
+
img_suffix: .png
|
| 94 |
+
dataloader_kwargs:
|
| 95 |
+
batch_size: ${data.val_batch_size}
|
| 96 |
+
shuffle: false
|
| 97 |
+
num_workers: ${data.num_workers}
|
| 98 |
+
visual_test:
|
| 99 |
+
indir: ${location.data_root_dir}/korean_test
|
| 100 |
+
img_suffix: _input.png
|
| 101 |
+
pad_out_to_modulo: 32
|
| 102 |
+
dataloader_kwargs:
|
| 103 |
+
batch_size: 1
|
| 104 |
+
shuffle: false
|
| 105 |
+
num_workers: ${data.num_workers}
|
| 106 |
+
generator:
|
| 107 |
+
kind: ffc_resnet
|
| 108 |
+
input_nc: 4
|
| 109 |
+
output_nc: 3
|
| 110 |
+
ngf: 64
|
| 111 |
+
n_downsampling: 3
|
| 112 |
+
n_blocks: 18
|
| 113 |
+
add_out_act: sigmoid
|
| 114 |
+
init_conv_kwargs:
|
| 115 |
+
ratio_gin: 0
|
| 116 |
+
ratio_gout: 0
|
| 117 |
+
enable_lfu: false
|
| 118 |
+
downsample_conv_kwargs:
|
| 119 |
+
ratio_gin: ${generator.init_conv_kwargs.ratio_gout}
|
| 120 |
+
ratio_gout: ${generator.downsample_conv_kwargs.ratio_gin}
|
| 121 |
+
enable_lfu: false
|
| 122 |
+
resnet_conv_kwargs:
|
| 123 |
+
ratio_gin: 0.75
|
| 124 |
+
ratio_gout: ${generator.resnet_conv_kwargs.ratio_gin}
|
| 125 |
+
enable_lfu: false
|
| 126 |
+
discriminator:
|
| 127 |
+
kind: pix2pixhd_nlayer
|
| 128 |
+
input_nc: 3
|
| 129 |
+
ndf: 64
|
| 130 |
+
n_layers: 4
|
| 131 |
+
evaluator:
|
| 132 |
+
kind: default
|
| 133 |
+
inpainted_key: inpainted
|
| 134 |
+
integral_kind: ssim_fid100_f1
|
| 135 |
+
trainer:
|
| 136 |
+
kwargs:
|
| 137 |
+
gpus: -1
|
| 138 |
+
accelerator: ddp
|
| 139 |
+
max_epochs: 200
|
| 140 |
+
gradient_clip_val: 1
|
| 141 |
+
log_gpu_memory: None
|
| 142 |
+
limit_train_batches: 25000
|
| 143 |
+
val_check_interval: ${trainer.kwargs.limit_train_batches}
|
| 144 |
+
log_every_n_steps: 1000
|
| 145 |
+
precision: 32
|
| 146 |
+
terminate_on_nan: false
|
| 147 |
+
check_val_every_n_epoch: 1
|
| 148 |
+
num_sanity_val_steps: 8
|
| 149 |
+
limit_val_batches: 1000
|
| 150 |
+
replace_sampler_ddp: false
|
| 151 |
+
checkpoint_kwargs:
|
| 152 |
+
verbose: true
|
| 153 |
+
save_top_k: 5
|
| 154 |
+
save_last: true
|
| 155 |
+
period: 1
|
| 156 |
+
monitor: val_ssim_fid100_f1_total_mean
|
| 157 |
+
mode: max
|
deployment.ipynb
ADDED
|
@@ -0,0 +1,995 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "code",
|
| 5 |
+
"execution_count": 1,
|
| 6 |
+
"id": "5af7e53b-80ff-4058-888d-fe41804f64ba",
|
| 7 |
+
"metadata": {
|
| 8 |
+
"scrolled": true,
|
| 9 |
+
"tags": []
|
| 10 |
+
},
|
| 11 |
+
"outputs": [
|
| 12 |
+
{
|
| 13 |
+
"name": "stdout",
|
| 14 |
+
"output_type": "stream",
|
| 15 |
+
"text": [
|
| 16 |
+
"Looking in indexes: https://pypi.org/simple, https://pip.repos.neuron.amazonaws.com\n",
|
| 17 |
+
"Requirement already satisfied: pip in /home/ec2-user/anaconda3/envs/pytorch_p39/lib/python3.9/site-packages (23.1.2)\n"
|
| 18 |
+
]
|
| 19 |
+
}
|
| 20 |
+
],
|
| 21 |
+
"source": [
|
| 22 |
+
"!pip install --upgrade pip\n",
|
| 23 |
+
"!pip install \"sagemaker==2.116.0\" \"huggingface_hub==0.10.1\" --upgrade --quiet"
|
| 24 |
+
]
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"cell_type": "code",
|
| 28 |
+
"execution_count": 5,
|
| 29 |
+
"id": "93ee3d96-400f-46b4-8eb3-0f3f3c853a7e",
|
| 30 |
+
"metadata": {
|
| 31 |
+
"tags": []
|
| 32 |
+
},
|
| 33 |
+
"outputs": [],
|
| 34 |
+
"source": [
|
| 35 |
+
"from distutils.dir_util import copy_tree\n",
|
| 36 |
+
"from pathlib import Path\n",
|
| 37 |
+
"from huggingface_hub import snapshot_download\n",
|
| 38 |
+
"import random\n",
|
| 39 |
+
"import os\n",
|
| 40 |
+
"import tarfile\n",
|
| 41 |
+
"import time\n",
|
| 42 |
+
"import sagemaker\n",
|
| 43 |
+
"from datetime import datetime\n",
|
| 44 |
+
"from sagemaker.s3 import S3Uploader\n",
|
| 45 |
+
"import boto3\n",
|
| 46 |
+
"from sagemaker.huggingface.model import HuggingFaceModel\n",
|
| 47 |
+
"from threading import Thread\n",
|
| 48 |
+
"import subprocess\n",
|
| 49 |
+
"import shutil"
|
| 50 |
+
]
|
| 51 |
+
},
|
| 52 |
+
{
|
| 53 |
+
"cell_type": "code",
|
| 54 |
+
"execution_count": 2,
|
| 55 |
+
"id": "2db37b03-b517-46bc-8602-4999a64399c0",
|
| 56 |
+
"metadata": {
|
| 57 |
+
"tags": []
|
| 58 |
+
},
|
| 59 |
+
"outputs": [],
|
| 60 |
+
"source": [
|
| 61 |
+
"# ------------------------------------------------\n",
|
| 62 |
+
"# Configuration\n",
|
| 63 |
+
"# ------------------------------------------------\n",
|
| 64 |
+
"STAGE = \"prod\"\n",
|
| 65 |
+
"model_configs = [\n",
|
| 66 |
+
" # {\n",
|
| 67 |
+
" # \"inference_2\": False, \n",
|
| 68 |
+
" # \"path\": \"icbinp\",\n",
|
| 69 |
+
" # \"endpoint_name\": \"gamma-10000-2023-05-16-14-55\"\n",
|
| 70 |
+
" # #\"endpoint_name\": f\"{STAGE}-10000-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n",
|
| 71 |
+
" # },\n",
|
| 72 |
+
" # {\n",
|
| 73 |
+
" # \"inference_2\": False, \n",
|
| 74 |
+
" # \"path\": \"icb_with_epi\",\n",
|
| 75 |
+
" # \"endpoint_name\": \"gamma-10000-2023-05-16-14-55\"\n",
|
| 76 |
+
" # # \"endpoint_name\": f\"{STAGE}-10000-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n",
|
| 77 |
+
" # },\n",
|
| 78 |
+
" {\n",
|
| 79 |
+
" \"inference_2\": False, \n",
|
| 80 |
+
" \"path\": \"model_v9\",\n",
|
| 81 |
+
" # \"endpoint_name\": \"gamma-10000-2023-05-16-14-55\"\n",
|
| 82 |
+
" \"endpoint_name\": f\"{STAGE}-10000-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n",
|
| 83 |
+
" },\n",
|
| 84 |
+
" {\n",
|
| 85 |
+
" \"inference_2\": False, \n",
|
| 86 |
+
" \"path\": \"model_v8\",\n",
|
| 87 |
+
" #\"endpoint_name\": \"gamma-10001-2023-05-08-06-14\"\n",
|
| 88 |
+
" \"endpoint_name\": f\"{STAGE}-10001-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n",
|
| 89 |
+
" },\n",
|
| 90 |
+
" # {\n",
|
| 91 |
+
" # \"inference_2\": False, \n",
|
| 92 |
+
" # \"path\": \"model_v5_anime\",\n",
|
| 93 |
+
" # \"endpoint_name\": \"gamma-10001-2023-05-08-06-14\"\n",
|
| 94 |
+
" # #\"endpoint_name\": f\"{STAGE}-10001-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n",
|
| 95 |
+
" # },\n",
|
| 96 |
+
" # {\n",
|
| 97 |
+
" # \"inference_2\": False, \n",
|
| 98 |
+
" # \"path\": \"model_v5.3_comic\",\n",
|
| 99 |
+
" # #\"endpoint_name\": \"gamma-10002-2023-05-08-07-22\"\n",
|
| 100 |
+
" # \"endpoint_name\": f\"{STAGE}-10002-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n",
|
| 101 |
+
" # },\n",
|
| 102 |
+
" {\n",
|
| 103 |
+
" \"inference_2\": False, \n",
|
| 104 |
+
" \"path\": \"model_v10\",\n",
|
| 105 |
+
" # \"endpoint_name\": \"gamma-10002-2023-05-08-07-22\"\n",
|
| 106 |
+
" \"endpoint_name\": f\"{STAGE}-10002-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n",
|
| 107 |
+
" },\n",
|
| 108 |
+
" {\n",
|
| 109 |
+
" \"inference_2\": True, \n",
|
| 110 |
+
" \"path\": \"model_v5.2_other\",\n",
|
| 111 |
+
" # \"endpoint_name\": \"gamma-other-2023-05-04-09-33\"\n",
|
| 112 |
+
" \"endpoint_name\": f\"{STAGE}-other-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n",
|
| 113 |
+
" }\n",
|
| 114 |
+
" # {\n",
|
| 115 |
+
" # \"inference_2\": False, \n",
|
| 116 |
+
" # \"path\": \"model_v6_bheem\",\n",
|
| 117 |
+
" # \"endpoint_name\": f\"{STAGE}-10003-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n",
|
| 118 |
+
" # },\n",
|
| 119 |
+
" # {\n",
|
| 120 |
+
" # \"inference_2\": False, \n",
|
| 121 |
+
" # \"path\": \"model_v12\",\n",
|
| 122 |
+
" # \"endpoint_name\": \"gamma-10003-2023-05-04-05-20\"\n",
|
| 123 |
+
" # # \"endpoint_name\": f\"{STAGE}-10003-\" + datetime.now().strftime(\"%Y-%m-%d-%H-%M\")\n",
|
| 124 |
+
" # }\n",
|
| 125 |
+
"]\n",
|
| 126 |
+
"\n",
|
| 127 |
+
"VpcConfig = {\n",
|
| 128 |
+
" \"Subnets\": [\n",
|
| 129 |
+
" \"subnet-0df3f71df4c7b29e5\",\n",
|
| 130 |
+
" \"subnet-0d753b7fc74b5ee68\"\n",
|
| 131 |
+
" ],\n",
|
| 132 |
+
" \"SecurityGroupIds\": [\n",
|
| 133 |
+
" \"sg-033a7948e79a501cd\"\n",
|
| 134 |
+
" ]\n",
|
| 135 |
+
"}"
|
| 136 |
+
]
|
| 137 |
+
},
|
| 138 |
+
{
|
| 139 |
+
"cell_type": "code",
|
| 140 |
+
"execution_count": 3,
|
| 141 |
+
"id": "d7322ac4-aeeb-4a72-a662-5f3fa74e6454",
|
| 142 |
+
"metadata": {
|
| 143 |
+
"tags": []
|
| 144 |
+
},
|
| 145 |
+
"outputs": [],
|
| 146 |
+
"source": [
|
| 147 |
+
"def compress(tar_dir=None,output_file=\"model.tar.gz\"):\n",
|
| 148 |
+
" parent_dir=os.getcwd()\n",
|
| 149 |
+
" os.chdir(parent_dir + \"/\" + tar_dir)\n",
|
| 150 |
+
" with tarfile.open(os.path.join(parent_dir, output_file), \"w:gz\") as tar:\n",
|
| 151 |
+
" for item in os.listdir('.'):\n",
|
| 152 |
+
" print(\"- \" + item)\n",
|
| 153 |
+
" tar.add(item, arcname=item)\n",
|
| 154 |
+
" os.chdir(parent_dir)\n",
|
| 155 |
+
"\n",
|
| 156 |
+
" \n",
|
| 157 |
+
"def create_model_tar(config):\n",
|
| 158 |
+
" print(\"Copying inference 'code': \" + config.get(\"path\"))\n",
|
| 159 |
+
" \n",
|
| 160 |
+
" model_tar = Path(config.get(\"path\"))\n",
|
| 161 |
+
" if os.path.exists(model_tar.joinpath(\"code\")):\n",
|
| 162 |
+
" shutil.rmtree(model_tar.joinpath(\"code\"))\n",
|
| 163 |
+
" out_tar = config.get(\"path\") + \".tar.gz\"\n",
|
| 164 |
+
" model_tar.mkdir(exist_ok=True)\n",
|
| 165 |
+
" copy_tree(\"code/\", str(model_tar.joinpath(\"code\")))\n",
|
| 166 |
+
" copy_tree(\"laur_style/\", str(model_tar.joinpath(\"laur_style\")))\n",
|
| 167 |
+
" \n",
|
| 168 |
+
" if config.get(\"inference_2\"):\n",
|
| 169 |
+
" os.remove(model_tar.joinpath(\"code\").joinpath(\"inference.py\"))\n",
|
| 170 |
+
" os.rename(model_tar.joinpath(\"code\").joinpath(\"inference2.py\"), model_tar.joinpath(\"code\").joinpath(\"inference.py\"))\n",
|
| 171 |
+
" \n",
|
| 172 |
+
" print(\"Compressing: \" + config.get(\"path\"))\n",
|
| 173 |
+
"\n",
|
| 174 |
+
" if os.path.exists(out_tar):\n",
|
| 175 |
+
" os.remove(out_tar)\n",
|
| 176 |
+
"\n",
|
| 177 |
+
" compress(str(model_tar), out_tar)\n",
|
| 178 |
+
" \n",
|
| 179 |
+
"def upload_to_s3(config):\n",
|
| 180 |
+
" out_tar = config.get(\"path\") + \".tar.gz\"\n",
|
| 181 |
+
" print(\"Uploading model to S3: \" + out_tar)\n",
|
| 182 |
+
" s3_model_uri=S3Uploader.upload(local_path=out_tar, desired_s3_uri=f\"s3://comic-assets/stable-diffusion-v1-4/v2/\")\n",
|
| 183 |
+
" return s3_model_uri\n",
|
| 184 |
+
" \n",
|
| 185 |
+
" \n",
|
| 186 |
+
"def deploy_and_create_endpoint(config, s3_model_uri):\n",
|
| 187 |
+
" sess = sagemaker.Session()\n",
|
| 188 |
+
" # sagemaker session bucket -> used for uploading data, models and logs\n",
|
| 189 |
+
" # sagemaker will automatically create this bucket if it not exists\n",
|
| 190 |
+
" sagemaker_session_bucket=None\n",
|
| 191 |
+
" if sagemaker_session_bucket is None and sess is not None:\n",
|
| 192 |
+
" # set to default bucket if a bucket name is not given\n",
|
| 193 |
+
" sagemaker_session_bucket = sess.default_bucket()\n",
|
| 194 |
+
" try:\n",
|
| 195 |
+
" role = sagemaker.get_execution_role()\n",
|
| 196 |
+
" except ValueError:\n",
|
| 197 |
+
" iam = boto3.client('iam')\n",
|
| 198 |
+
" role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']\n",
|
| 199 |
+
"\n",
|
| 200 |
+
" sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)\n",
|
| 201 |
+
" \n",
|
| 202 |
+
" huggingface_model = HuggingFaceModel(\n",
|
| 203 |
+
" model_data=s3_model_uri, # path to your model and script\n",
|
| 204 |
+
" role=role, # iam role with permissions to create an Endpoint\n",
|
| 205 |
+
" transformers_version=\"4.17\", # transformers version used\n",
|
| 206 |
+
" pytorch_version=\"1.10\", # pytorch version used\n",
|
| 207 |
+
" py_version='py38',# python version used\n",
|
| 208 |
+
" vpc_config=VpcConfig,\n",
|
| 209 |
+
" )\n",
|
| 210 |
+
"\n",
|
| 211 |
+
" print(\"Creating endpoint: \" + config.get(\"endpoint_name\"))\n",
|
| 212 |
+
"\n",
|
| 213 |
+
" predictor = huggingface_model.deploy(\n",
|
| 214 |
+
" initial_instance_count=1,\n",
|
| 215 |
+
" instance_type=\"ml.g4dn.xlarge\",\n",
|
| 216 |
+
" endpoint_name=config.get(\"endpoint_name\")\n",
|
| 217 |
+
" )\n",
|
| 218 |
+
"\n",
|
| 219 |
+
" \n",
|
| 220 |
+
"def start_process(config):\n",
|
| 221 |
+
" try:\n",
|
| 222 |
+
" create_model_tar(config)\n",
|
| 223 |
+
" s3_model_uri = upload_to_s3(config)\n",
|
| 224 |
+
" #s3_model_uri = \"s3://comic-assets/stable-diffusion-v1-4/v2//model_v5.2_other.tar.gz\"\n",
|
| 225 |
+
" deploy_and_create_endpoint(config, s3_model_uri)\n",
|
| 226 |
+
" except Exception as e:\n",
|
| 227 |
+
" print(\"Failed to deploy: \" + config.get(\"path\") + \"\\n\" + str(e))"
|
| 228 |
+
]
|
| 229 |
+
},
|
| 230 |
+
{
|
| 231 |
+
"cell_type": "code",
|
| 232 |
+
"execution_count": 4,
|
| 233 |
+
"id": "cdc04669-90a5-4b43-8499-ad1d2dd63a4c",
|
| 234 |
+
"metadata": {
|
| 235 |
+
"tags": []
|
| 236 |
+
},
|
| 237 |
+
"outputs": [
|
| 238 |
+
{
|
| 239 |
+
"name": "stdout",
|
| 240 |
+
"output_type": "stream",
|
| 241 |
+
"text": [
|
| 242 |
+
"Copying inference 'code': model_v9\n",
|
| 243 |
+
"Compressing: model_v9\n",
|
| 244 |
+
"- scheduler\n",
|
| 245 |
+
"- vae\n",
|
| 246 |
+
"- .ipynb_checkpoints\n",
|
| 247 |
+
"- feature_extractor\n",
|
| 248 |
+
"- tokenizer\n",
|
| 249 |
+
"- text_encoder\n",
|
| 250 |
+
"- model_index.json\n",
|
| 251 |
+
"- laur_style\n",
|
| 252 |
+
"- code\n",
|
| 253 |
+
"- unet\n",
|
| 254 |
+
"- args.json\n",
|
| 255 |
+
"Uploading model to S3: model_v9.tar.gz\n",
|
| 256 |
+
"Creating endpoint: gamma-10000-2023-05-16-14-55\n",
|
| 257 |
+
"-----------------!\n",
|
| 258 |
+
"\n",
|
| 259 |
+
"Completed in : 992.3517553806305s\n"
|
| 260 |
+
]
|
| 261 |
+
}
|
| 262 |
+
],
|
| 263 |
+
"source": [
|
| 264 |
+
"threads = []\n",
|
| 265 |
+
"\n",
|
| 266 |
+
"os.chdir(\"/home/ec2-user/SageMaker\")\n",
|
| 267 |
+
"\n",
|
| 268 |
+
"start_time = time.time()\n",
|
| 269 |
+
"\n",
|
| 270 |
+
"for config in model_configs:\n",
|
| 271 |
+
" thread = Thread(target=start_process, args=(config,))\n",
|
| 272 |
+
" thread.start()\n",
|
| 273 |
+
" thread.join()\n",
|
| 274 |
+
" threads.append(thread)\n",
|
| 275 |
+
"\n",
|
| 276 |
+
"for thread in threads:\n",
|
| 277 |
+
" thread.join()\n",
|
| 278 |
+
" \n",
|
| 279 |
+
"print(\"\\n\\nCompleted in : \" + str(time.time() - start_time) + \"s\")\n",
|
| 280 |
+
"\n",
|
| 281 |
+
"# For redeploying gamma endpoints or promoting gamma endpoints to prod\n",
|
| 282 |
+
"\n",
|
| 283 |
+
"# thread1 = Thread(target=deploy_and_create_endpoint, args=(model_configs[0],\"s3://comic-assets/stable-diffusion-v1-4/v2//model_v9.tar.gz\",))\n",
|
| 284 |
+
"# thread2 = Thread(target=deploy_and_create_endpoint, args=(model_configs[1],\"s3://comic-assets/stable-diffusion-v1-4/v2//anime_mode_with_lora.tar.gz\",))\n",
|
| 285 |
+
"# thread3 = Thread(target=deploy_and_create_endpoint, args=(model_configs[0],\"s3://comic-assets/stable-diffusion-v1-4/v2//model_v5.3_comic.tar.gz\",))\n",
|
| 286 |
+
"# thread4 = Thread(target=deploy_and_create_endpoint, args=(model_configs[3],\"s3://comic-assets/stable-diffusion-v1-4/v2//model_v5.2_other.tar.gz\",))\n",
|
| 287 |
+
"\n",
|
| 288 |
+
"# thread1.start()\n",
|
| 289 |
+
"# thread2.start()\n",
|
| 290 |
+
"# thread3.start()\n",
|
| 291 |
+
"# thread4.start()\n",
|
| 292 |
+
"\n",
|
| 293 |
+
"# thread1.join()\n",
|
| 294 |
+
"# thread2.join()\n",
|
| 295 |
+
"# thread3.join()\n",
|
| 296 |
+
"# thread4.join()\n",
|
| 297 |
+
"\n",
|
| 298 |
+
"# print(\"Done\")\n"
|
| 299 |
+
]
|
| 300 |
+
},
|
| 301 |
+
{
|
| 302 |
+
"cell_type": "code",
|
| 303 |
+
"execution_count": null,
|
| 304 |
+
"id": "39f007f2-0ff8-487c-b5d7-158f0947b7fd",
|
| 305 |
+
"metadata": {
|
| 306 |
+
"collapsed": true,
|
| 307 |
+
"jupyter": {
|
| 308 |
+
"outputs_hidden": true
|
| 309 |
+
},
|
| 310 |
+
"tags": []
|
| 311 |
+
},
|
| 312 |
+
"outputs": [],
|
| 313 |
+
"source": [
|
| 314 |
+
"\n",
|
| 315 |
+
"# import sagemaker\n",
|
| 316 |
+
"# import boto3\n",
|
| 317 |
+
"# import time \n",
|
| 318 |
+
"\n",
|
| 319 |
+
"# start = time.time()\n",
|
| 320 |
+
"\n",
|
| 321 |
+
"# sess = sagemaker.Session()\n",
|
| 322 |
+
"# # sagemaker session bucket -> used for uploading data, models and logs\n",
|
| 323 |
+
"# # sagemaker will automatically create this bucket if it not exists\n",
|
| 324 |
+
"# sagemaker_session_bucket=None\n",
|
| 325 |
+
"# if sagemaker_session_bucket is None and sess is not None:\n",
|
| 326 |
+
"# # set to default bucket if a bucket name is not given\n",
|
| 327 |
+
"# sagemaker_session_bucket = sess.default_bucket()\n",
|
| 328 |
+
"\n",
|
| 329 |
+
"# try:\n",
|
| 330 |
+
"# role = sagemaker.get_execution_role()\n",
|
| 331 |
+
"# except ValueError:\n",
|
| 332 |
+
"# iam = boto3.client('iam')\n",
|
| 333 |
+
"# role = iam.get_role(RoleName='sagemaker_execution_role')['Role']['Arn']\n",
|
| 334 |
+
"\n",
|
| 335 |
+
"# sess = sagemaker.Session(default_bucket=sagemaker_session_bucket)\n",
|
| 336 |
+
"\n",
|
| 337 |
+
"# print(f\"sagemaker role arn: {role}\")\n",
|
| 338 |
+
"# print(f\"sagemaker bucket: {sess.default_bucket()}\")\n",
|
| 339 |
+
"# print(f\"sagemaker session region: {sess.boto_region_name}\")\n",
|
| 340 |
+
"# print(sagemaker.get_execution_role())\n",
|
| 341 |
+
"\n",
|
| 342 |
+
"# from sagemaker.s3 import S3Uploader\n",
|
| 343 |
+
"\n",
|
| 344 |
+
"# print(\"Uploading model to S3\")\n",
|
| 345 |
+
"\n",
|
| 346 |
+
"# # upload model.tar.gz to s3\n",
|
| 347 |
+
"# s3_model_uri=S3Uploader.upload(local_path=\"model.tar.gz\", desired_s3_uri=f\"s3://comic-assets/stable-diffusion-v1-4/v2/\")\n",
|
| 348 |
+
"\n",
|
| 349 |
+
"# print(f\"model uploaded to: {s3_model_uri}\")\n",
|
| 350 |
+
"\n",
|
| 351 |
+
"\n",
|
| 352 |
+
"# from sagemaker.huggingface.model import HuggingFaceModel\n",
|
| 353 |
+
"\n",
|
| 354 |
+
"# VpcConfig = {\n",
|
| 355 |
+
"# \"Subnets\": [\n",
|
| 356 |
+
"# \"subnet-0df3f71df4c7b29e5\",\n",
|
| 357 |
+
"# \"subnet-0d753b7fc74b5ee68\"\n",
|
| 358 |
+
"# ],\n",
|
| 359 |
+
"# \"SecurityGroupIds\": [\n",
|
| 360 |
+
"# \"sg-033a7948e79a501cd\"\n",
|
| 361 |
+
"# ]\n",
|
| 362 |
+
"# }\n",
|
| 363 |
+
"\n",
|
| 364 |
+
"# # create Hugging Face Model Class\n",
|
| 365 |
+
"# huggingface_model = HuggingFaceModel(\n",
|
| 366 |
+
"# model_data=s3_model_uri, # path to your model and script\n",
|
| 367 |
+
"# role=role, # iam role with permissions to create an Endpoint\n",
|
| 368 |
+
"# transformers_version=\"4.17\", # transformers version used\n",
|
| 369 |
+
"# pytorch_version=\"1.10\", # pytorch version used\n",
|
| 370 |
+
"# py_version='py38',# python version used\n",
|
| 371 |
+
"# vpc_config=VpcConfig,\n",
|
| 372 |
+
"# )\n",
|
| 373 |
+
"\n",
|
| 374 |
+
"# print(\"Deploying model\")\n",
|
| 375 |
+
"\n",
|
| 376 |
+
"# predictor = huggingface_model.deploy(\n",
|
| 377 |
+
"# initial_instance_count=1,\n",
|
| 378 |
+
"# instance_type=\"ml.g4dn.xlarge\",\n",
|
| 379 |
+
"# # endpoint_name=endpoint_name\n",
|
| 380 |
+
"# )\n",
|
| 381 |
+
"\n",
|
| 382 |
+
"# print(f\"Done {time.time() - start}\")"
|
| 383 |
+
]
|
| 384 |
+
},
|
| 385 |
+
{
|
| 386 |
+
"cell_type": "code",
|
| 387 |
+
"execution_count": null,
|
| 388 |
+
"id": "aa95a262-d6ba-4e61-8657-6f8e5bab74a1",
|
| 389 |
+
"metadata": {
|
| 390 |
+
"tags": []
|
| 391 |
+
},
|
| 392 |
+
"outputs": [],
|
| 393 |
+
"source": [
|
| 394 |
+
"!curl -s https://packagecloud.io/install/repositories/github/git-lfs/script.rpm.sh | sudo bash"
|
| 395 |
+
]
|
| 396 |
+
},
|
| 397 |
+
{
|
| 398 |
+
"cell_type": "code",
|
| 399 |
+
"execution_count": null,
|
| 400 |
+
"id": "524ca546-2a67-4b51-9cda-a1b51a49c339",
|
| 401 |
+
"metadata": {
|
| 402 |
+
"tags": []
|
| 403 |
+
},
|
| 404 |
+
"outputs": [],
|
| 405 |
+
"source": [
|
| 406 |
+
"!sudo yum install git-lfs"
|
| 407 |
+
]
|
| 408 |
+
},
|
| 409 |
+
{
|
| 410 |
+
"cell_type": "code",
|
| 411 |
+
"execution_count": null,
|
| 412 |
+
"id": "3c7e661f-5eee-4357-80f6-e7563941a812",
|
| 413 |
+
"metadata": {},
|
| 414 |
+
"outputs": [],
|
| 415 |
+
"source": []
|
| 416 |
+
}
|
| 417 |
+
],
|
| 418 |
+
"metadata": {
|
| 419 |
+
"availableInstances": [
|
| 420 |
+
{
|
| 421 |
+
"_defaultOrder": 0,
|
| 422 |
+
"_isFastLaunch": true,
|
| 423 |
+
"category": "General purpose",
|
| 424 |
+
"gpuNum": 0,
|
| 425 |
+
"hideHardwareSpecs": false,
|
| 426 |
+
"memoryGiB": 4,
|
| 427 |
+
"name": "ml.t3.medium",
|
| 428 |
+
"vcpuNum": 2
|
| 429 |
+
},
|
| 430 |
+
{
|
| 431 |
+
"_defaultOrder": 1,
|
| 432 |
+
"_isFastLaunch": false,
|
| 433 |
+
"category": "General purpose",
|
| 434 |
+
"gpuNum": 0,
|
| 435 |
+
"hideHardwareSpecs": false,
|
| 436 |
+
"memoryGiB": 8,
|
| 437 |
+
"name": "ml.t3.large",
|
| 438 |
+
"vcpuNum": 2
|
| 439 |
+
},
|
| 440 |
+
{
|
| 441 |
+
"_defaultOrder": 2,
|
| 442 |
+
"_isFastLaunch": false,
|
| 443 |
+
"category": "General purpose",
|
| 444 |
+
"gpuNum": 0,
|
| 445 |
+
"hideHardwareSpecs": false,
|
| 446 |
+
"memoryGiB": 16,
|
| 447 |
+
"name": "ml.t3.xlarge",
|
| 448 |
+
"vcpuNum": 4
|
| 449 |
+
},
|
| 450 |
+
{
|
| 451 |
+
"_defaultOrder": 3,
|
| 452 |
+
"_isFastLaunch": false,
|
| 453 |
+
"category": "General purpose",
|
| 454 |
+
"gpuNum": 0,
|
| 455 |
+
"hideHardwareSpecs": false,
|
| 456 |
+
"memoryGiB": 32,
|
| 457 |
+
"name": "ml.t3.2xlarge",
|
| 458 |
+
"vcpuNum": 8
|
| 459 |
+
},
|
| 460 |
+
{
|
| 461 |
+
"_defaultOrder": 4,
|
| 462 |
+
"_isFastLaunch": true,
|
| 463 |
+
"category": "General purpose",
|
| 464 |
+
"gpuNum": 0,
|
| 465 |
+
"hideHardwareSpecs": false,
|
| 466 |
+
"memoryGiB": 8,
|
| 467 |
+
"name": "ml.m5.large",
|
| 468 |
+
"vcpuNum": 2
|
| 469 |
+
},
|
| 470 |
+
{
|
| 471 |
+
"_defaultOrder": 5,
|
| 472 |
+
"_isFastLaunch": false,
|
| 473 |
+
"category": "General purpose",
|
| 474 |
+
"gpuNum": 0,
|
| 475 |
+
"hideHardwareSpecs": false,
|
| 476 |
+
"memoryGiB": 16,
|
| 477 |
+
"name": "ml.m5.xlarge",
|
| 478 |
+
"vcpuNum": 4
|
| 479 |
+
},
|
| 480 |
+
{
|
| 481 |
+
"_defaultOrder": 6,
|
| 482 |
+
"_isFastLaunch": false,
|
| 483 |
+
"category": "General purpose",
|
| 484 |
+
"gpuNum": 0,
|
| 485 |
+
"hideHardwareSpecs": false,
|
| 486 |
+
"memoryGiB": 32,
|
| 487 |
+
"name": "ml.m5.2xlarge",
|
| 488 |
+
"vcpuNum": 8
|
| 489 |
+
},
|
| 490 |
+
{
|
| 491 |
+
"_defaultOrder": 7,
|
| 492 |
+
"_isFastLaunch": false,
|
| 493 |
+
"category": "General purpose",
|
| 494 |
+
"gpuNum": 0,
|
| 495 |
+
"hideHardwareSpecs": false,
|
| 496 |
+
"memoryGiB": 64,
|
| 497 |
+
"name": "ml.m5.4xlarge",
|
| 498 |
+
"vcpuNum": 16
|
| 499 |
+
},
|
| 500 |
+
{
|
| 501 |
+
"_defaultOrder": 8,
|
| 502 |
+
"_isFastLaunch": false,
|
| 503 |
+
"category": "General purpose",
|
| 504 |
+
"gpuNum": 0,
|
| 505 |
+
"hideHardwareSpecs": false,
|
| 506 |
+
"memoryGiB": 128,
|
| 507 |
+
"name": "ml.m5.8xlarge",
|
| 508 |
+
"vcpuNum": 32
|
| 509 |
+
},
|
| 510 |
+
{
|
| 511 |
+
"_defaultOrder": 9,
|
| 512 |
+
"_isFastLaunch": false,
|
| 513 |
+
"category": "General purpose",
|
| 514 |
+
"gpuNum": 0,
|
| 515 |
+
"hideHardwareSpecs": false,
|
| 516 |
+
"memoryGiB": 192,
|
| 517 |
+
"name": "ml.m5.12xlarge",
|
| 518 |
+
"vcpuNum": 48
|
| 519 |
+
},
|
| 520 |
+
{
|
| 521 |
+
"_defaultOrder": 10,
|
| 522 |
+
"_isFastLaunch": false,
|
| 523 |
+
"category": "General purpose",
|
| 524 |
+
"gpuNum": 0,
|
| 525 |
+
"hideHardwareSpecs": false,
|
| 526 |
+
"memoryGiB": 256,
|
| 527 |
+
"name": "ml.m5.16xlarge",
|
| 528 |
+
"vcpuNum": 64
|
| 529 |
+
},
|
| 530 |
+
{
|
| 531 |
+
"_defaultOrder": 11,
|
| 532 |
+
"_isFastLaunch": false,
|
| 533 |
+
"category": "General purpose",
|
| 534 |
+
"gpuNum": 0,
|
| 535 |
+
"hideHardwareSpecs": false,
|
| 536 |
+
"memoryGiB": 384,
|
| 537 |
+
"name": "ml.m5.24xlarge",
|
| 538 |
+
"vcpuNum": 96
|
| 539 |
+
},
|
| 540 |
+
{
|
| 541 |
+
"_defaultOrder": 12,
|
| 542 |
+
"_isFastLaunch": false,
|
| 543 |
+
"category": "General purpose",
|
| 544 |
+
"gpuNum": 0,
|
| 545 |
+
"hideHardwareSpecs": false,
|
| 546 |
+
"memoryGiB": 8,
|
| 547 |
+
"name": "ml.m5d.large",
|
| 548 |
+
"vcpuNum": 2
|
| 549 |
+
},
|
| 550 |
+
{
|
| 551 |
+
"_defaultOrder": 13,
|
| 552 |
+
"_isFastLaunch": false,
|
| 553 |
+
"category": "General purpose",
|
| 554 |
+
"gpuNum": 0,
|
| 555 |
+
"hideHardwareSpecs": false,
|
| 556 |
+
"memoryGiB": 16,
|
| 557 |
+
"name": "ml.m5d.xlarge",
|
| 558 |
+
"vcpuNum": 4
|
| 559 |
+
},
|
| 560 |
+
{
|
| 561 |
+
"_defaultOrder": 14,
|
| 562 |
+
"_isFastLaunch": false,
|
| 563 |
+
"category": "General purpose",
|
| 564 |
+
"gpuNum": 0,
|
| 565 |
+
"hideHardwareSpecs": false,
|
| 566 |
+
"memoryGiB": 32,
|
| 567 |
+
"name": "ml.m5d.2xlarge",
|
| 568 |
+
"vcpuNum": 8
|
| 569 |
+
},
|
| 570 |
+
{
|
| 571 |
+
"_defaultOrder": 15,
|
| 572 |
+
"_isFastLaunch": false,
|
| 573 |
+
"category": "General purpose",
|
| 574 |
+
"gpuNum": 0,
|
| 575 |
+
"hideHardwareSpecs": false,
|
| 576 |
+
"memoryGiB": 64,
|
| 577 |
+
"name": "ml.m5d.4xlarge",
|
| 578 |
+
"vcpuNum": 16
|
| 579 |
+
},
|
| 580 |
+
{
|
| 581 |
+
"_defaultOrder": 16,
|
| 582 |
+
"_isFastLaunch": false,
|
| 583 |
+
"category": "General purpose",
|
| 584 |
+
"gpuNum": 0,
|
| 585 |
+
"hideHardwareSpecs": false,
|
| 586 |
+
"memoryGiB": 128,
|
| 587 |
+
"name": "ml.m5d.8xlarge",
|
| 588 |
+
"vcpuNum": 32
|
| 589 |
+
},
|
| 590 |
+
{
|
| 591 |
+
"_defaultOrder": 17,
|
| 592 |
+
"_isFastLaunch": false,
|
| 593 |
+
"category": "General purpose",
|
| 594 |
+
"gpuNum": 0,
|
| 595 |
+
"hideHardwareSpecs": false,
|
| 596 |
+
"memoryGiB": 192,
|
| 597 |
+
"name": "ml.m5d.12xlarge",
|
| 598 |
+
"vcpuNum": 48
|
| 599 |
+
},
|
| 600 |
+
{
|
| 601 |
+
"_defaultOrder": 18,
|
| 602 |
+
"_isFastLaunch": false,
|
| 603 |
+
"category": "General purpose",
|
| 604 |
+
"gpuNum": 0,
|
| 605 |
+
"hideHardwareSpecs": false,
|
| 606 |
+
"memoryGiB": 256,
|
| 607 |
+
"name": "ml.m5d.16xlarge",
|
| 608 |
+
"vcpuNum": 64
|
| 609 |
+
},
|
| 610 |
+
{
|
| 611 |
+
"_defaultOrder": 19,
|
| 612 |
+
"_isFastLaunch": false,
|
| 613 |
+
"category": "General purpose",
|
| 614 |
+
"gpuNum": 0,
|
| 615 |
+
"hideHardwareSpecs": false,
|
| 616 |
+
"memoryGiB": 384,
|
| 617 |
+
"name": "ml.m5d.24xlarge",
|
| 618 |
+
"vcpuNum": 96
|
| 619 |
+
},
|
| 620 |
+
{
|
| 621 |
+
"_defaultOrder": 20,
|
| 622 |
+
"_isFastLaunch": false,
|
| 623 |
+
"category": "General purpose",
|
| 624 |
+
"gpuNum": 0,
|
| 625 |
+
"hideHardwareSpecs": true,
|
| 626 |
+
"memoryGiB": 0,
|
| 627 |
+
"name": "ml.geospatial.interactive",
|
| 628 |
+
"supportedImageNames": [
|
| 629 |
+
"sagemaker-geospatial-v1-0"
|
| 630 |
+
],
|
| 631 |
+
"vcpuNum": 0
|
| 632 |
+
},
|
| 633 |
+
{
|
| 634 |
+
"_defaultOrder": 21,
|
| 635 |
+
"_isFastLaunch": true,
|
| 636 |
+
"category": "Compute optimized",
|
| 637 |
+
"gpuNum": 0,
|
| 638 |
+
"hideHardwareSpecs": false,
|
| 639 |
+
"memoryGiB": 4,
|
| 640 |
+
"name": "ml.c5.large",
|
| 641 |
+
"vcpuNum": 2
|
| 642 |
+
},
|
| 643 |
+
{
|
| 644 |
+
"_defaultOrder": 22,
|
| 645 |
+
"_isFastLaunch": false,
|
| 646 |
+
"category": "Compute optimized",
|
| 647 |
+
"gpuNum": 0,
|
| 648 |
+
"hideHardwareSpecs": false,
|
| 649 |
+
"memoryGiB": 8,
|
| 650 |
+
"name": "ml.c5.xlarge",
|
| 651 |
+
"vcpuNum": 4
|
| 652 |
+
},
|
| 653 |
+
{
|
| 654 |
+
"_defaultOrder": 23,
|
| 655 |
+
"_isFastLaunch": false,
|
| 656 |
+
"category": "Compute optimized",
|
| 657 |
+
"gpuNum": 0,
|
| 658 |
+
"hideHardwareSpecs": false,
|
| 659 |
+
"memoryGiB": 16,
|
| 660 |
+
"name": "ml.c5.2xlarge",
|
| 661 |
+
"vcpuNum": 8
|
| 662 |
+
},
|
| 663 |
+
{
|
| 664 |
+
"_defaultOrder": 24,
|
| 665 |
+
"_isFastLaunch": false,
|
| 666 |
+
"category": "Compute optimized",
|
| 667 |
+
"gpuNum": 0,
|
| 668 |
+
"hideHardwareSpecs": false,
|
| 669 |
+
"memoryGiB": 32,
|
| 670 |
+
"name": "ml.c5.4xlarge",
|
| 671 |
+
"vcpuNum": 16
|
| 672 |
+
},
|
| 673 |
+
{
|
| 674 |
+
"_defaultOrder": 25,
|
| 675 |
+
"_isFastLaunch": false,
|
| 676 |
+
"category": "Compute optimized",
|
| 677 |
+
"gpuNum": 0,
|
| 678 |
+
"hideHardwareSpecs": false,
|
| 679 |
+
"memoryGiB": 72,
|
| 680 |
+
"name": "ml.c5.9xlarge",
|
| 681 |
+
"vcpuNum": 36
|
| 682 |
+
},
|
| 683 |
+
{
|
| 684 |
+
"_defaultOrder": 26,
|
| 685 |
+
"_isFastLaunch": false,
|
| 686 |
+
"category": "Compute optimized",
|
| 687 |
+
"gpuNum": 0,
|
| 688 |
+
"hideHardwareSpecs": false,
|
| 689 |
+
"memoryGiB": 96,
|
| 690 |
+
"name": "ml.c5.12xlarge",
|
| 691 |
+
"vcpuNum": 48
|
| 692 |
+
},
|
| 693 |
+
{
|
| 694 |
+
"_defaultOrder": 27,
|
| 695 |
+
"_isFastLaunch": false,
|
| 696 |
+
"category": "Compute optimized",
|
| 697 |
+
"gpuNum": 0,
|
| 698 |
+
"hideHardwareSpecs": false,
|
| 699 |
+
"memoryGiB": 144,
|
| 700 |
+
"name": "ml.c5.18xlarge",
|
| 701 |
+
"vcpuNum": 72
|
| 702 |
+
},
|
| 703 |
+
{
|
| 704 |
+
"_defaultOrder": 28,
|
| 705 |
+
"_isFastLaunch": false,
|
| 706 |
+
"category": "Compute optimized",
|
| 707 |
+
"gpuNum": 0,
|
| 708 |
+
"hideHardwareSpecs": false,
|
| 709 |
+
"memoryGiB": 192,
|
| 710 |
+
"name": "ml.c5.24xlarge",
|
| 711 |
+
"vcpuNum": 96
|
| 712 |
+
},
|
| 713 |
+
{
|
| 714 |
+
"_defaultOrder": 29,
|
| 715 |
+
"_isFastLaunch": true,
|
| 716 |
+
"category": "Accelerated computing",
|
| 717 |
+
"gpuNum": 1,
|
| 718 |
+
"hideHardwareSpecs": false,
|
| 719 |
+
"memoryGiB": 16,
|
| 720 |
+
"name": "ml.g4dn.xlarge",
|
| 721 |
+
"vcpuNum": 4
|
| 722 |
+
},
|
| 723 |
+
{
|
| 724 |
+
"_defaultOrder": 30,
|
| 725 |
+
"_isFastLaunch": false,
|
| 726 |
+
"category": "Accelerated computing",
|
| 727 |
+
"gpuNum": 1,
|
| 728 |
+
"hideHardwareSpecs": false,
|
| 729 |
+
"memoryGiB": 32,
|
| 730 |
+
"name": "ml.g4dn.2xlarge",
|
| 731 |
+
"vcpuNum": 8
|
| 732 |
+
},
|
| 733 |
+
{
|
| 734 |
+
"_defaultOrder": 31,
|
| 735 |
+
"_isFastLaunch": false,
|
| 736 |
+
"category": "Accelerated computing",
|
| 737 |
+
"gpuNum": 1,
|
| 738 |
+
"hideHardwareSpecs": false,
|
| 739 |
+
"memoryGiB": 64,
|
| 740 |
+
"name": "ml.g4dn.4xlarge",
|
| 741 |
+
"vcpuNum": 16
|
| 742 |
+
},
|
| 743 |
+
{
|
| 744 |
+
"_defaultOrder": 32,
|
| 745 |
+
"_isFastLaunch": false,
|
| 746 |
+
"category": "Accelerated computing",
|
| 747 |
+
"gpuNum": 1,
|
| 748 |
+
"hideHardwareSpecs": false,
|
| 749 |
+
"memoryGiB": 128,
|
| 750 |
+
"name": "ml.g4dn.8xlarge",
|
| 751 |
+
"vcpuNum": 32
|
| 752 |
+
},
|
| 753 |
+
{
|
| 754 |
+
"_defaultOrder": 33,
|
| 755 |
+
"_isFastLaunch": false,
|
| 756 |
+
"category": "Accelerated computing",
|
| 757 |
+
"gpuNum": 4,
|
| 758 |
+
"hideHardwareSpecs": false,
|
| 759 |
+
"memoryGiB": 192,
|
| 760 |
+
"name": "ml.g4dn.12xlarge",
|
| 761 |
+
"vcpuNum": 48
|
| 762 |
+
},
|
| 763 |
+
{
|
| 764 |
+
"_defaultOrder": 34,
|
| 765 |
+
"_isFastLaunch": false,
|
| 766 |
+
"category": "Accelerated computing",
|
| 767 |
+
"gpuNum": 1,
|
| 768 |
+
"hideHardwareSpecs": false,
|
| 769 |
+
"memoryGiB": 256,
|
| 770 |
+
"name": "ml.g4dn.16xlarge",
|
| 771 |
+
"vcpuNum": 64
|
| 772 |
+
},
|
| 773 |
+
{
|
| 774 |
+
"_defaultOrder": 35,
|
| 775 |
+
"_isFastLaunch": false,
|
| 776 |
+
"category": "Accelerated computing",
|
| 777 |
+
"gpuNum": 1,
|
| 778 |
+
"hideHardwareSpecs": false,
|
| 779 |
+
"memoryGiB": 61,
|
| 780 |
+
"name": "ml.p3.2xlarge",
|
| 781 |
+
"vcpuNum": 8
|
| 782 |
+
},
|
| 783 |
+
{
|
| 784 |
+
"_defaultOrder": 36,
|
| 785 |
+
"_isFastLaunch": false,
|
| 786 |
+
"category": "Accelerated computing",
|
| 787 |
+
"gpuNum": 4,
|
| 788 |
+
"hideHardwareSpecs": false,
|
| 789 |
+
"memoryGiB": 244,
|
| 790 |
+
"name": "ml.p3.8xlarge",
|
| 791 |
+
"vcpuNum": 32
|
| 792 |
+
},
|
| 793 |
+
{
|
| 794 |
+
"_defaultOrder": 37,
|
| 795 |
+
"_isFastLaunch": false,
|
| 796 |
+
"category": "Accelerated computing",
|
| 797 |
+
"gpuNum": 8,
|
| 798 |
+
"hideHardwareSpecs": false,
|
| 799 |
+
"memoryGiB": 488,
|
| 800 |
+
"name": "ml.p3.16xlarge",
|
| 801 |
+
"vcpuNum": 64
|
| 802 |
+
},
|
| 803 |
+
{
|
| 804 |
+
"_defaultOrder": 38,
|
| 805 |
+
"_isFastLaunch": false,
|
| 806 |
+
"category": "Accelerated computing",
|
| 807 |
+
"gpuNum": 8,
|
| 808 |
+
"hideHardwareSpecs": false,
|
| 809 |
+
"memoryGiB": 768,
|
| 810 |
+
"name": "ml.p3dn.24xlarge",
|
| 811 |
+
"vcpuNum": 96
|
| 812 |
+
},
|
| 813 |
+
{
|
| 814 |
+
"_defaultOrder": 39,
|
| 815 |
+
"_isFastLaunch": false,
|
| 816 |
+
"category": "Memory Optimized",
|
| 817 |
+
"gpuNum": 0,
|
| 818 |
+
"hideHardwareSpecs": false,
|
| 819 |
+
"memoryGiB": 16,
|
| 820 |
+
"name": "ml.r5.large",
|
| 821 |
+
"vcpuNum": 2
|
| 822 |
+
},
|
| 823 |
+
{
|
| 824 |
+
"_defaultOrder": 40,
|
| 825 |
+
"_isFastLaunch": false,
|
| 826 |
+
"category": "Memory Optimized",
|
| 827 |
+
"gpuNum": 0,
|
| 828 |
+
"hideHardwareSpecs": false,
|
| 829 |
+
"memoryGiB": 32,
|
| 830 |
+
"name": "ml.r5.xlarge",
|
| 831 |
+
"vcpuNum": 4
|
| 832 |
+
},
|
| 833 |
+
{
|
| 834 |
+
"_defaultOrder": 41,
|
| 835 |
+
"_isFastLaunch": false,
|
| 836 |
+
"category": "Memory Optimized",
|
| 837 |
+
"gpuNum": 0,
|
| 838 |
+
"hideHardwareSpecs": false,
|
| 839 |
+
"memoryGiB": 64,
|
| 840 |
+
"name": "ml.r5.2xlarge",
|
| 841 |
+
"vcpuNum": 8
|
| 842 |
+
},
|
| 843 |
+
{
|
| 844 |
+
"_defaultOrder": 42,
|
| 845 |
+
"_isFastLaunch": false,
|
| 846 |
+
"category": "Memory Optimized",
|
| 847 |
+
"gpuNum": 0,
|
| 848 |
+
"hideHardwareSpecs": false,
|
| 849 |
+
"memoryGiB": 128,
|
| 850 |
+
"name": "ml.r5.4xlarge",
|
| 851 |
+
"vcpuNum": 16
|
| 852 |
+
},
|
| 853 |
+
{
|
| 854 |
+
"_defaultOrder": 43,
|
| 855 |
+
"_isFastLaunch": false,
|
| 856 |
+
"category": "Memory Optimized",
|
| 857 |
+
"gpuNum": 0,
|
| 858 |
+
"hideHardwareSpecs": false,
|
| 859 |
+
"memoryGiB": 256,
|
| 860 |
+
"name": "ml.r5.8xlarge",
|
| 861 |
+
"vcpuNum": 32
|
| 862 |
+
},
|
| 863 |
+
{
|
| 864 |
+
"_defaultOrder": 44,
|
| 865 |
+
"_isFastLaunch": false,
|
| 866 |
+
"category": "Memory Optimized",
|
| 867 |
+
"gpuNum": 0,
|
| 868 |
+
"hideHardwareSpecs": false,
|
| 869 |
+
"memoryGiB": 384,
|
| 870 |
+
"name": "ml.r5.12xlarge",
|
| 871 |
+
"vcpuNum": 48
|
| 872 |
+
},
|
| 873 |
+
{
|
| 874 |
+
"_defaultOrder": 45,
|
| 875 |
+
"_isFastLaunch": false,
|
| 876 |
+
"category": "Memory Optimized",
|
| 877 |
+
"gpuNum": 0,
|
| 878 |
+
"hideHardwareSpecs": false,
|
| 879 |
+
"memoryGiB": 512,
|
| 880 |
+
"name": "ml.r5.16xlarge",
|
| 881 |
+
"vcpuNum": 64
|
| 882 |
+
},
|
| 883 |
+
{
|
| 884 |
+
"_defaultOrder": 46,
|
| 885 |
+
"_isFastLaunch": false,
|
| 886 |
+
"category": "Memory Optimized",
|
| 887 |
+
"gpuNum": 0,
|
| 888 |
+
"hideHardwareSpecs": false,
|
| 889 |
+
"memoryGiB": 768,
|
| 890 |
+
"name": "ml.r5.24xlarge",
|
| 891 |
+
"vcpuNum": 96
|
| 892 |
+
},
|
| 893 |
+
{
|
| 894 |
+
"_defaultOrder": 47,
|
| 895 |
+
"_isFastLaunch": false,
|
| 896 |
+
"category": "Accelerated computing",
|
| 897 |
+
"gpuNum": 1,
|
| 898 |
+
"hideHardwareSpecs": false,
|
| 899 |
+
"memoryGiB": 16,
|
| 900 |
+
"name": "ml.g5.xlarge",
|
| 901 |
+
"vcpuNum": 4
|
| 902 |
+
},
|
| 903 |
+
{
|
| 904 |
+
"_defaultOrder": 48,
|
| 905 |
+
"_isFastLaunch": false,
|
| 906 |
+
"category": "Accelerated computing",
|
| 907 |
+
"gpuNum": 1,
|
| 908 |
+
"hideHardwareSpecs": false,
|
| 909 |
+
"memoryGiB": 32,
|
| 910 |
+
"name": "ml.g5.2xlarge",
|
| 911 |
+
"vcpuNum": 8
|
| 912 |
+
},
|
| 913 |
+
{
|
| 914 |
+
"_defaultOrder": 49,
|
| 915 |
+
"_isFastLaunch": false,
|
| 916 |
+
"category": "Accelerated computing",
|
| 917 |
+
"gpuNum": 1,
|
| 918 |
+
"hideHardwareSpecs": false,
|
| 919 |
+
"memoryGiB": 64,
|
| 920 |
+
"name": "ml.g5.4xlarge",
|
| 921 |
+
"vcpuNum": 16
|
| 922 |
+
},
|
| 923 |
+
{
|
| 924 |
+
"_defaultOrder": 50,
|
| 925 |
+
"_isFastLaunch": false,
|
| 926 |
+
"category": "Accelerated computing",
|
| 927 |
+
"gpuNum": 1,
|
| 928 |
+
"hideHardwareSpecs": false,
|
| 929 |
+
"memoryGiB": 128,
|
| 930 |
+
"name": "ml.g5.8xlarge",
|
| 931 |
+
"vcpuNum": 32
|
| 932 |
+
},
|
| 933 |
+
{
|
| 934 |
+
"_defaultOrder": 51,
|
| 935 |
+
"_isFastLaunch": false,
|
| 936 |
+
"category": "Accelerated computing",
|
| 937 |
+
"gpuNum": 1,
|
| 938 |
+
"hideHardwareSpecs": false,
|
| 939 |
+
"memoryGiB": 256,
|
| 940 |
+
"name": "ml.g5.16xlarge",
|
| 941 |
+
"vcpuNum": 64
|
| 942 |
+
},
|
| 943 |
+
{
|
| 944 |
+
"_defaultOrder": 52,
|
| 945 |
+
"_isFastLaunch": false,
|
| 946 |
+
"category": "Accelerated computing",
|
| 947 |
+
"gpuNum": 4,
|
| 948 |
+
"hideHardwareSpecs": false,
|
| 949 |
+
"memoryGiB": 192,
|
| 950 |
+
"name": "ml.g5.12xlarge",
|
| 951 |
+
"vcpuNum": 48
|
| 952 |
+
},
|
| 953 |
+
{
|
| 954 |
+
"_defaultOrder": 53,
|
| 955 |
+
"_isFastLaunch": false,
|
| 956 |
+
"category": "Accelerated computing",
|
| 957 |
+
"gpuNum": 4,
|
| 958 |
+
"hideHardwareSpecs": false,
|
| 959 |
+
"memoryGiB": 384,
|
| 960 |
+
"name": "ml.g5.24xlarge",
|
| 961 |
+
"vcpuNum": 96
|
| 962 |
+
},
|
| 963 |
+
{
|
| 964 |
+
"_defaultOrder": 54,
|
| 965 |
+
"_isFastLaunch": false,
|
| 966 |
+
"category": "Accelerated computing",
|
| 967 |
+
"gpuNum": 8,
|
| 968 |
+
"hideHardwareSpecs": false,
|
| 969 |
+
"memoryGiB": 768,
|
| 970 |
+
"name": "ml.g5.48xlarge",
|
| 971 |
+
"vcpuNum": 192
|
| 972 |
+
}
|
| 973 |
+
],
|
| 974 |
+
"instance_type": "ml.t3.medium",
|
| 975 |
+
"kernelspec": {
|
| 976 |
+
"display_name": "conda_pytorch_p39",
|
| 977 |
+
"language": "python",
|
| 978 |
+
"name": "conda_pytorch_p39"
|
| 979 |
+
},
|
| 980 |
+
"language_info": {
|
| 981 |
+
"codemirror_mode": {
|
| 982 |
+
"name": "ipython",
|
| 983 |
+
"version": 3
|
| 984 |
+
},
|
| 985 |
+
"file_extension": ".py",
|
| 986 |
+
"mimetype": "text/x-python",
|
| 987 |
+
"name": "python",
|
| 988 |
+
"nbconvert_exporter": "python",
|
| 989 |
+
"pygments_lexer": "ipython3",
|
| 990 |
+
"version": "3.9.15"
|
| 991 |
+
}
|
| 992 |
+
},
|
| 993 |
+
"nbformat": 4,
|
| 994 |
+
"nbformat_minor": 5
|
| 995 |
+
}
|
handler.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List
|
| 2 |
+
|
| 3 |
+
from inference import model_fn, predict_fn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class EndpointHandler:
|
| 7 |
+
def __init__(self, path=""):
|
| 8 |
+
return model_fn(path)
|
| 9 |
+
|
| 10 |
+
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
| 11 |
+
return predict_fn(data, None)
|
inference.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from internals.data.dataAccessor import update_db
|
| 6 |
+
from internals.data.task import Task, TaskType
|
| 7 |
+
from internals.pipelines.commons import Img2Img, Text2Img
|
| 8 |
+
from internals.pipelines.controlnets import ControlNet
|
| 9 |
+
from internals.pipelines.img_classifier import ImageClassifier
|
| 10 |
+
from internals.pipelines.img_to_text import Image2Text
|
| 11 |
+
from internals.pipelines.prompt_modifier import PromptModifier
|
| 12 |
+
from internals.pipelines.safety_checker import SafetyChecker
|
| 13 |
+
from internals.util.args import apply_style_args
|
| 14 |
+
from internals.util.avatar import Avatar
|
| 15 |
+
from internals.util.cache import auto_clear_cuda_and_gc
|
| 16 |
+
from internals.util.commons import pickPoses, upload_image, upload_images
|
| 17 |
+
from internals.util.config import set_configs_from_task, set_root_dir
|
| 18 |
+
from internals.util.failure_hander import FailureHandler
|
| 19 |
+
from internals.util.lora_style import LoraStyle
|
| 20 |
+
from internals.util.slack import Slack
|
| 21 |
+
|
| 22 |
+
torch.backends.cudnn.benchmark = True
|
| 23 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 24 |
+
|
| 25 |
+
num_return_sequences = 4 # the number of results to generate
|
| 26 |
+
auto_mode = False
|
| 27 |
+
|
| 28 |
+
prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
|
| 29 |
+
img2text = Image2Text()
|
| 30 |
+
img_classifier = ImageClassifier()
|
| 31 |
+
controlnet = ControlNet()
|
| 32 |
+
lora_style = LoraStyle()
|
| 33 |
+
text2img_pipe = Text2Img()
|
| 34 |
+
img2img_pipe = Img2Img()
|
| 35 |
+
safety_checker = SafetyChecker()
|
| 36 |
+
slack = Slack()
|
| 37 |
+
avatar = Avatar()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_patched_prompt(task: Task):
|
| 41 |
+
def add_style_and_character(prompt: List[str], additional: Optional[str] = None):
|
| 42 |
+
for i in range(len(prompt)):
|
| 43 |
+
prompt[i] = avatar.add_code_names(prompt[i])
|
| 44 |
+
prompt[i] = lora_style.prepend_style_to_prompt(prompt[i], task.get_style())
|
| 45 |
+
if additional:
|
| 46 |
+
prompt[i] = additional + " " + prompt[i]
|
| 47 |
+
|
| 48 |
+
prompt = task.get_prompt()
|
| 49 |
+
|
| 50 |
+
if task.is_prompt_engineering():
|
| 51 |
+
prompt = prompt_modifier.modify(prompt)
|
| 52 |
+
else:
|
| 53 |
+
prompt = [prompt] * num_return_sequences
|
| 54 |
+
|
| 55 |
+
ori_prompt = [task.get_prompt()] * num_return_sequences
|
| 56 |
+
|
| 57 |
+
class_name = None
|
| 58 |
+
# if task.get_imageUrl():
|
| 59 |
+
# class_name = img_classifier.classify(
|
| 60 |
+
# task.get_imageUrl(), task.get_width(), task.get_height()
|
| 61 |
+
# )
|
| 62 |
+
add_style_and_character(ori_prompt, class_name)
|
| 63 |
+
add_style_and_character(prompt, class_name)
|
| 64 |
+
|
| 65 |
+
print({"prompts": prompt})
|
| 66 |
+
|
| 67 |
+
return (prompt, ori_prompt)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def get_patched_prompt_tile_upscale(task: Task):
|
| 71 |
+
if task.get_prompt():
|
| 72 |
+
prompt = task.get_prompt()
|
| 73 |
+
else:
|
| 74 |
+
prompt = img2text.process(task.get_imageUrl())
|
| 75 |
+
|
| 76 |
+
prompt = avatar.add_code_names(prompt)
|
| 77 |
+
prompt = lora_style.prepend_style_to_prompt(prompt, task.get_style())
|
| 78 |
+
|
| 79 |
+
class_name = img_classifier.classify(
|
| 80 |
+
task.get_imageUrl(), task.get_width(), task.get_height()
|
| 81 |
+
)
|
| 82 |
+
prompt = class_name + " " + prompt
|
| 83 |
+
|
| 84 |
+
print({"prompt": prompt})
|
| 85 |
+
|
| 86 |
+
return prompt
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@update_db
|
| 90 |
+
@auto_clear_cuda_and_gc(controlnet)
|
| 91 |
+
@slack.auto_send_alert
|
| 92 |
+
def canny(task: Task):
|
| 93 |
+
prompt, _ = get_patched_prompt(task)
|
| 94 |
+
|
| 95 |
+
controlnet.load_canny()
|
| 96 |
+
|
| 97 |
+
# pipe2 is used for canny and pose
|
| 98 |
+
lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
|
| 99 |
+
lora_patcher.patch()
|
| 100 |
+
|
| 101 |
+
images, has_nsfw = controlnet.process_canny(
|
| 102 |
+
prompt=prompt,
|
| 103 |
+
imageUrl=task.get_imageUrl(),
|
| 104 |
+
seed=task.get_seed(),
|
| 105 |
+
steps=task.get_steps(),
|
| 106 |
+
width=task.get_width(),
|
| 107 |
+
height=task.get_height(),
|
| 108 |
+
guidance_scale=task.get_cy_guidance_scale(),
|
| 109 |
+
negative_prompt=[
|
| 110 |
+
f"monochrome, neon, x-ray, negative image, oversaturated, {task.get_negative_prompt()}"
|
| 111 |
+
]
|
| 112 |
+
* num_return_sequences,
|
| 113 |
+
**lora_patcher.kwargs(),
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
generated_image_urls = upload_images(images, "_canny", task.get_taskId())
|
| 117 |
+
|
| 118 |
+
lora_patcher.cleanup()
|
| 119 |
+
controlnet.cleanup()
|
| 120 |
+
|
| 121 |
+
return {
|
| 122 |
+
"modified_prompts": prompt,
|
| 123 |
+
"generated_image_urls": generated_image_urls,
|
| 124 |
+
"has_nsfw": has_nsfw,
|
| 125 |
+
}
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
@update_db
|
| 129 |
+
@auto_clear_cuda_and_gc(controlnet)
|
| 130 |
+
@slack.auto_send_alert
|
| 131 |
+
def tile_upscale(task: Task):
|
| 132 |
+
output_key = "crecoAI/{}_tile_upscaler.png".format(task.get_taskId())
|
| 133 |
+
|
| 134 |
+
prompt = get_patched_prompt_tile_upscale(task)
|
| 135 |
+
|
| 136 |
+
controlnet.load_tile_upscaler()
|
| 137 |
+
|
| 138 |
+
lora_patcher = lora_style.get_patcher(controlnet.pipe, task.get_style())
|
| 139 |
+
lora_patcher.patch()
|
| 140 |
+
|
| 141 |
+
images, has_nsfw = controlnet.process_tile_upscaler(
|
| 142 |
+
imageUrl=task.get_imageUrl(),
|
| 143 |
+
seed=task.get_seed(),
|
| 144 |
+
steps=task.get_steps(),
|
| 145 |
+
width=task.get_width(),
|
| 146 |
+
height=task.get_height(),
|
| 147 |
+
prompt=prompt,
|
| 148 |
+
resize_dimension=task.get_resize_dimension(),
|
| 149 |
+
negative_prompt=task.get_negative_prompt(),
|
| 150 |
+
guidance_scale=task.get_ti_guidance_scale(),
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
generated_image_url = upload_image(images[0], output_key)
|
| 154 |
+
|
| 155 |
+
lora_patcher.cleanup()
|
| 156 |
+
controlnet.cleanup()
|
| 157 |
+
|
| 158 |
+
return {
|
| 159 |
+
"modified_prompts": prompt,
|
| 160 |
+
"generated_image_url": generated_image_url,
|
| 161 |
+
"has_nsfw": has_nsfw,
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
@update_db
|
| 166 |
+
@auto_clear_cuda_and_gc(controlnet)
|
| 167 |
+
@slack.auto_send_alert
|
| 168 |
+
def pose(task: Task, s3_outkey: str = "_pose", poses: Optional[list] = None):
|
| 169 |
+
prompt, _ = get_patched_prompt(task)
|
| 170 |
+
|
| 171 |
+
controlnet.load_pose()
|
| 172 |
+
|
| 173 |
+
# pipe2 is used for canny and pose
|
| 174 |
+
lora_patcher = lora_style.get_patcher(controlnet.pipe2, task.get_style())
|
| 175 |
+
lora_patcher.patch()
|
| 176 |
+
|
| 177 |
+
if poses is None:
|
| 178 |
+
poses = [controlnet.detect_pose(task.get_imageUrl())] * num_return_sequences
|
| 179 |
+
|
| 180 |
+
images, has_nsfw = controlnet.process_pose(
|
| 181 |
+
prompt=prompt,
|
| 182 |
+
image=poses,
|
| 183 |
+
seed=task.get_seed(),
|
| 184 |
+
steps=task.get_steps(),
|
| 185 |
+
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
|
| 186 |
+
width=task.get_width(),
|
| 187 |
+
height=task.get_height(),
|
| 188 |
+
guidance_scale=task.get_po_guidance_scale(),
|
| 189 |
+
**lora_patcher.kwargs(),
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
generated_image_urls = upload_images(images, s3_outkey, task.get_taskId())
|
| 193 |
+
|
| 194 |
+
lora_patcher.cleanup()
|
| 195 |
+
controlnet.cleanup()
|
| 196 |
+
|
| 197 |
+
return {
|
| 198 |
+
"modified_prompts": prompt,
|
| 199 |
+
"generated_image_urls": generated_image_urls,
|
| 200 |
+
"has_nsfw": has_nsfw,
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
@update_db
|
| 205 |
+
@auto_clear_cuda_and_gc(controlnet)
|
| 206 |
+
@slack.auto_send_alert
|
| 207 |
+
def text2img(task: Task):
|
| 208 |
+
prompt, ori_prompt = get_patched_prompt(task)
|
| 209 |
+
|
| 210 |
+
lora_patcher = lora_style.get_patcher(text2img_pipe.pipe, task.get_style())
|
| 211 |
+
lora_patcher.patch()
|
| 212 |
+
|
| 213 |
+
torch.manual_seed(task.get_seed())
|
| 214 |
+
|
| 215 |
+
images, has_nsfw = text2img_pipe.process(
|
| 216 |
+
prompt=ori_prompt,
|
| 217 |
+
modified_prompts=prompt,
|
| 218 |
+
num_inference_steps=task.get_steps(),
|
| 219 |
+
guidance_scale=7.5,
|
| 220 |
+
height=task.get_height(),
|
| 221 |
+
width=task.get_width(),
|
| 222 |
+
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
|
| 223 |
+
iteration=task.get_iteration(),
|
| 224 |
+
**lora_patcher.kwargs(),
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
generated_image_urls = upload_images(images, "", task.get_taskId())
|
| 228 |
+
|
| 229 |
+
lora_patcher.cleanup()
|
| 230 |
+
|
| 231 |
+
return {
|
| 232 |
+
"modified_prompts": prompt,
|
| 233 |
+
"generated_image_urls": generated_image_urls,
|
| 234 |
+
"has_nsfw": has_nsfw,
|
| 235 |
+
}
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
@update_db
|
| 239 |
+
@auto_clear_cuda_and_gc(controlnet)
|
| 240 |
+
@slack.auto_send_alert
|
| 241 |
+
def img2img(task: Task):
|
| 242 |
+
prompt, _ = get_patched_prompt(task)
|
| 243 |
+
|
| 244 |
+
lora_patcher = lora_style.get_patcher(img2img_pipe.pipe, task.get_style())
|
| 245 |
+
lora_patcher.patch()
|
| 246 |
+
|
| 247 |
+
torch.manual_seed(task.get_seed())
|
| 248 |
+
|
| 249 |
+
images, has_nsfw = img2img_pipe.process(
|
| 250 |
+
prompt=prompt,
|
| 251 |
+
imageUrl=task.get_imageUrl(),
|
| 252 |
+
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
|
| 253 |
+
steps=task.get_steps(),
|
| 254 |
+
width=task.get_width(),
|
| 255 |
+
height=task.get_height(),
|
| 256 |
+
strength=task.get_i2i_strength(),
|
| 257 |
+
guidance_scale=task.get_i2i_guidance_scale(),
|
| 258 |
+
**lora_patcher.kwargs(),
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
generated_image_urls = upload_images(images, "_imgtoimg", task.get_taskId())
|
| 262 |
+
|
| 263 |
+
lora_patcher.cleanup()
|
| 264 |
+
|
| 265 |
+
return {
|
| 266 |
+
"modified_prompts": prompt,
|
| 267 |
+
"generated_image_urls": generated_image_urls,
|
| 268 |
+
"has_nsfw": has_nsfw,
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def model_fn(model_dir):
|
| 273 |
+
print("Logs: model loaded .... starts")
|
| 274 |
+
|
| 275 |
+
set_root_dir(__file__)
|
| 276 |
+
|
| 277 |
+
FailureHandler.register()
|
| 278 |
+
|
| 279 |
+
avatar.load_local()
|
| 280 |
+
|
| 281 |
+
prompt_modifier.load()
|
| 282 |
+
img2text.load()
|
| 283 |
+
img_classifier.load()
|
| 284 |
+
|
| 285 |
+
lora_style.load(model_dir)
|
| 286 |
+
safety_checker.load()
|
| 287 |
+
|
| 288 |
+
controlnet.load(model_dir)
|
| 289 |
+
text2img_pipe.load(model_dir)
|
| 290 |
+
img2img_pipe.create(text2img_pipe)
|
| 291 |
+
|
| 292 |
+
safety_checker.apply(text2img_pipe)
|
| 293 |
+
safety_checker.apply(img2img_pipe)
|
| 294 |
+
safety_checker.apply(controlnet)
|
| 295 |
+
|
| 296 |
+
print("Logs: model loaded ....")
|
| 297 |
+
return
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
@FailureHandler.clear
|
| 301 |
+
def predict_fn(data, pipe):
|
| 302 |
+
task = Task(data)
|
| 303 |
+
print("task is ", data)
|
| 304 |
+
|
| 305 |
+
FailureHandler.handle(task)
|
| 306 |
+
|
| 307 |
+
try:
|
| 308 |
+
# Set set_environment
|
| 309 |
+
set_configs_from_task(task)
|
| 310 |
+
|
| 311 |
+
# Apply arguments
|
| 312 |
+
apply_style_args(data)
|
| 313 |
+
|
| 314 |
+
# Re-fetch styles
|
| 315 |
+
lora_style.fetch_styles()
|
| 316 |
+
|
| 317 |
+
# Fetch avatars
|
| 318 |
+
avatar.fetch_from_network(task.get_model_id())
|
| 319 |
+
|
| 320 |
+
task_type = task.get_type()
|
| 321 |
+
|
| 322 |
+
if task_type == TaskType.TEXT_TO_IMAGE:
|
| 323 |
+
# character sheet
|
| 324 |
+
if "character sheet" in task.get_prompt().lower():
|
| 325 |
+
return pose(task, s3_outkey="", poses=pickPoses())
|
| 326 |
+
else:
|
| 327 |
+
return text2img(task)
|
| 328 |
+
elif task_type == TaskType.IMAGE_TO_IMAGE:
|
| 329 |
+
return img2img(task)
|
| 330 |
+
elif task_type == TaskType.CANNY:
|
| 331 |
+
return canny(task)
|
| 332 |
+
elif task_type == TaskType.POSE:
|
| 333 |
+
return pose(task)
|
| 334 |
+
elif task_type == TaskType.TILE_UPSCALE:
|
| 335 |
+
return tile_upscale(task)
|
| 336 |
+
else:
|
| 337 |
+
raise Exception("Invalid task type")
|
| 338 |
+
except Exception as e:
|
| 339 |
+
print(f"Error: {e}")
|
| 340 |
+
slack.error_alert(task, e)
|
| 341 |
+
return None
|
inference2.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from io import BytesIO
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from internals.data.dataAccessor import update_db
|
| 6 |
+
from internals.data.task import ModelType, Task, TaskType
|
| 7 |
+
from internals.pipelines.inpainter import InPainter
|
| 8 |
+
from internals.pipelines.object_remove import ObjectRemoval
|
| 9 |
+
from internals.pipelines.prompt_modifier import PromptModifier
|
| 10 |
+
from internals.pipelines.remove_background import RemoveBackground
|
| 11 |
+
from internals.pipelines.safety_checker import SafetyChecker
|
| 12 |
+
from internals.pipelines.upscaler import Upscaler
|
| 13 |
+
from internals.util.avatar import Avatar
|
| 14 |
+
from internals.util.cache import clear_cuda
|
| 15 |
+
from internals.util.commons import (construct_default_s3_url, upload_image,
|
| 16 |
+
upload_images)
|
| 17 |
+
from internals.util.config import set_configs_from_task, set_root_dir
|
| 18 |
+
from internals.util.failure_hander import FailureHandler
|
| 19 |
+
from internals.util.slack import Slack
|
| 20 |
+
|
| 21 |
+
torch.backends.cudnn.benchmark = True
|
| 22 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 23 |
+
|
| 24 |
+
num_return_sequences = 4
|
| 25 |
+
auto_mode = False
|
| 26 |
+
|
| 27 |
+
slack = Slack()
|
| 28 |
+
|
| 29 |
+
prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
|
| 30 |
+
upscaler = Upscaler()
|
| 31 |
+
inpainter = InPainter()
|
| 32 |
+
safety_checker = SafetyChecker()
|
| 33 |
+
object_removal = ObjectRemoval()
|
| 34 |
+
avatar = Avatar()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@update_db
|
| 38 |
+
@slack.auto_send_alert
|
| 39 |
+
def remove_bg(task: Task):
|
| 40 |
+
remove_background = RemoveBackground()
|
| 41 |
+
output_image = remove_background.remove(task.get_imageUrl())
|
| 42 |
+
|
| 43 |
+
output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId())
|
| 44 |
+
upload_image(output_image, output_key)
|
| 45 |
+
|
| 46 |
+
return {"generated_image_url": construct_default_s3_url(output_key)}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
@update_db
|
| 50 |
+
@slack.auto_send_alert
|
| 51 |
+
def inpaint(task: Task):
|
| 52 |
+
prompt = avatar.add_code_names(task.get_prompt())
|
| 53 |
+
if task.is_prompt_engineering():
|
| 54 |
+
prompt = prompt_modifier.modify(prompt)
|
| 55 |
+
else:
|
| 56 |
+
prompt = [prompt] * num_return_sequences
|
| 57 |
+
|
| 58 |
+
print({"prompts": prompt})
|
| 59 |
+
|
| 60 |
+
images = inpainter.process(
|
| 61 |
+
prompt=prompt,
|
| 62 |
+
image_url=task.get_imageUrl(),
|
| 63 |
+
mask_image_url=task.get_maskImageUrl(),
|
| 64 |
+
width=task.get_width(),
|
| 65 |
+
height=task.get_height(),
|
| 66 |
+
seed=task.get_seed(),
|
| 67 |
+
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
|
| 68 |
+
)
|
| 69 |
+
generated_image_urls = upload_images(images, "_inpaint", task.get_taskId())
|
| 70 |
+
|
| 71 |
+
clear_cuda()
|
| 72 |
+
|
| 73 |
+
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@update_db
|
| 77 |
+
@slack.auto_send_alert
|
| 78 |
+
def remove_object(task: Task):
|
| 79 |
+
output_key = "crecoAI/{}_object_remove.png".format(task.get_taskId())
|
| 80 |
+
|
| 81 |
+
images = object_removal.process(
|
| 82 |
+
image_url=task.get_imageUrl(),
|
| 83 |
+
mask_image_url=task.get_maskImageUrl(),
|
| 84 |
+
seed=task.get_seed(),
|
| 85 |
+
width=task.get_width(),
|
| 86 |
+
height=task.get_height(),
|
| 87 |
+
)
|
| 88 |
+
generated_image_urls = upload_image(images[0], output_key)
|
| 89 |
+
|
| 90 |
+
clear_cuda()
|
| 91 |
+
|
| 92 |
+
return {"generated_image_urls": generated_image_urls}
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
@update_db
|
| 96 |
+
@slack.auto_send_alert
|
| 97 |
+
def upscale_image(task: Task):
|
| 98 |
+
output_key = "crecoAI/{}_upscale.png".format(task.get_taskId())
|
| 99 |
+
out_img = None
|
| 100 |
+
if task.get_modelType() == ModelType.ANIME:
|
| 101 |
+
print("Using Anime model")
|
| 102 |
+
out_img = upscaler.upscale_anime(
|
| 103 |
+
image=task.get_imageUrl(), resize_dimension=task.get_resize_dimension()
|
| 104 |
+
)
|
| 105 |
+
else:
|
| 106 |
+
print("Using Real model")
|
| 107 |
+
out_img = upscaler.upscale(
|
| 108 |
+
image=task.get_imageUrl(), resize_dimension=task.get_resize_dimension()
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
upload_image(BytesIO(out_img), output_key)
|
| 112 |
+
return {"generated_image_url": construct_default_s3_url(output_key)}
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def model_fn(model_dir):
|
| 116 |
+
print("Logs: model loaded .... starts")
|
| 117 |
+
|
| 118 |
+
set_root_dir(__file__)
|
| 119 |
+
|
| 120 |
+
FailureHandler.register()
|
| 121 |
+
|
| 122 |
+
avatar.load_local()
|
| 123 |
+
|
| 124 |
+
prompt_modifier.load()
|
| 125 |
+
safety_checker.load()
|
| 126 |
+
|
| 127 |
+
object_removal.load(model_dir)
|
| 128 |
+
upscaler.load()
|
| 129 |
+
inpainter.load()
|
| 130 |
+
|
| 131 |
+
safety_checker.apply(inpainter)
|
| 132 |
+
|
| 133 |
+
print("Logs: model loaded ....")
|
| 134 |
+
return
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
@FailureHandler.clear
|
| 138 |
+
def predict_fn(data, pipe):
|
| 139 |
+
task = Task(data)
|
| 140 |
+
print("task is ", data)
|
| 141 |
+
|
| 142 |
+
FailureHandler.handle(task)
|
| 143 |
+
|
| 144 |
+
# Set set_environment
|
| 145 |
+
set_configs_from_task(task)
|
| 146 |
+
|
| 147 |
+
try:
|
| 148 |
+
# Set set_environment
|
| 149 |
+
set_configs_from_task(task)
|
| 150 |
+
|
| 151 |
+
# Fetch avatars
|
| 152 |
+
avatar.fetch_from_network(task.get_model_id())
|
| 153 |
+
|
| 154 |
+
task_type = task.get_type()
|
| 155 |
+
|
| 156 |
+
if task_type == TaskType.REMOVE_BG:
|
| 157 |
+
return remove_bg(task)
|
| 158 |
+
elif task_type == TaskType.INPAINT:
|
| 159 |
+
return inpaint(task)
|
| 160 |
+
elif task_type == TaskType.UPSCALE_IMAGE:
|
| 161 |
+
return upscale_image(task)
|
| 162 |
+
elif task_type == TaskType.OBJECT_REMOVAL:
|
| 163 |
+
return remove_object(task)
|
| 164 |
+
else:
|
| 165 |
+
raise Exception("Invalid task type")
|
| 166 |
+
except Exception as e:
|
| 167 |
+
print(f"Error: {e}")
|
| 168 |
+
slack.error_alert(task, e)
|
| 169 |
+
return None
|
internals/__init__.py
ADDED
|
File without changes
|
internals/data/__init__.py
ADDED
|
File without changes
|
internals/data/dataAccessor.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import traceback
|
| 2 |
+
from typing import Dict, List, Optional
|
| 3 |
+
|
| 4 |
+
import requests
|
| 5 |
+
from pydash import includes
|
| 6 |
+
|
| 7 |
+
from internals.data.task import Task
|
| 8 |
+
from internals.util.config import api_endpoint, api_headers
|
| 9 |
+
from internals.util.slack import Slack
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def updateSource(sourceId, userId, state):
|
| 13 |
+
print("update source is called")
|
| 14 |
+
url = api_endpoint() + f"/comic-crecoai/source/{sourceId}"
|
| 15 |
+
headers = {
|
| 16 |
+
"Content-Type": "application/json",
|
| 17 |
+
"user-id": str(userId),
|
| 18 |
+
**api_headers(),
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
data = {"state": state}
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
response = requests.patch(url, headers=headers, json=data, timeout=10)
|
| 25 |
+
print("update source response", response)
|
| 26 |
+
except requests.exceptions.Timeout:
|
| 27 |
+
print("Request timed out while updating source")
|
| 28 |
+
except requests.exceptions.RequestException as e:
|
| 29 |
+
print(f"Error while updating source: {e}")
|
| 30 |
+
|
| 31 |
+
return
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def saveGeneratedImages(sourceId, userId, has_nsfw: bool):
|
| 35 |
+
print("save generation called")
|
| 36 |
+
url = api_endpoint() + "/comic-crecoai/source/" + str(sourceId) + "/generatedImages"
|
| 37 |
+
headers = {
|
| 38 |
+
"Content-Type": "application/json",
|
| 39 |
+
"user-id": str(userId),
|
| 40 |
+
**api_headers(),
|
| 41 |
+
}
|
| 42 |
+
data = {"state": "ACTIVE", "has_nsfw": has_nsfw}
|
| 43 |
+
|
| 44 |
+
try:
|
| 45 |
+
requests.patch(url, headers=headers, json=data)
|
| 46 |
+
# print("save generation response", response)
|
| 47 |
+
except requests.exceptions.Timeout:
|
| 48 |
+
print("Request timed out while saving image")
|
| 49 |
+
except requests.exceptions.RequestException as e:
|
| 50 |
+
print("Failed to mark source as active: ", e)
|
| 51 |
+
return
|
| 52 |
+
return
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def getStyles() -> Optional[Dict]:
|
| 56 |
+
url = api_endpoint() + "/comic-crecoai/style"
|
| 57 |
+
try:
|
| 58 |
+
response = requests.get(
|
| 59 |
+
url,
|
| 60 |
+
timeout=10,
|
| 61 |
+
headers={"x-api-key": "kGyEMp)oHB(zf^E5>-{o]I%go", **api_headers()},
|
| 62 |
+
)
|
| 63 |
+
return response.json()
|
| 64 |
+
except requests.exceptions.Timeout:
|
| 65 |
+
print("Request timed out while fetching styles")
|
| 66 |
+
except requests.exceptions.RequestException as e:
|
| 67 |
+
print(f"Error while fetching styles: {e}")
|
| 68 |
+
return None
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def getCharacters(model_id: str) -> Optional[List]:
|
| 72 |
+
url = api_endpoint() + "/comic-crecoai/model/{}".format(model_id)
|
| 73 |
+
try:
|
| 74 |
+
response = requests.get(url, timeout=10, headers=api_headers())
|
| 75 |
+
response = response.json()
|
| 76 |
+
response = response["data"]["characters"]
|
| 77 |
+
return response
|
| 78 |
+
except requests.exceptions.Timeout:
|
| 79 |
+
print("Request timed out while fetching characters")
|
| 80 |
+
except Exception as e:
|
| 81 |
+
print(f"Error while fetching characters: {e}")
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def update_db(func):
|
| 86 |
+
def caller(*args, **kwargs):
|
| 87 |
+
if type(args[0]) is not Task:
|
| 88 |
+
raise Exception("First argument must be a Task object")
|
| 89 |
+
task = args[0]
|
| 90 |
+
try:
|
| 91 |
+
updateSource(task.get_sourceId(), task.get_userId(), "INPROGRESS")
|
| 92 |
+
rargs = func(*args, **kwargs)
|
| 93 |
+
has_nsfw = rargs.get("has_nsfw", False)
|
| 94 |
+
updateSource(task.get_sourceId(), task.get_userId(), "COMPLETED")
|
| 95 |
+
saveGeneratedImages(task.get_sourceId(), task.get_userId(), has_nsfw)
|
| 96 |
+
return rargs
|
| 97 |
+
except Exception as e:
|
| 98 |
+
print("Error processing image: {}".format(str(e)))
|
| 99 |
+
traceback.print_exc()
|
| 100 |
+
slack = Slack()
|
| 101 |
+
slack.error_alert(task, e)
|
| 102 |
+
updateSource(task.get_sourceId(), task.get_userId(), "FAILED")
|
| 103 |
+
|
| 104 |
+
return caller
|
internals/data/result.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from internals.util.config import get_nsfw_access
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Result:
|
| 5 |
+
images, nsfw = None, None
|
| 6 |
+
|
| 7 |
+
def __init__(self, images, nsfw):
|
| 8 |
+
self.images = images
|
| 9 |
+
self.nsfw = nsfw
|
| 10 |
+
|
| 11 |
+
@staticmethod
|
| 12 |
+
def from_result(result):
|
| 13 |
+
has_nsfw = result.nsfw_content_detected
|
| 14 |
+
if has_nsfw and isinstance(has_nsfw, list):
|
| 15 |
+
has_nsfw = any(has_nsfw)
|
| 16 |
+
|
| 17 |
+
has_nsfw = ~get_nsfw_access() and has_nsfw
|
| 18 |
+
return (result.images, bool(has_nsfw))
|
| 19 |
+
# return Result(result.images, result.has_nsfw_concepts)
|
internals/data/task.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
from typing import Union
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class TaskType(Enum):
|
| 8 |
+
TEXT_TO_IMAGE = "GENERATE_AI_IMAGE"
|
| 9 |
+
IMAGE_TO_IMAGE = "IMAGE_TO_IMAGE"
|
| 10 |
+
POSE = "POSE"
|
| 11 |
+
CANNY = "CANNY"
|
| 12 |
+
REMOVE_BG = "REMOVE_BG"
|
| 13 |
+
INPAINT = "INPAINT"
|
| 14 |
+
UPSCALE_IMAGE = "UPSCALE_IMAGE"
|
| 15 |
+
TILE_UPSCALE = "TILE_UPSCALE"
|
| 16 |
+
OBJECT_REMOVAL = "OBJECT_REMOVAL"
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ModelType(Enum):
|
| 20 |
+
REAL = 10000
|
| 21 |
+
ANIME = 10001
|
| 22 |
+
COMIC = 10002
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Task:
|
| 26 |
+
def __init__(self, data):
|
| 27 |
+
self.__data = data
|
| 28 |
+
if data.get("seed", -1) == None or self.get_seed() == -1:
|
| 29 |
+
self.__data["seed"] = np.random.randint(0, np.iinfo(np.int64).max)
|
| 30 |
+
prompt = data.get("prompt", "")
|
| 31 |
+
if prompt is None:
|
| 32 |
+
self.__data["prompt"] = ""
|
| 33 |
+
else:
|
| 34 |
+
self.__data["prompt"] = data.get("prompt", "")[:200]
|
| 35 |
+
|
| 36 |
+
def get_taskId(self) -> str:
|
| 37 |
+
return self.__data.get("task_id")
|
| 38 |
+
|
| 39 |
+
def get_sourceId(self) -> str:
|
| 40 |
+
return self.__data.get("source_id")
|
| 41 |
+
|
| 42 |
+
def get_imageUrl(self) -> str:
|
| 43 |
+
return self.__data.get("imageUrl", None)
|
| 44 |
+
|
| 45 |
+
def get_prompt(self) -> str:
|
| 46 |
+
return self.__data.get("prompt", "")
|
| 47 |
+
|
| 48 |
+
def get_userId(self) -> str:
|
| 49 |
+
return self.__data.get("userId", "")
|
| 50 |
+
|
| 51 |
+
def get_email(self) -> str:
|
| 52 |
+
return self.__data.get("email", "")
|
| 53 |
+
|
| 54 |
+
def get_style(self) -> str:
|
| 55 |
+
return self.__data.get("style", None)
|
| 56 |
+
|
| 57 |
+
def get_iteration(self) -> float:
|
| 58 |
+
return float(self.__data.get("iteration", 3.0))
|
| 59 |
+
|
| 60 |
+
def get_modelType(self) -> ModelType:
|
| 61 |
+
id = self.get_model_id()
|
| 62 |
+
return ModelType(id)
|
| 63 |
+
|
| 64 |
+
def get_model_id(self) -> int:
|
| 65 |
+
return int(self.__data.get("modelId", 10000))
|
| 66 |
+
|
| 67 |
+
def get_width(self) -> int:
|
| 68 |
+
return int(self.__data.get("width", 512))
|
| 69 |
+
|
| 70 |
+
def get_height(self) -> int:
|
| 71 |
+
return int(self.__data.get("height", 512))
|
| 72 |
+
|
| 73 |
+
def get_seed(self) -> int:
|
| 74 |
+
return int(self.__data.get("seed", -1))
|
| 75 |
+
|
| 76 |
+
def get_steps(self) -> int:
|
| 77 |
+
return int(self.__data.get("steps", "75"))
|
| 78 |
+
|
| 79 |
+
def get_type(self) -> Union[TaskType, None]:
|
| 80 |
+
try:
|
| 81 |
+
return TaskType(self.__data.get("task_type"))
|
| 82 |
+
except ValueError:
|
| 83 |
+
return None
|
| 84 |
+
|
| 85 |
+
def get_maskImageUrl(self) -> str:
|
| 86 |
+
return self.__data.get("maskImageUrl")
|
| 87 |
+
|
| 88 |
+
def get_negative_prompt(self) -> str:
|
| 89 |
+
return self.__data.get("negative_prompt", "")
|
| 90 |
+
|
| 91 |
+
def is_prompt_engineering(self) -> bool:
|
| 92 |
+
return self.__data.get("auto_mode", True)
|
| 93 |
+
|
| 94 |
+
def get_queue_name(self) -> str:
|
| 95 |
+
return self.__data.get("queue_name", "")
|
| 96 |
+
|
| 97 |
+
def get_resize_dimension(self) -> int:
|
| 98 |
+
return self.__data.get("resize_dimension", 1024)
|
| 99 |
+
|
| 100 |
+
def get_ti_guidance_scale(self) -> float:
|
| 101 |
+
return self.__data.get("ti_guidance_scale", 7.5)
|
| 102 |
+
|
| 103 |
+
def get_i2i_guidance_scale(self) -> float:
|
| 104 |
+
return self.__data.get("i2i_guidance_scale", 7.5)
|
| 105 |
+
|
| 106 |
+
def get_i2i_strength(self) -> float:
|
| 107 |
+
return self.__data.get("i2i_strength", 0.75)
|
| 108 |
+
|
| 109 |
+
def get_cy_guidance_scale(self) -> float:
|
| 110 |
+
return self.__data.get("cy_guidance_scale", 9)
|
| 111 |
+
|
| 112 |
+
def get_po_guidance_scale(self) -> float:
|
| 113 |
+
return self.__data.get("po_guidance_scale", 7.5)
|
| 114 |
+
|
| 115 |
+
def get_nsfw_threshold(self) -> float:
|
| 116 |
+
return self.__data.get("nsfw_threshold", 0.03)
|
| 117 |
+
|
| 118 |
+
def can_access_nsfw(self) -> bool:
|
| 119 |
+
return self.__data.get("can_access_nsfw", False)
|
| 120 |
+
|
| 121 |
+
def get_access_token(self) -> str:
|
| 122 |
+
return self.__data.get("access_token", "")
|
| 123 |
+
|
| 124 |
+
def get_raw(self) -> dict:
|
| 125 |
+
return self.__data.copy()
|
internals/pipelines/commons.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from diffusers import StableDiffusionImg2ImgPipeline
|
| 5 |
+
|
| 6 |
+
from internals.data.result import Result
|
| 7 |
+
from internals.pipelines.twoStepPipeline import two_step_pipeline
|
| 8 |
+
from internals.util.commons import disable_safety_checker, download_image
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class AbstractPipeline:
|
| 12 |
+
def load(self, model_dir: str):
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
def create(self, pipe):
|
| 16 |
+
pass
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class Text2Img(AbstractPipeline):
|
| 20 |
+
def load(self, model_dir: str):
|
| 21 |
+
self.pipe = two_step_pipeline.from_pretrained(
|
| 22 |
+
model_dir, torch_dtype=torch.float16
|
| 23 |
+
).to("cuda")
|
| 24 |
+
self.__patch()
|
| 25 |
+
|
| 26 |
+
def create(self, pipeline: AbstractPipeline):
|
| 27 |
+
self.pipe = two_step_pipeline(**pipeline.pipe.components).to("cuda")
|
| 28 |
+
self.__patch()
|
| 29 |
+
|
| 30 |
+
def __patch(self):
|
| 31 |
+
self.pipe.enable_xformers_memory_efficient_attention()
|
| 32 |
+
|
| 33 |
+
@torch.inference_mode()
|
| 34 |
+
def process(
|
| 35 |
+
self,
|
| 36 |
+
prompt: Union[str, List[str]] = None,
|
| 37 |
+
modified_prompts: Union[str, List[str]] = None,
|
| 38 |
+
height: Optional[int] = None,
|
| 39 |
+
width: Optional[int] = None,
|
| 40 |
+
num_inference_steps: int = 50,
|
| 41 |
+
guidance_scale: float = 7.5,
|
| 42 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 43 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 44 |
+
eta: float = 0.0,
|
| 45 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 46 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 47 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 48 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 49 |
+
output_type: Optional[str] = "pil",
|
| 50 |
+
return_dict: bool = True,
|
| 51 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 52 |
+
callback_steps: int = 1,
|
| 53 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 54 |
+
iteration: float = 3.0,
|
| 55 |
+
):
|
| 56 |
+
result = self.pipe.two_step_pipeline(
|
| 57 |
+
prompt=prompt,
|
| 58 |
+
modified_prompts=modified_prompts,
|
| 59 |
+
height=height,
|
| 60 |
+
width=width,
|
| 61 |
+
num_inference_steps=num_inference_steps,
|
| 62 |
+
guidance_scale=guidance_scale,
|
| 63 |
+
negative_prompt=negative_prompt,
|
| 64 |
+
num_images_per_prompt=num_images_per_prompt,
|
| 65 |
+
eta=eta,
|
| 66 |
+
generator=generator,
|
| 67 |
+
latents=latents,
|
| 68 |
+
prompt_embeds=prompt_embeds,
|
| 69 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 70 |
+
output_type=output_type,
|
| 71 |
+
return_dict=return_dict,
|
| 72 |
+
callback=callback,
|
| 73 |
+
callback_steps=callback_steps,
|
| 74 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 75 |
+
iteration=iteration,
|
| 76 |
+
)
|
| 77 |
+
return Result.from_result(result)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class Img2Img(AbstractPipeline):
|
| 81 |
+
def load(self, model_dir: str):
|
| 82 |
+
self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
|
| 83 |
+
model_dir, torch_dtype=torch.float16
|
| 84 |
+
).to("cuda")
|
| 85 |
+
self.__patch()
|
| 86 |
+
|
| 87 |
+
def create(self, pipeline: AbstractPipeline):
|
| 88 |
+
self.pipe = StableDiffusionImg2ImgPipeline(**pipeline.pipe.components).to(
|
| 89 |
+
"cuda"
|
| 90 |
+
)
|
| 91 |
+
self.__patch()
|
| 92 |
+
|
| 93 |
+
def __patch(self):
|
| 94 |
+
self.pipe.enable_xformers_memory_efficient_attention()
|
| 95 |
+
|
| 96 |
+
@torch.inference_mode()
|
| 97 |
+
def process(
|
| 98 |
+
self,
|
| 99 |
+
prompt: List[str],
|
| 100 |
+
imageUrl: str,
|
| 101 |
+
negative_prompt: List[str],
|
| 102 |
+
strength: float,
|
| 103 |
+
guidance_scale: float,
|
| 104 |
+
steps: int,
|
| 105 |
+
width: int,
|
| 106 |
+
height: int,
|
| 107 |
+
):
|
| 108 |
+
image = download_image(imageUrl).resize((width, height))
|
| 109 |
+
|
| 110 |
+
result = self.pipe.__call__(
|
| 111 |
+
prompt=prompt,
|
| 112 |
+
image=image,
|
| 113 |
+
strength=strength,
|
| 114 |
+
negative_prompt=negative_prompt,
|
| 115 |
+
guidance_scale=guidance_scale,
|
| 116 |
+
num_images_per_prompt=1,
|
| 117 |
+
num_inference_steps=steps,
|
| 118 |
+
)
|
| 119 |
+
return Result.from_result(result)
|
internals/pipelines/controlnets.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from controlnet_aux import OpenposeDetector
|
| 7 |
+
from diffusers import (
|
| 8 |
+
ControlNetModel,
|
| 9 |
+
DiffusionPipeline,
|
| 10 |
+
StableDiffusionControlNetPipeline,
|
| 11 |
+
UniPCMultistepScheduler,
|
| 12 |
+
)
|
| 13 |
+
from PIL import Image
|
| 14 |
+
from tqdm import gui
|
| 15 |
+
|
| 16 |
+
from internals.data.result import Result
|
| 17 |
+
from internals.pipelines.commons import AbstractPipeline
|
| 18 |
+
from internals.util.cache import clear_cuda_and_gc
|
| 19 |
+
from internals.util.commons import download_image
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class ControlNet(AbstractPipeline):
|
| 23 |
+
__current_task_name = ""
|
| 24 |
+
|
| 25 |
+
def load(self, model_dir: str):
|
| 26 |
+
# we will load canny by default
|
| 27 |
+
self.load_canny()
|
| 28 |
+
|
| 29 |
+
# controlnet pipeline for canny and pose
|
| 30 |
+
pipe = DiffusionPipeline.from_pretrained(
|
| 31 |
+
model_dir,
|
| 32 |
+
controlnet=self.controlnet,
|
| 33 |
+
torch_dtype=torch.float16,
|
| 34 |
+
custom_pipeline="stable_diffusion_controlnet_img2img",
|
| 35 |
+
).to("cuda")
|
| 36 |
+
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
| 37 |
+
pipe.enable_model_cpu_offload()
|
| 38 |
+
pipe.enable_xformers_memory_efficient_attention()
|
| 39 |
+
self.pipe = pipe
|
| 40 |
+
|
| 41 |
+
# controlnet pipeline for tile upscaler
|
| 42 |
+
pipe2 = StableDiffusionControlNetPipeline(**pipe.components).to("cuda")
|
| 43 |
+
pipe2.scheduler = UniPCMultistepScheduler.from_config(pipe2.scheduler.config)
|
| 44 |
+
pipe2.enable_xformers_memory_efficient_attention()
|
| 45 |
+
self.pipe2 = pipe2
|
| 46 |
+
|
| 47 |
+
def load_canny(self):
|
| 48 |
+
if self.__current_task_name == "canny":
|
| 49 |
+
return
|
| 50 |
+
canny = ControlNetModel.from_pretrained(
|
| 51 |
+
"lllyasviel/control_v11p_sd15_canny", torch_dtype=torch.float16
|
| 52 |
+
).to("cuda")
|
| 53 |
+
self.__current_task_name = "canny"
|
| 54 |
+
self.controlnet = canny
|
| 55 |
+
if hasattr(self, "pipe"):
|
| 56 |
+
self.pipe.controlnet = canny
|
| 57 |
+
if hasattr(self, "pipe2"):
|
| 58 |
+
self.pipe2.controlnet = canny
|
| 59 |
+
clear_cuda_and_gc()
|
| 60 |
+
|
| 61 |
+
def load_pose(self):
|
| 62 |
+
if self.__current_task_name == "pose":
|
| 63 |
+
return
|
| 64 |
+
pose = ControlNetModel.from_pretrained(
|
| 65 |
+
"lllyasviel/control_v11p_sd15_openpose", torch_dtype=torch.float16
|
| 66 |
+
).to("cuda")
|
| 67 |
+
self.__current_task_name = "pose"
|
| 68 |
+
self.controlnet = pose
|
| 69 |
+
if hasattr(self, "pipe"):
|
| 70 |
+
self.pipe.controlnet = pose
|
| 71 |
+
if hasattr(self, "pipe2"):
|
| 72 |
+
self.pipe2.controlnet = pose
|
| 73 |
+
clear_cuda_and_gc()
|
| 74 |
+
|
| 75 |
+
def load_tile_upscaler(self):
|
| 76 |
+
if self.__current_task_name == "tile_upscaler":
|
| 77 |
+
return
|
| 78 |
+
tile_upscaler = ControlNetModel.from_pretrained(
|
| 79 |
+
"lllyasviel/control_v11f1e_sd15_tile", torch_dtype=torch.float16
|
| 80 |
+
).to("cuda")
|
| 81 |
+
self.__current_task_name = "tile_upscaler"
|
| 82 |
+
self.controlnet = tile_upscaler
|
| 83 |
+
if hasattr(self, "pipe"):
|
| 84 |
+
self.pipe.controlnet = tile_upscaler
|
| 85 |
+
if hasattr(self, "pipe2"):
|
| 86 |
+
self.pipe2.controlnet = tile_upscaler
|
| 87 |
+
clear_cuda_and_gc()
|
| 88 |
+
|
| 89 |
+
def cleanup(self):
|
| 90 |
+
self.pipe.controlnet = None
|
| 91 |
+
self.pipe2.controlnet = None
|
| 92 |
+
self.controlnet = None
|
| 93 |
+
self.__current_task_name = ""
|
| 94 |
+
|
| 95 |
+
clear_cuda_and_gc()
|
| 96 |
+
|
| 97 |
+
@torch.inference_mode()
|
| 98 |
+
def process_canny(
|
| 99 |
+
self,
|
| 100 |
+
prompt: List[str],
|
| 101 |
+
imageUrl: str,
|
| 102 |
+
seed: int,
|
| 103 |
+
steps: int,
|
| 104 |
+
negative_prompt: List[str],
|
| 105 |
+
guidance_scale: float,
|
| 106 |
+
height: int,
|
| 107 |
+
width: int,
|
| 108 |
+
):
|
| 109 |
+
if self.__current_task_name != "canny":
|
| 110 |
+
raise Exception("ControlNet is not loaded with canny model")
|
| 111 |
+
|
| 112 |
+
torch.manual_seed(seed)
|
| 113 |
+
|
| 114 |
+
init_image = download_image(imageUrl).resize((width, height))
|
| 115 |
+
init_image = self.__canny_detect_edge(init_image)
|
| 116 |
+
|
| 117 |
+
result = self.pipe2.__call__(
|
| 118 |
+
prompt=prompt,
|
| 119 |
+
image=init_image,
|
| 120 |
+
guidance_scale=guidance_scale,
|
| 121 |
+
num_images_per_prompt=1,
|
| 122 |
+
negative_prompt=negative_prompt,
|
| 123 |
+
num_inference_steps=steps,
|
| 124 |
+
height=height,
|
| 125 |
+
width=width,
|
| 126 |
+
)
|
| 127 |
+
return Result.from_result(result)
|
| 128 |
+
|
| 129 |
+
@torch.inference_mode()
|
| 130 |
+
def process_pose(
|
| 131 |
+
self,
|
| 132 |
+
prompt: List[str],
|
| 133 |
+
image: List[Image.Image],
|
| 134 |
+
seed: int,
|
| 135 |
+
steps: int,
|
| 136 |
+
guidance_scale: float,
|
| 137 |
+
negative_prompt: List[str],
|
| 138 |
+
height: int,
|
| 139 |
+
width: int,
|
| 140 |
+
):
|
| 141 |
+
if self.__current_task_name != "pose":
|
| 142 |
+
raise Exception("ControlNet is not loaded with pose model")
|
| 143 |
+
|
| 144 |
+
torch.manual_seed(seed)
|
| 145 |
+
|
| 146 |
+
result = self.pipe2.__call__(
|
| 147 |
+
prompt=prompt,
|
| 148 |
+
image=image,
|
| 149 |
+
num_images_per_prompt=1,
|
| 150 |
+
num_inference_steps=steps,
|
| 151 |
+
negative_prompt=negative_prompt,
|
| 152 |
+
guidance_scale=guidance_scale,
|
| 153 |
+
height=height,
|
| 154 |
+
width=width,
|
| 155 |
+
)
|
| 156 |
+
return Result.from_result(result)
|
| 157 |
+
|
| 158 |
+
@torch.inference_mode()
|
| 159 |
+
def process_tile_upscaler(
|
| 160 |
+
self,
|
| 161 |
+
imageUrl: str,
|
| 162 |
+
prompt: str,
|
| 163 |
+
negative_prompt: str,
|
| 164 |
+
steps: int,
|
| 165 |
+
seed: int,
|
| 166 |
+
height: int,
|
| 167 |
+
width: int,
|
| 168 |
+
resize_dimension: int,
|
| 169 |
+
guidance_scale: float,
|
| 170 |
+
):
|
| 171 |
+
if self.__current_task_name != "tile_upscaler":
|
| 172 |
+
raise Exception("ControlNet is not loaded with tile_upscaler model")
|
| 173 |
+
|
| 174 |
+
torch.manual_seed(seed)
|
| 175 |
+
|
| 176 |
+
init_image = download_image(imageUrl).resize((width, height))
|
| 177 |
+
condition_image = self.__resize_for_condition_image(
|
| 178 |
+
init_image, resize_dimension
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
result = self.pipe.__call__(
|
| 182 |
+
image=condition_image,
|
| 183 |
+
prompt=prompt,
|
| 184 |
+
controlnet_conditioning_image=condition_image,
|
| 185 |
+
num_inference_steps=steps,
|
| 186 |
+
negative_prompt=negative_prompt,
|
| 187 |
+
height=condition_image.size[1],
|
| 188 |
+
width=condition_image.size[0],
|
| 189 |
+
strength=1.0,
|
| 190 |
+
guidance_scale=guidance_scale,
|
| 191 |
+
)
|
| 192 |
+
return Result.from_result(result)
|
| 193 |
+
|
| 194 |
+
def detect_pose(self, imageUrl: str) -> Image.Image:
|
| 195 |
+
detector = OpenposeDetector.from_pretrained("lllyasviel/ControlNet")
|
| 196 |
+
image = download_image(imageUrl)
|
| 197 |
+
image = detector.__call__(image, hand_and_face=True)
|
| 198 |
+
return image
|
| 199 |
+
|
| 200 |
+
def __canny_detect_edge(self, image: Image.Image) -> Image.Image:
|
| 201 |
+
image_array = np.array(image)
|
| 202 |
+
|
| 203 |
+
low_threshold = 100
|
| 204 |
+
high_threshold = 200
|
| 205 |
+
|
| 206 |
+
image_array = cv2.Canny(image_array, low_threshold, high_threshold)
|
| 207 |
+
image_array = image_array[:, :, None]
|
| 208 |
+
image_array = np.concatenate([image_array, image_array, image_array], axis=2)
|
| 209 |
+
canny_image = Image.fromarray(image_array)
|
| 210 |
+
return canny_image
|
| 211 |
+
|
| 212 |
+
def __resize_for_condition_image(self, image: Image.Image, resolution: int):
|
| 213 |
+
input_image = image.convert("RGB")
|
| 214 |
+
W, H = input_image.size
|
| 215 |
+
k = float(resolution) / min(W, H)
|
| 216 |
+
H *= k
|
| 217 |
+
W *= k
|
| 218 |
+
H = int(round(H / 64.0)) * 64
|
| 219 |
+
W = int(round(W / 64.0)) * 64
|
| 220 |
+
img = input_image.resize((W, H), resample=Image.LANCZOS)
|
| 221 |
+
return img
|
internals/pipelines/img_classifier.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
from transformers import pipeline
|
| 4 |
+
|
| 5 |
+
from internals.util.commons import download_image
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ImageClassifier:
|
| 9 |
+
def __init__(self, candidates: List[str] = ["realistic", "anime", "comic"]):
|
| 10 |
+
self.__candidates = candidates
|
| 11 |
+
|
| 12 |
+
def load(self):
|
| 13 |
+
self.pipe = pipeline(
|
| 14 |
+
"zero-shot-image-classification",
|
| 15 |
+
model="philschmid/clip-zero-shot-image-classification",
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
def classify(self, image_url: str, width: int, height: int) -> str:
|
| 19 |
+
image = download_image(image_url).resize((width, height))
|
| 20 |
+
results = self.pipe.__call__([image], candidate_labels=self.__candidates)
|
| 21 |
+
results = results[0]
|
| 22 |
+
if len(results) > 0:
|
| 23 |
+
return results[0]["label"]
|
| 24 |
+
return ""
|
internals/pipelines/img_to_text.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from torchvision import transforms
|
| 5 |
+
from transformers import BlipForConditionalGeneration, BlipProcessor
|
| 6 |
+
|
| 7 |
+
from internals.util.commons import download_image
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Image2Text:
|
| 11 |
+
def load(self):
|
| 12 |
+
self.processor = BlipProcessor.from_pretrained(
|
| 13 |
+
"Salesforce/blip-image-captioning-large"
|
| 14 |
+
)
|
| 15 |
+
self.model = BlipForConditionalGeneration.from_pretrained(
|
| 16 |
+
"Salesforce/blip-image-captioning-large", torch_dtype=torch.float16
|
| 17 |
+
).to("cuda")
|
| 18 |
+
|
| 19 |
+
def process(self, imageUrl: str) -> str:
|
| 20 |
+
image = download_image(imageUrl).resize((512, 512))
|
| 21 |
+
inputs = self.processor.__call__(image, return_tensors="pt").to(
|
| 22 |
+
"cuda", torch.float16
|
| 23 |
+
)
|
| 24 |
+
output_ids = self.model.generate(
|
| 25 |
+
**inputs, do_sample=False, top_p=0.9, max_length=128
|
| 26 |
+
)
|
| 27 |
+
output_text = self.processor.batch_decode(output_ids)
|
| 28 |
+
print(output_text)
|
| 29 |
+
output_text = output_text[0]
|
| 30 |
+
output_text = re.sub("</.>|\\n|\[SEP\]", "", output_text)
|
| 31 |
+
return output_text
|
internals/pipelines/inpainter.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Union
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from diffusers import StableDiffusionInpaintPipeline
|
| 5 |
+
|
| 6 |
+
from internals.pipelines.commons import AbstractPipeline
|
| 7 |
+
from internals.util.commons import disable_safety_checker, download_image
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class InPainter(AbstractPipeline):
|
| 11 |
+
def load(self):
|
| 12 |
+
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
| 13 |
+
"jayparmr/icbinp_v8_inpaint_v2",
|
| 14 |
+
torch_dtype=torch.float16,
|
| 15 |
+
).to("cuda")
|
| 16 |
+
disable_safety_checker(self.pipe)
|
| 17 |
+
|
| 18 |
+
@torch.inference_mode()
|
| 19 |
+
def process(
|
| 20 |
+
self,
|
| 21 |
+
image_url: str,
|
| 22 |
+
mask_image_url: str,
|
| 23 |
+
width: int,
|
| 24 |
+
height: int,
|
| 25 |
+
seed: int,
|
| 26 |
+
prompt: Union[str, List[str]],
|
| 27 |
+
negative_prompt: Union[str, List[str]],
|
| 28 |
+
):
|
| 29 |
+
torch.manual_seed(seed)
|
| 30 |
+
|
| 31 |
+
input_img = download_image(image_url).resize((width, height))
|
| 32 |
+
mask_img = download_image(mask_image_url).resize((width, height))
|
| 33 |
+
|
| 34 |
+
return self.pipe.__call__(
|
| 35 |
+
prompt=prompt,
|
| 36 |
+
image=input_img,
|
| 37 |
+
mask_image=mask_img,
|
| 38 |
+
height=height,
|
| 39 |
+
width=width,
|
| 40 |
+
negative_prompt=negative_prompt,
|
| 41 |
+
).images
|
internals/pipelines/object_remove.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import List
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import tqdm
|
| 9 |
+
from omegaconf import OmegaConf
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from torch.utils.data._utils.collate import default_collate
|
| 12 |
+
|
| 13 |
+
from internals.util.commons import download_file, download_image
|
| 14 |
+
from internals.util.config import get_root_dir
|
| 15 |
+
from saicinpainting.evaluation.utils import move_to_device
|
| 16 |
+
from saicinpainting.training.data.datasets import make_default_val_dataset
|
| 17 |
+
from saicinpainting.training.trainers import load_checkpoint
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class ObjectRemoval:
|
| 21 |
+
def load(self, model_dir):
|
| 22 |
+
print("Downloading LAMA model...")
|
| 23 |
+
|
| 24 |
+
self.lama_path = Path.home() / ".cache" / "lama"
|
| 25 |
+
|
| 26 |
+
out_file = self.lama_path / "models" / "best.ckpt"
|
| 27 |
+
os.makedirs(os.path.dirname(out_file), exist_ok=True)
|
| 28 |
+
download_file(
|
| 29 |
+
"https://huggingface.co/akhaliq/lama/resolve/main/best.ckpt", out_file
|
| 30 |
+
)
|
| 31 |
+
config = OmegaConf.load(get_root_dir() + "/config.yml")
|
| 32 |
+
config.training_model.predict_only = True
|
| 33 |
+
self.model = load_checkpoint(
|
| 34 |
+
config, str(out_file), strict=False, map_location="cuda"
|
| 35 |
+
)
|
| 36 |
+
self.model.freeze()
|
| 37 |
+
self.model.to("cuda")
|
| 38 |
+
|
| 39 |
+
@torch.no_grad()
|
| 40 |
+
def process(
|
| 41 |
+
self,
|
| 42 |
+
image_url: str,
|
| 43 |
+
mask_image_url: str,
|
| 44 |
+
seed: int,
|
| 45 |
+
width: int,
|
| 46 |
+
height: int,
|
| 47 |
+
) -> List:
|
| 48 |
+
torch.manual_seed(seed)
|
| 49 |
+
|
| 50 |
+
img_folder = self.lama_path / "images"
|
| 51 |
+
indir = img_folder / "input"
|
| 52 |
+
|
| 53 |
+
img_folder.mkdir(parents=True, exist_ok=True)
|
| 54 |
+
indir.mkdir(parents=True, exist_ok=True)
|
| 55 |
+
|
| 56 |
+
download_image(image_url).resize((width, height)).save(indir / "data.png")
|
| 57 |
+
download_image(mask_image_url).resize((width, height)).save(
|
| 58 |
+
indir / "data_mask.png"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
dataset = make_default_val_dataset(
|
| 62 |
+
img_folder / "input", img_suffix=".png", pad_out_to_modulo=8
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
out_images = []
|
| 66 |
+
for img_i in tqdm.trange(len(dataset)):
|
| 67 |
+
batch = move_to_device(default_collate([dataset[img_i]]), "cuda")
|
| 68 |
+
batch["mask"] = (batch["mask"] > 0) * 1
|
| 69 |
+
batch = self.model(batch)
|
| 70 |
+
out_path = str(img_folder / "out.png")
|
| 71 |
+
|
| 72 |
+
cur_res = batch["inpainted"][0].permute(1, 2, 0).detach().cpu().numpy()
|
| 73 |
+
|
| 74 |
+
cur_res = np.clip(cur_res * 255, 0, 255).astype("uint8")
|
| 75 |
+
cur_res = cv2.cvtColor(cur_res, cv2.COLOR_RGB2BGR)
|
| 76 |
+
cv2.imwrite(out_path, cur_res)
|
| 77 |
+
|
| 78 |
+
image = Image.open(out_path).convert("RGB")
|
| 79 |
+
out_images.append(image)
|
| 80 |
+
os.remove(out_path)
|
| 81 |
+
|
| 82 |
+
return out_images
|
internals/pipelines/prompt_modifier.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
|
| 3 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class PromptModifier:
|
| 7 |
+
def __init__(self, num_of_sequences: Optional[int] = 4):
|
| 8 |
+
self.__blacklist = {"alphonse mucha": "", "adolphe bouguereau": ""}
|
| 9 |
+
self.__num_of_sequences = num_of_sequences
|
| 10 |
+
|
| 11 |
+
def load(self):
|
| 12 |
+
self.prompter_model = AutoModelForCausalLM.from_pretrained(
|
| 13 |
+
"Gustavosta/MagicPrompt-Stable-Diffusion"
|
| 14 |
+
)
|
| 15 |
+
self.prompter_tokenizer = AutoTokenizer.from_pretrained(
|
| 16 |
+
"Gustavosta/MagicPrompt-Stable-Diffusion"
|
| 17 |
+
)
|
| 18 |
+
self.prompter_tokenizer.pad_token = self.prompter_tokenizer.eos_token
|
| 19 |
+
self.prompter_tokenizer.padding_side = "left"
|
| 20 |
+
|
| 21 |
+
def modify(self, text: str) -> List[str]:
|
| 22 |
+
eos_id = self.prompter_tokenizer.eos_token_id
|
| 23 |
+
# restricted_words_list = ["octane", "cyber"]
|
| 24 |
+
# restricted_words_token_ids = prompter_tokenizer(
|
| 25 |
+
# restricted_words_list, add_special_tokens=False
|
| 26 |
+
# ).input_ids
|
| 27 |
+
|
| 28 |
+
generation_config = GenerationConfig(
|
| 29 |
+
do_sample=False,
|
| 30 |
+
max_new_tokens=75,
|
| 31 |
+
num_beams=4,
|
| 32 |
+
num_return_sequences=self.__num_of_sequences,
|
| 33 |
+
eos_token_id=eos_id,
|
| 34 |
+
pad_token_id=eos_id,
|
| 35 |
+
length_penalty=-1.0,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
input_ids = self.prompter_tokenizer(text.strip(), return_tensors="pt").input_ids
|
| 39 |
+
outputs = self.prompter_model.generate(
|
| 40 |
+
input_ids, generation_config=generation_config
|
| 41 |
+
)
|
| 42 |
+
output_texts = self.prompter_tokenizer.batch_decode(
|
| 43 |
+
outputs, skip_special_tokens=True
|
| 44 |
+
)
|
| 45 |
+
output_texts = self.__patch_blacklist_words(output_texts)
|
| 46 |
+
return output_texts
|
| 47 |
+
|
| 48 |
+
def __patch_blacklist_words(self, texts: List[str]):
|
| 49 |
+
def replace_all(text, dic):
|
| 50 |
+
for i, j in dic.items():
|
| 51 |
+
text = text.replace(i, j)
|
| 52 |
+
return text
|
| 53 |
+
|
| 54 |
+
return [replace_all(text, self.__blacklist) for text in texts]
|
internals/pipelines/remove_background.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
from typing import Union
|
| 3 |
+
|
| 4 |
+
from PIL import Image
|
| 5 |
+
from rembg import remove
|
| 6 |
+
|
| 7 |
+
from internals.util.commons import read_url
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class RemoveBackground:
|
| 11 |
+
def remove(self, image: Union[str, Image.Image]) -> Image.Image:
|
| 12 |
+
if type(image) is str:
|
| 13 |
+
image = Image.open(io.BytesIO(read_url(image)))
|
| 14 |
+
|
| 15 |
+
output = remove(image)
|
| 16 |
+
return output
|
internals/pipelines/safety_checker.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from re import L
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from transformers import CLIPConfig, CLIPVisionModel, PreTrainedModel
|
| 8 |
+
|
| 9 |
+
from internals.pipelines.commons import AbstractPipeline
|
| 10 |
+
from internals.util.config import get_nsfw_access, get_nsfw_threshold
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def cosine_distance(image_embeds, text_embeds):
|
| 14 |
+
normalized_image_embeds = nn.functional.normalize(image_embeds)
|
| 15 |
+
normalized_text_embeds = nn.functional.normalize(text_embeds)
|
| 16 |
+
return torch.mm(normalized_image_embeds, normalized_text_embeds.t())
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class SafetyChecker:
|
| 20 |
+
def load(self):
|
| 21 |
+
self.model = StableDiffusionSafetyCheckerV2.from_pretrained(
|
| 22 |
+
"CompVis/stable-diffusion-safety-checker", torch_dtype=torch.float16
|
| 23 |
+
).to("cuda")
|
| 24 |
+
|
| 25 |
+
def apply(self, pipeline: AbstractPipeline):
|
| 26 |
+
if hasattr(pipeline, "pipe"):
|
| 27 |
+
pipeline.pipe.safety_checker = self.model
|
| 28 |
+
if hasattr(pipeline, "pipe2"):
|
| 29 |
+
pipeline.pipe2.safety_checker = self.model
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class StableDiffusionSafetyCheckerV2(PreTrainedModel):
|
| 33 |
+
config_class = CLIPConfig
|
| 34 |
+
|
| 35 |
+
_no_split_modules = ["CLIPEncoderLayer"]
|
| 36 |
+
|
| 37 |
+
def __init__(self, config: CLIPConfig):
|
| 38 |
+
super().__init__(config)
|
| 39 |
+
|
| 40 |
+
self.vision_model = CLIPVisionModel(config.vision_config)
|
| 41 |
+
self.visual_projection = nn.Linear(
|
| 42 |
+
config.vision_config.hidden_size, config.projection_dim, bias=False
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
self.concept_embeds = nn.Parameter(
|
| 46 |
+
torch.ones(17, config.projection_dim), requires_grad=False
|
| 47 |
+
)
|
| 48 |
+
self.special_care_embeds = nn.Parameter(
|
| 49 |
+
torch.ones(3, config.projection_dim), requires_grad=False
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
self.concept_embeds_weights = nn.Parameter(torch.ones(17), requires_grad=False)
|
| 53 |
+
self.special_care_embeds_weights = nn.Parameter(
|
| 54 |
+
torch.ones(3), requires_grad=False
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
@torch.no_grad()
|
| 58 |
+
def forward(self, clip_input, images):
|
| 59 |
+
pooled_output = self.vision_model(clip_input)[1] # pooled_output
|
| 60 |
+
image_embeds = self.visual_projection(pooled_output)
|
| 61 |
+
|
| 62 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 63 |
+
special_cos_dist = (
|
| 64 |
+
cosine_distance(image_embeds, self.special_care_embeds)
|
| 65 |
+
.cpu()
|
| 66 |
+
.float()
|
| 67 |
+
.numpy()
|
| 68 |
+
)
|
| 69 |
+
cos_dist = (
|
| 70 |
+
cosine_distance(image_embeds, self.concept_embeds).cpu().float().numpy()
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
result = []
|
| 74 |
+
batch_size = image_embeds.shape[0]
|
| 75 |
+
for i in range(batch_size):
|
| 76 |
+
result_img = {
|
| 77 |
+
"special_scores": {},
|
| 78 |
+
"special_care": [],
|
| 79 |
+
"concept_scores": {},
|
| 80 |
+
"bad_concepts": [],
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
# increase this value to create a stronger `nfsw` filter
|
| 84 |
+
# at the cost of increasing the possibility of filtering benign images
|
| 85 |
+
adjustment = 0.0
|
| 86 |
+
|
| 87 |
+
for concept_idx in range(len(special_cos_dist[0])):
|
| 88 |
+
concept_cos = special_cos_dist[i][concept_idx]
|
| 89 |
+
concept_threshold = self.special_care_embeds_weights[concept_idx].item()
|
| 90 |
+
result_img["special_scores"][concept_idx] = round(
|
| 91 |
+
concept_cos - concept_threshold + adjustment, 3
|
| 92 |
+
)
|
| 93 |
+
if result_img["special_scores"][concept_idx] > 0:
|
| 94 |
+
result_img["special_care"].append(
|
| 95 |
+
{concept_idx, result_img["special_scores"][concept_idx]}
|
| 96 |
+
)
|
| 97 |
+
adjustment = 0.01
|
| 98 |
+
|
| 99 |
+
for concept_idx in range(len(cos_dist[0])):
|
| 100 |
+
concept_cos = cos_dist[i][concept_idx]
|
| 101 |
+
concept_threshold = self.concept_embeds_weights[concept_idx].item()
|
| 102 |
+
result_img["concept_scores"][concept_idx] = round(
|
| 103 |
+
concept_cos - concept_threshold + adjustment, 3
|
| 104 |
+
)
|
| 105 |
+
if result_img["concept_scores"][concept_idx] > get_nsfw_threshold():
|
| 106 |
+
result_img["bad_concepts"].append(concept_idx)
|
| 107 |
+
|
| 108 |
+
result.append(result_img)
|
| 109 |
+
|
| 110 |
+
has_nsfw_concepts = [len(res["bad_concepts"]) > 0 for res in result]
|
| 111 |
+
|
| 112 |
+
# Blur images based on NSFW score
|
| 113 |
+
# -------------------------------
|
| 114 |
+
for idx, has_nsfw_concept in enumerate(has_nsfw_concepts):
|
| 115 |
+
if any(has_nsfw_concepts) and not get_nsfw_access():
|
| 116 |
+
if torch.is_tensor(images) or torch.is_tensor(images[0]):
|
| 117 |
+
image = images[idx].cpu().numpy().astype(np.float32)
|
| 118 |
+
image = cv2.blur(image, (30, 30))
|
| 119 |
+
image = torch.from_numpy(image)
|
| 120 |
+
images[idx] = image
|
| 121 |
+
else:
|
| 122 |
+
images[idx] = cv2.blur(images[idx], (30, 30))
|
| 123 |
+
|
| 124 |
+
if any(has_nsfw_concepts):
|
| 125 |
+
print("NSFW")
|
| 126 |
+
|
| 127 |
+
return images, has_nsfw_concepts
|
| 128 |
+
|
| 129 |
+
@torch.no_grad()
|
| 130 |
+
def forward_onnx(self, clip_input: torch.FloatTensor, images: torch.FloatTensor):
|
| 131 |
+
pooled_output = self.vision_model(clip_input)[1] # pooled_output
|
| 132 |
+
image_embeds = self.visual_projection(pooled_output)
|
| 133 |
+
|
| 134 |
+
special_cos_dist = cosine_distance(image_embeds, self.special_care_embeds)
|
| 135 |
+
cos_dist = cosine_distance(image_embeds, self.concept_embeds)
|
| 136 |
+
|
| 137 |
+
# increase this value to create a stronger `nsfw` filter
|
| 138 |
+
# at the cost of increasing the possibility of filtering benign images
|
| 139 |
+
adjustment = 0.0
|
| 140 |
+
|
| 141 |
+
special_scores = (
|
| 142 |
+
special_cos_dist - self.special_care_embeds_weights + adjustment
|
| 143 |
+
)
|
| 144 |
+
# special_scores = special_scores.round(decimals=3)
|
| 145 |
+
special_care = torch.any(special_scores > 0, dim=1)
|
| 146 |
+
special_adjustment = special_care * 0.01
|
| 147 |
+
special_adjustment = special_adjustment.unsqueeze(1).expand(
|
| 148 |
+
-1, cos_dist.shape[1]
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
concept_scores = (cos_dist - self.concept_embeds_weights) + special_adjustment
|
| 152 |
+
# concept_scores = concept_scores.round(decimals=3)
|
| 153 |
+
has_nsfw_concepts = torch.any(concept_scores > get_nsfw_threshold(), dim=1)
|
| 154 |
+
|
| 155 |
+
# Blur images based on NSFW score
|
| 156 |
+
# -------------------------------
|
| 157 |
+
if not get_nsfw_access():
|
| 158 |
+
image = images[has_nsfw_concepts].cpu().numpy().astype(np.float32)
|
| 159 |
+
image = cv2.blur(image, (30, 30))
|
| 160 |
+
image = torch.from_numpy(image)
|
| 161 |
+
images[has_nsfw_concepts] = image
|
| 162 |
+
|
| 163 |
+
return images, has_nsfw_concepts
|
internals/pipelines/twoStepPipeline.py
ADDED
|
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from diffusers import StableDiffusionPipeline
|
| 3 |
+
|
| 4 |
+
torch.backends.cudnn.benchmark = True
|
| 5 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 6 |
+
|
| 7 |
+
from typing import Any, Callable, Dict, List, Optional, Union
|
| 8 |
+
|
| 9 |
+
from diffusers import StableDiffusionPipeline
|
| 10 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class two_step_pipeline(StableDiffusionPipeline):
|
| 14 |
+
@torch.no_grad()
|
| 15 |
+
def two_step_pipeline(
|
| 16 |
+
self,
|
| 17 |
+
prompt: Union[str, List[str]] = None,
|
| 18 |
+
modified_prompts: Union[str, List[str]] = None,
|
| 19 |
+
height: Optional[int] = None,
|
| 20 |
+
width: Optional[int] = None,
|
| 21 |
+
num_inference_steps: int = 50,
|
| 22 |
+
guidance_scale: float = 7.5,
|
| 23 |
+
negative_prompt: Optional[Union[str, List[str]]] = None,
|
| 24 |
+
num_images_per_prompt: Optional[int] = 1,
|
| 25 |
+
eta: float = 0.0,
|
| 26 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 27 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 28 |
+
prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 29 |
+
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
| 30 |
+
output_type: Optional[str] = "pil",
|
| 31 |
+
return_dict: bool = True,
|
| 32 |
+
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
| 33 |
+
callback_steps: int = 1,
|
| 34 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 35 |
+
iteration: float = 3.0,
|
| 36 |
+
):
|
| 37 |
+
r"""
|
| 38 |
+
Function invoked when calling the pipeline for generation.
|
| 39 |
+
Args:
|
| 40 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 41 |
+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
| 42 |
+
instead.
|
| 43 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 44 |
+
The height in pixels of the generated image.
|
| 45 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 46 |
+
The width in pixels of the generated image.
|
| 47 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 48 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 49 |
+
expense of slower inference.
|
| 50 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 51 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 52 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 53 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 54 |
+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
| 55 |
+
usually at the expense of lower image quality.
|
| 56 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 57 |
+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
| 58 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 59 |
+
less than `1`).
|
| 60 |
+
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
| 61 |
+
The number of images to generate per prompt.
|
| 62 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 63 |
+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 64 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 65 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 66 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 67 |
+
to make generation deterministic.
|
| 68 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 69 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
| 70 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 71 |
+
tensor will ge generated by sampling using the supplied random `generator`.
|
| 72 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 73 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 74 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 75 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 76 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 77 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 78 |
+
argument.
|
| 79 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 80 |
+
The output format of the generate image. Choose between
|
| 81 |
+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
| 82 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 83 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 84 |
+
plain tuple.
|
| 85 |
+
callback (`Callable`, *optional*):
|
| 86 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 87 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
| 88 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 89 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 90 |
+
called at every step.
|
| 91 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 92 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 93 |
+
`self.processor` in
|
| 94 |
+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
| 95 |
+
Examples:
|
| 96 |
+
Returns:
|
| 97 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
| 98 |
+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
| 99 |
+
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
| 100 |
+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
| 101 |
+
(nsfw) content, according to the `safety_checker`.
|
| 102 |
+
"""
|
| 103 |
+
# 0. Default height and width to unet
|
| 104 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 105 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 106 |
+
|
| 107 |
+
# 1. Check inputs. Raise error if not correct
|
| 108 |
+
self.check_inputs(
|
| 109 |
+
prompt,
|
| 110 |
+
height,
|
| 111 |
+
width,
|
| 112 |
+
callback_steps,
|
| 113 |
+
negative_prompt,
|
| 114 |
+
prompt_embeds,
|
| 115 |
+
negative_prompt_embeds,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# 2. Define call parameters
|
| 119 |
+
if prompt is not None and isinstance(prompt, str):
|
| 120 |
+
batch_size = 1
|
| 121 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 122 |
+
batch_size = len(prompt)
|
| 123 |
+
else:
|
| 124 |
+
batch_size = prompt_embeds.shape[0]
|
| 125 |
+
|
| 126 |
+
device = self._execution_device
|
| 127 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 128 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 129 |
+
# corresponds to doing no classifier free guidance.
|
| 130 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 131 |
+
|
| 132 |
+
# 3. Encode input prompt
|
| 133 |
+
modified_embeds = self._encode_prompt(
|
| 134 |
+
modified_prompts,
|
| 135 |
+
device,
|
| 136 |
+
num_images_per_prompt,
|
| 137 |
+
do_classifier_free_guidance,
|
| 138 |
+
negative_prompt,
|
| 139 |
+
prompt_embeds=prompt_embeds,
|
| 140 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 141 |
+
)
|
| 142 |
+
print("mod prompt size : ", modified_embeds.size(), modified_embeds.dtype)
|
| 143 |
+
|
| 144 |
+
prompt_embeds = self._encode_prompt(
|
| 145 |
+
prompt,
|
| 146 |
+
device,
|
| 147 |
+
num_images_per_prompt,
|
| 148 |
+
do_classifier_free_guidance,
|
| 149 |
+
negative_prompt,
|
| 150 |
+
prompt_embeds=prompt_embeds,
|
| 151 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
print("prompt size : ", prompt_embeds.size(), prompt_embeds.dtype)
|
| 155 |
+
|
| 156 |
+
# 4. Prepare timesteps
|
| 157 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 158 |
+
timesteps = self.scheduler.timesteps
|
| 159 |
+
|
| 160 |
+
# 5. Prepare latent variables
|
| 161 |
+
num_channels_latents = self.unet.config.in_channels
|
| 162 |
+
latents = self.prepare_latents(
|
| 163 |
+
batch_size * num_images_per_prompt,
|
| 164 |
+
num_channels_latents,
|
| 165 |
+
height,
|
| 166 |
+
width,
|
| 167 |
+
prompt_embeds.dtype,
|
| 168 |
+
device,
|
| 169 |
+
generator,
|
| 170 |
+
latents,
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 174 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 175 |
+
|
| 176 |
+
# 7. Denoising loop
|
| 177 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 178 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 179 |
+
for i, t in enumerate(timesteps):
|
| 180 |
+
# expand the latents if we are doing classifier free guidance
|
| 181 |
+
latent_model_input = (
|
| 182 |
+
torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 183 |
+
)
|
| 184 |
+
latent_model_input = self.scheduler.scale_model_input(
|
| 185 |
+
latent_model_input, t
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# predict the noise residual
|
| 189 |
+
noise_pred = self.unet(
|
| 190 |
+
latent_model_input,
|
| 191 |
+
t,
|
| 192 |
+
encoder_hidden_states=prompt_embeds,
|
| 193 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 194 |
+
).sample
|
| 195 |
+
|
| 196 |
+
# perform guidance
|
| 197 |
+
if do_classifier_free_guidance:
|
| 198 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 199 |
+
noise_pred = noise_pred_uncond + guidance_scale * (
|
| 200 |
+
noise_pred_text - noise_pred_uncond
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 204 |
+
latents = self.scheduler.step(
|
| 205 |
+
noise_pred, t, latents, **extra_step_kwargs
|
| 206 |
+
).prev_sample
|
| 207 |
+
|
| 208 |
+
# call the callback, if provided
|
| 209 |
+
if i == len(timesteps) - 1 or (
|
| 210 |
+
(i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
|
| 211 |
+
):
|
| 212 |
+
progress_bar.update()
|
| 213 |
+
if callback is not None and i % callback_steps == 0:
|
| 214 |
+
callback(i, t, latents)
|
| 215 |
+
|
| 216 |
+
if i == int(len(timesteps) / iteration):
|
| 217 |
+
print("modified prompts")
|
| 218 |
+
prompt_embeds = modified_embeds
|
| 219 |
+
|
| 220 |
+
if output_type == "latent":
|
| 221 |
+
image = latents
|
| 222 |
+
has_nsfw_concept = None
|
| 223 |
+
elif output_type == "pil":
|
| 224 |
+
# 8. Post-processing
|
| 225 |
+
image = self.decode_latents(latents)
|
| 226 |
+
|
| 227 |
+
# 9. Run safety checker
|
| 228 |
+
image, has_nsfw_concept = self.run_safety_checker(
|
| 229 |
+
image, device, prompt_embeds.dtype
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# 10. Convert to PIL
|
| 233 |
+
image = self.numpy_to_pil(image)
|
| 234 |
+
else:
|
| 235 |
+
# 8. Post-processing
|
| 236 |
+
image = self.decode_latents(latents)
|
| 237 |
+
|
| 238 |
+
# 9. Run safety checker
|
| 239 |
+
image, has_nsfw_concept = self.run_safety_checker(
|
| 240 |
+
image, device, prompt_embeds.dtype
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# Offload last model to CPU
|
| 244 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
| 245 |
+
self.final_offload_hook.offload()
|
| 246 |
+
|
| 247 |
+
if not return_dict:
|
| 248 |
+
return (image, has_nsfw_concept)
|
| 249 |
+
|
| 250 |
+
return StableDiffusionPipelineOutput(
|
| 251 |
+
images=image, nsfw_content_detected=has_nsfw_concept
|
| 252 |
+
)
|
internals/pipelines/upscaler.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Union
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
| 9 |
+
from basicsr.utils.download_util import load_file_from_url
|
| 10 |
+
from PIL import Image
|
| 11 |
+
from realesrgan import RealESRGANer
|
| 12 |
+
|
| 13 |
+
import internals.util.image as ImageUtil
|
| 14 |
+
from internals.util.commons import download_image
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Upscaler:
|
| 18 |
+
__model_esrgan_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
|
| 19 |
+
__model_esrgan_anime_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
|
| 20 |
+
|
| 21 |
+
def load(self):
|
| 22 |
+
download_dir = Path(Path.home() / ".cache" / "realesrgan")
|
| 23 |
+
download_dir.mkdir(parents=True, exist_ok=True)
|
| 24 |
+
|
| 25 |
+
self.__model_path = self.__preload_model(self.__model_esrgan_url, download_dir)
|
| 26 |
+
self.__model_path_anime = self.__preload_model(
|
| 27 |
+
self.__model_esrgan_anime_url, download_dir
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def upscale(self, image: Union[str, Image.Image], resize_dimension: int) -> bytes:
|
| 31 |
+
model = RRDBNet(
|
| 32 |
+
num_in_ch=3,
|
| 33 |
+
num_out_ch=3,
|
| 34 |
+
num_feat=64,
|
| 35 |
+
num_block=23,
|
| 36 |
+
num_grow_ch=32,
|
| 37 |
+
scale=4,
|
| 38 |
+
)
|
| 39 |
+
return self.__internal_upscale(
|
| 40 |
+
image, resize_dimension, self.__model_path, model
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def upscale_anime(
|
| 44 |
+
self, image: Union[str, Image.Image], resize_dimension: int
|
| 45 |
+
) -> bytes:
|
| 46 |
+
model = RRDBNet(
|
| 47 |
+
num_in_ch=3,
|
| 48 |
+
num_out_ch=3,
|
| 49 |
+
num_feat=64,
|
| 50 |
+
num_block=23,
|
| 51 |
+
num_grow_ch=32,
|
| 52 |
+
scale=4,
|
| 53 |
+
)
|
| 54 |
+
return self.__internal_upscale(
|
| 55 |
+
image, resize_dimension, self.__model_path_anime, model
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def __preload_model(self, url: str, download_dir: Path):
|
| 59 |
+
name = url.split("/")[-1]
|
| 60 |
+
if not os.path.exists(str(download_dir / name)):
|
| 61 |
+
return load_file_from_url(
|
| 62 |
+
url=url,
|
| 63 |
+
model_dir=str(download_dir),
|
| 64 |
+
progress=True,
|
| 65 |
+
file_name=None,
|
| 66 |
+
)
|
| 67 |
+
else:
|
| 68 |
+
return str(download_dir / name)
|
| 69 |
+
|
| 70 |
+
def __internal_upscale(
|
| 71 |
+
self,
|
| 72 |
+
image,
|
| 73 |
+
resize_dimension: int,
|
| 74 |
+
model_path: str,
|
| 75 |
+
rrbdnet: RRDBNet,
|
| 76 |
+
) -> bytes:
|
| 77 |
+
if type(image) is str:
|
| 78 |
+
image = download_image(image)
|
| 79 |
+
image = ImageUtil.resize_image_to512(image)
|
| 80 |
+
image = ImageUtil.to_bytes(image)
|
| 81 |
+
|
| 82 |
+
upsampler = RealESRGANer(
|
| 83 |
+
scale=4, model_path=model_path, model=rrbdnet, half="fp16", gpu_id="0"
|
| 84 |
+
)
|
| 85 |
+
image_array = np.frombuffer(image, dtype=np.uint8)
|
| 86 |
+
input_image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
| 87 |
+
dimension = min(input_image.shape[0], input_image.shape[1])
|
| 88 |
+
scale = max(math.floor(resize_dimension / dimension), 2)
|
| 89 |
+
output, _ = upsampler.enhance(input_image, outscale=scale)
|
| 90 |
+
out_bytes = cv2.imencode(".png", output)[1].tobytes()
|
| 91 |
+
return out_bytes
|
internals/util/__init__.py
ADDED
|
File without changes
|
internals/util/args.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
from typing import Dict
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def apply_style_args(data: Dict):
|
| 6 |
+
prompt = data.get("prompt", None)
|
| 7 |
+
if prompt is None:
|
| 8 |
+
return
|
| 9 |
+
result = re.match(r"\[style:(.*?)\]", prompt)
|
| 10 |
+
if result is not None:
|
| 11 |
+
style = result.group(1)
|
| 12 |
+
data["style"] = style
|
| 13 |
+
data["prompt"] = prompt.replace(f"[style:{style}]", "").strip()
|
internals/util/avatar.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
|
| 5 |
+
from internals.data.dataAccessor import getCharacters
|
| 6 |
+
from internals.util.config import root_dir
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class Avatar:
|
| 10 |
+
__avatars = {}
|
| 11 |
+
|
| 12 |
+
def load_local(self):
|
| 13 |
+
self.__find_available_characters(root_dir)
|
| 14 |
+
if len(self.__avatars.items()) > 0:
|
| 15 |
+
print("Local characters", self.__avatars)
|
| 16 |
+
|
| 17 |
+
def fetch_from_network(self, model_id: int):
|
| 18 |
+
characters = getCharacters(str(model_id))
|
| 19 |
+
if characters is not None:
|
| 20 |
+
for character in characters:
|
| 21 |
+
item = {
|
| 22 |
+
"avatarName": str(character["title"]).lower(),
|
| 23 |
+
"codename": character["tag"],
|
| 24 |
+
"extraPrompt": character["extraData"]["extraPrompt"],
|
| 25 |
+
}
|
| 26 |
+
self.__avatars[item["avatarName"]] = item
|
| 27 |
+
|
| 28 |
+
def add_code_names(self, prompt):
|
| 29 |
+
array_of_objects = self.__avatars.values()
|
| 30 |
+
|
| 31 |
+
for obj in array_of_objects:
|
| 32 |
+
prompt = (
|
| 33 |
+
re.sub(
|
| 34 |
+
r"\b" + obj["avatarName"] + r"\b",
|
| 35 |
+
obj["extraPrompt"],
|
| 36 |
+
prompt,
|
| 37 |
+
flags=re.IGNORECASE,
|
| 38 |
+
)
|
| 39 |
+
+ " "
|
| 40 |
+
)
|
| 41 |
+
print(prompt)
|
| 42 |
+
return prompt
|
| 43 |
+
|
| 44 |
+
def __find_available_characters(self, path: str):
|
| 45 |
+
if os.path.exists(path + "/characters.json"):
|
| 46 |
+
print(path)
|
| 47 |
+
try:
|
| 48 |
+
print("Loading characters")
|
| 49 |
+
with open(path + "/characters.json") as f:
|
| 50 |
+
data = json.load(f)
|
| 51 |
+
print("Characters: ", data)
|
| 52 |
+
if "avatarName" in data[0]:
|
| 53 |
+
for item in data:
|
| 54 |
+
self.__avatars[item["avatarName"]] = item
|
| 55 |
+
print("Avatars", self.__avatars)
|
| 56 |
+
else:
|
| 57 |
+
print("Invalid characters.json file")
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print("Error Loading characters", e)
|
internals/util/cache.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def clear_cuda_and_gc():
|
| 7 |
+
clear_cuda()
|
| 8 |
+
clear_gc()
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def clear_cuda():
|
| 12 |
+
torch.cuda.empty_cache()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def clear_gc():
|
| 16 |
+
gc.collect()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def auto_clear_cuda_and_gc(controlnet):
|
| 20 |
+
def auto_clear_cuda_and_gc_wrapper(func):
|
| 21 |
+
def wrapper(*args, **kwargs):
|
| 22 |
+
try:
|
| 23 |
+
return func(*args, **kwargs)
|
| 24 |
+
except Exception as e:
|
| 25 |
+
controlnet.cleanup()
|
| 26 |
+
clear_cuda_and_gc()
|
| 27 |
+
raise e
|
| 28 |
+
|
| 29 |
+
return wrapper
|
| 30 |
+
|
| 31 |
+
return auto_clear_cuda_and_gc_wrapper
|
internals/util/commons.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import pprint
|
| 4 |
+
import random
|
| 5 |
+
import re
|
| 6 |
+
from io import BytesIO
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Union
|
| 9 |
+
|
| 10 |
+
import boto3
|
| 11 |
+
import requests
|
| 12 |
+
|
| 13 |
+
from internals.util.config import api_endpoint, api_headers
|
| 14 |
+
|
| 15 |
+
s3 = boto3.client("s3")
|
| 16 |
+
import io
|
| 17 |
+
import urllib.request
|
| 18 |
+
|
| 19 |
+
from PIL import Image
|
| 20 |
+
|
| 21 |
+
black_list = {"alphonse mucha": "", "adolphe bouguereau": ""}
|
| 22 |
+
pp = pprint.PrettyPrinter(indent=4)
|
| 23 |
+
|
| 24 |
+
webhook_url = (
|
| 25 |
+
"https://hooks.slack.com/services/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"
|
| 26 |
+
)
|
| 27 |
+
error_webhook = (
|
| 28 |
+
"https://hooks.slack.com/services/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"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
characterSheets = [
|
| 32 |
+
"character+sheets/1.1.png",
|
| 33 |
+
"character+sheets/10.1.png",
|
| 34 |
+
"character+sheets/11.1.png",
|
| 35 |
+
"character+sheets/12.1.png",
|
| 36 |
+
"character+sheets/13.1.png",
|
| 37 |
+
"character+sheets/14.1.png",
|
| 38 |
+
"character+sheets/16.1.png",
|
| 39 |
+
"character+sheets/17.1.png",
|
| 40 |
+
"character+sheets/18.1.png",
|
| 41 |
+
"character+sheets/19.1.png",
|
| 42 |
+
"character+sheets/2.1.png",
|
| 43 |
+
"character+sheets/20.1.png",
|
| 44 |
+
"character+sheets/21.1.png",
|
| 45 |
+
"character+sheets/22.1.png",
|
| 46 |
+
"character+sheets/23.1.png",
|
| 47 |
+
"character+sheets/24.1.png",
|
| 48 |
+
"character+sheets/25.1.png",
|
| 49 |
+
"character+sheets/26.1.png",
|
| 50 |
+
"character+sheets/27.1.png",
|
| 51 |
+
"character+sheets/28.1.png",
|
| 52 |
+
"character+sheets/29.1.png",
|
| 53 |
+
"character+sheets/3.1.png",
|
| 54 |
+
"character+sheets/30.1.png",
|
| 55 |
+
"character+sheets/31.1.png",
|
| 56 |
+
"character+sheets/32.1.png",
|
| 57 |
+
"character+sheets/33.1.png",
|
| 58 |
+
"character+sheets/34.1.png",
|
| 59 |
+
"character+sheets/35.1.png",
|
| 60 |
+
"character+sheets/36.1.png",
|
| 61 |
+
"character+sheets/38.1.png",
|
| 62 |
+
"character+sheets/39.1.png",
|
| 63 |
+
"character+sheets/4.1.png",
|
| 64 |
+
"character+sheets/40.1.png",
|
| 65 |
+
"character+sheets/42.1.png",
|
| 66 |
+
"character+sheets/43.1.png",
|
| 67 |
+
"character+sheets/44.1.png",
|
| 68 |
+
"character+sheets/45.1.png",
|
| 69 |
+
"character+sheets/46.1.png",
|
| 70 |
+
"character+sheets/47.1.png",
|
| 71 |
+
"character+sheets/48.1.png",
|
| 72 |
+
"character+sheets/49.1.png",
|
| 73 |
+
"character+sheets/5.1.png",
|
| 74 |
+
"character+sheets/50.1.png",
|
| 75 |
+
"character+sheets/51.1.png",
|
| 76 |
+
"character+sheets/52.1.png",
|
| 77 |
+
"character+sheets/53.1.png",
|
| 78 |
+
"character+sheets/54.1.png",
|
| 79 |
+
"character+sheets/55.1.png",
|
| 80 |
+
"character+sheets/56.1.png",
|
| 81 |
+
"character+sheets/57.1.png",
|
| 82 |
+
"character+sheets/58.1.png",
|
| 83 |
+
"character+sheets/59.1.png",
|
| 84 |
+
"character+sheets/60.1.png",
|
| 85 |
+
"character+sheets/61.1.png",
|
| 86 |
+
"character+sheets/62.1.png",
|
| 87 |
+
"character+sheets/63.1.png",
|
| 88 |
+
"character+sheets/64.1.png",
|
| 89 |
+
"character+sheets/65.1.png",
|
| 90 |
+
"character+sheets/66.1.png",
|
| 91 |
+
"character+sheets/7.1.png",
|
| 92 |
+
"character+sheets/8.1.png",
|
| 93 |
+
"character+sheets/9.1.png",
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def upload_images(images, processName: str, taskId: str):
|
| 98 |
+
imageUrls = []
|
| 99 |
+
for i, image in enumerate(images):
|
| 100 |
+
img_io = BytesIO()
|
| 101 |
+
image.save(img_io, "JPEG", quality=100)
|
| 102 |
+
img_io.seek(0)
|
| 103 |
+
key = "crecoAI/{}{}_{}.png".format(taskId, processName, i)
|
| 104 |
+
requests.post(
|
| 105 |
+
api_endpoint()
|
| 106 |
+
+ "/comic-content/v1.0/upload/crecoai-assets-2?fileName="
|
| 107 |
+
+ "{}{}_{}.png".format(taskId, processName, i),
|
| 108 |
+
headers=api_headers(),
|
| 109 |
+
files={"file": ("image.png", img_io, "image/png")},
|
| 110 |
+
)
|
| 111 |
+
# t = s3.put_object(
|
| 112 |
+
# Bucket="comic-assets", Key=key, Body=img_io.getvalue(), ACL="public-read"
|
| 113 |
+
# )
|
| 114 |
+
# print("uploading done to s3", key, t)
|
| 115 |
+
imageUrls.append(
|
| 116 |
+
"https://comic-assets.s3.ap-south-1.amazonaws.com/crecoAI/{}{}_{}.png".format(
|
| 117 |
+
taskId, processName, i
|
| 118 |
+
)
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
print({"promptImages": imageUrls})
|
| 122 |
+
|
| 123 |
+
return imageUrls
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def upload_image(image: Union[Image.Image, BytesIO], out_path):
|
| 127 |
+
if type(image) is Image.Image:
|
| 128 |
+
buffer = io.BytesIO()
|
| 129 |
+
image.save(buffer, format="PNG")
|
| 130 |
+
image = buffer
|
| 131 |
+
|
| 132 |
+
image.seek(0)
|
| 133 |
+
requests.post(
|
| 134 |
+
api_endpoint()
|
| 135 |
+
+ "/comic-content/v1.0/upload/crecoai-assets-2?fileName="
|
| 136 |
+
+ str(out_path).replace("crecoAI/", ""),
|
| 137 |
+
headers=api_headers(),
|
| 138 |
+
files={"file": ("image.png", image, "image/png")},
|
| 139 |
+
)
|
| 140 |
+
# s3.upload_fileobj(image, "comic-assets", out_path, ExtraArgs={"ACL": "public-read"})
|
| 141 |
+
image.close()
|
| 142 |
+
|
| 143 |
+
image_url = "https://comic-assets.s3.ap-south-1.amazonaws.com/" + out_path
|
| 144 |
+
print({"promptImages": image_url})
|
| 145 |
+
|
| 146 |
+
return image_url
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def download_image(url) -> Image.Image:
|
| 150 |
+
response = requests.get(url)
|
| 151 |
+
return Image.open(BytesIO(response.content)).convert("RGB")
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def download_file(url, out_path: Path):
|
| 155 |
+
with requests.get(url, stream=True) as r:
|
| 156 |
+
r.raise_for_status()
|
| 157 |
+
with open(out_path, "wb") as f:
|
| 158 |
+
for chunk in r.iter_content(chunk_size=8192):
|
| 159 |
+
f.write(chunk)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def pickPoses():
|
| 163 |
+
random_images = random.sample(characterSheets, 4)
|
| 164 |
+
poses = []
|
| 165 |
+
prefix = "https://comic-assets.s3.ap-south-1.amazonaws.com/"
|
| 166 |
+
|
| 167 |
+
# Use list comprehension to add prefix to all elements in the array
|
| 168 |
+
random_images_with_prefix = [prefix + img for img in random_images]
|
| 169 |
+
|
| 170 |
+
print(random_images_with_prefix)
|
| 171 |
+
for imageUrl in random_images_with_prefix:
|
| 172 |
+
# Download and resize the image
|
| 173 |
+
init_image = download_image(imageUrl).resize((512, 512))
|
| 174 |
+
|
| 175 |
+
# Open the pose image
|
| 176 |
+
imageUrlPose = imageUrl
|
| 177 |
+
# print(imageUrl)
|
| 178 |
+
input_image_bytes = read_url(imageUrlPose)
|
| 179 |
+
# print(input_image_bytes)
|
| 180 |
+
pose_image = Image.open(io.BytesIO(input_image_bytes)).convert("RGB")
|
| 181 |
+
# print(pose_image)
|
| 182 |
+
pose_image = pose_image.resize((512, 512))
|
| 183 |
+
# print(pose_image)
|
| 184 |
+
# Append the result to the poses array
|
| 185 |
+
poses.append(pose_image)
|
| 186 |
+
|
| 187 |
+
return poses
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def construct_default_s3_url(key):
|
| 191 |
+
return "https://comic-assets.s3.ap-south-1.amazonaws.com/" + key
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def read_url(url: str):
|
| 195 |
+
with urllib.request.urlopen(url) as u:
|
| 196 |
+
return u.read()
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def disable_safety_checker(pipe):
|
| 200 |
+
def dummy(images, **kwargs):
|
| 201 |
+
return images, False
|
| 202 |
+
|
| 203 |
+
pipe.safety_checker = None
|
internals/util/config.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from internals.data.task import Task
|
| 4 |
+
|
| 5 |
+
env = "gamma"
|
| 6 |
+
nsfw_threshold = 0.0
|
| 7 |
+
nsfw_access = False
|
| 8 |
+
access_token = ""
|
| 9 |
+
root_dir = ""
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def set_root_dir(main_file: str):
|
| 13 |
+
global root_dir
|
| 14 |
+
root_dir = os.path.dirname(os.path.abspath(main_file))
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def set_configs_from_task(task: Task):
|
| 18 |
+
global env, nsfw_threshold, nsfw_access, access_token
|
| 19 |
+
name = task.get_queue_name()
|
| 20 |
+
if name.startswith("prod"):
|
| 21 |
+
env = "prod"
|
| 22 |
+
else:
|
| 23 |
+
env = "gamma"
|
| 24 |
+
nsfw_threshold = task.get_nsfw_threshold()
|
| 25 |
+
nsfw_access = task.can_access_nsfw()
|
| 26 |
+
access_token = task.get_access_token()
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def get_root_dir():
|
| 30 |
+
global root_dir
|
| 31 |
+
return root_dir
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_environment():
|
| 35 |
+
global env
|
| 36 |
+
return env
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def get_nsfw_threshold():
|
| 40 |
+
global nsfw_threshold
|
| 41 |
+
return nsfw_threshold
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def get_nsfw_access():
|
| 45 |
+
global nsfw_access
|
| 46 |
+
return nsfw_access
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def api_headers():
|
| 50 |
+
return {
|
| 51 |
+
"Access-Token": access_token,
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def api_endpoint():
|
| 56 |
+
if env == "prod":
|
| 57 |
+
return "https://prod.pratilipicomics.com"
|
| 58 |
+
else:
|
| 59 |
+
return "https://gamma.pratilipicomics.com"
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def comic_url():
|
| 63 |
+
if env == "prod":
|
| 64 |
+
return "http://internal-k8s-prod-internal-bb9c57a6bb-1524739074.ap-south-1.elb.amazonaws.com:80"
|
| 65 |
+
else:
|
| 66 |
+
return "http://internal-k8s-gamma-internal-ea8e32da94-1997933257.ap-south-1.elb.amazonaws.com:80"
|
internals/util/failure_hander.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
|
| 5 |
+
from internals.data.dataAccessor import updateSource
|
| 6 |
+
from internals.data.task import Task
|
| 7 |
+
from internals.util.config import set_configs_from_task
|
| 8 |
+
from internals.util.slack import Slack
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class FailureHandler:
|
| 12 |
+
__task_path = Path.home() / ".cache" / "inference" / "task.json"
|
| 13 |
+
|
| 14 |
+
@staticmethod
|
| 15 |
+
def register():
|
| 16 |
+
path = FailureHandler.__task_path
|
| 17 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 18 |
+
if path.exists():
|
| 19 |
+
task = Task(json.loads(path.read_text()))
|
| 20 |
+
set_configs_from_task(task)
|
| 21 |
+
# Slack().error_alert(task, Exception("CATASTROPHIC FAILURE"))
|
| 22 |
+
updateSource(task.get_sourceId(), task.get_userId(), "FAILED")
|
| 23 |
+
os.remove(path)
|
| 24 |
+
|
| 25 |
+
@staticmethod
|
| 26 |
+
def clear(func):
|
| 27 |
+
def wrapper(*args, **kwargs):
|
| 28 |
+
result = func(*args, **kwargs)
|
| 29 |
+
if result is not None:
|
| 30 |
+
path = FailureHandler.__task_path
|
| 31 |
+
if path.exists():
|
| 32 |
+
os.remove(path)
|
| 33 |
+
return result
|
| 34 |
+
|
| 35 |
+
return wrapper
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def handle(task: Task):
|
| 39 |
+
path = FailureHandler.__task_path
|
| 40 |
+
path.write_text(json.dumps(task.get_raw()))
|
internals/util/image.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import io
|
| 2 |
+
|
| 3 |
+
from PIL import Image
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def to_bytes(image: Image.Image) -> bytes:
|
| 7 |
+
with io.BytesIO() as output:
|
| 8 |
+
image.save(output, format="JPEG")
|
| 9 |
+
return output.getvalue()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def resize_image_to512(image: Image.Image) -> Image.Image:
|
| 13 |
+
iw, ih = image.size
|
| 14 |
+
if iw > ih:
|
| 15 |
+
image = image.resize((512, int(512 * ih / iw)))
|
| 16 |
+
else:
|
| 17 |
+
image = image.resize((int(512 * iw / ih), 512))
|
| 18 |
+
return image
|
internals/util/lora_style.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Any, Dict, List, Union
|
| 5 |
+
|
| 6 |
+
import boto3
|
| 7 |
+
import torch
|
| 8 |
+
from lora_diffusion import patch_pipe, tune_lora_scale
|
| 9 |
+
from pydash import chain
|
| 10 |
+
|
| 11 |
+
from internals.data.dataAccessor import getStyles
|
| 12 |
+
from internals.util.commons import download_file
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LoraStyle:
|
| 16 |
+
class LoraPatcher:
|
| 17 |
+
def __init__(self, pipe, style: Dict[str, Any]):
|
| 18 |
+
self.__style = style
|
| 19 |
+
self.pipe = pipe
|
| 20 |
+
|
| 21 |
+
@torch.inference_mode()
|
| 22 |
+
def patch(self):
|
| 23 |
+
path = self.__style["path"]
|
| 24 |
+
if str(path).endswith((".pt", ".safetensors")):
|
| 25 |
+
patch_pipe(self.pipe, self.__style["path"])
|
| 26 |
+
tune_lora_scale(self.pipe.unet, self.__style["weight"])
|
| 27 |
+
tune_lora_scale(self.pipe.text_encoder, self.__style["weight"])
|
| 28 |
+
|
| 29 |
+
def kwargs(self):
|
| 30 |
+
return {}
|
| 31 |
+
|
| 32 |
+
def cleanup(self):
|
| 33 |
+
tune_lora_scale(self.pipe.unet, 0.0)
|
| 34 |
+
tune_lora_scale(self.pipe.text_encoder, 0.0)
|
| 35 |
+
pass
|
| 36 |
+
|
| 37 |
+
class EmptyLoraPatcher:
|
| 38 |
+
def __init__(self, pipe):
|
| 39 |
+
self.pipe = pipe
|
| 40 |
+
|
| 41 |
+
def patch(self):
|
| 42 |
+
"Patch will act as cleanup, to tune down any corrupted lora"
|
| 43 |
+
self.cleanup()
|
| 44 |
+
pass
|
| 45 |
+
|
| 46 |
+
def kwargs(self):
|
| 47 |
+
return {}
|
| 48 |
+
|
| 49 |
+
def cleanup(self):
|
| 50 |
+
tune_lora_scale(self.pipe.unet, 0.0)
|
| 51 |
+
tune_lora_scale(self.pipe.text_encoder, 0.0)
|
| 52 |
+
pass
|
| 53 |
+
|
| 54 |
+
def load(self, model_dir: str):
|
| 55 |
+
self.model = model_dir
|
| 56 |
+
self.fetch_styles()
|
| 57 |
+
|
| 58 |
+
def fetch_styles(self):
|
| 59 |
+
model_dir = self.model
|
| 60 |
+
result = getStyles()
|
| 61 |
+
if result is not None:
|
| 62 |
+
self.__styles = self.__parse_styles(model_dir, result["data"])
|
| 63 |
+
else:
|
| 64 |
+
self.__styles = self.__get_default_styles(model_dir)
|
| 65 |
+
self.__verify()
|
| 66 |
+
|
| 67 |
+
def prepend_style_to_prompt(self, prompt: str, key: str) -> str:
|
| 68 |
+
if key in self.__styles:
|
| 69 |
+
style = self.__styles[key]
|
| 70 |
+
return f"{', '.join(style['text'])}, {prompt}"
|
| 71 |
+
return prompt
|
| 72 |
+
|
| 73 |
+
def get_patcher(self, pipe, key: str) -> Union[LoraPatcher, EmptyLoraPatcher]:
|
| 74 |
+
if key in self.__styles:
|
| 75 |
+
style = self.__styles[key]
|
| 76 |
+
return self.LoraPatcher(pipe, style)
|
| 77 |
+
return self.EmptyLoraPatcher(pipe)
|
| 78 |
+
|
| 79 |
+
def __parse_styles(self, model_dir: str, data: List[Dict]) -> Dict:
|
| 80 |
+
styles = {}
|
| 81 |
+
download_dir = Path(Path.home() / ".cache" / "lora")
|
| 82 |
+
download_dir.mkdir(exist_ok=True)
|
| 83 |
+
data = chain(data).uniq_by(lambda x: x["tag"]).value()
|
| 84 |
+
for item in data:
|
| 85 |
+
if item["attributes"] is not None:
|
| 86 |
+
attr = json.loads(item["attributes"])
|
| 87 |
+
if "path" in attr:
|
| 88 |
+
file_path = Path(download_dir / attr["path"].split("/")[-1])
|
| 89 |
+
|
| 90 |
+
if not file_path.exists():
|
| 91 |
+
s3_uri = attr["path"]
|
| 92 |
+
download_file(s3_uri, file_path)
|
| 93 |
+
|
| 94 |
+
styles[item["tag"]] = {
|
| 95 |
+
"path": str(file_path),
|
| 96 |
+
"weight": attr["weight"],
|
| 97 |
+
"type": attr["type"],
|
| 98 |
+
"text": attr["text"],
|
| 99 |
+
"negativePrompt": attr["negativePrompt"],
|
| 100 |
+
}
|
| 101 |
+
if len(styles) == 0:
|
| 102 |
+
return self.__get_default_styles(model_dir)
|
| 103 |
+
return styles
|
| 104 |
+
|
| 105 |
+
def __get_default_styles(self, model_dir: str) -> Dict:
|
| 106 |
+
return {
|
| 107 |
+
"nq6akX1CIp": {
|
| 108 |
+
"path": model_dir + "/laur_style/nq6akX1CIp/final_lora.safetensors",
|
| 109 |
+
"text": ["nq6akX1CIp style"],
|
| 110 |
+
"weight": 0.5,
|
| 111 |
+
"negativePrompt": [""],
|
| 112 |
+
"type": "custom",
|
| 113 |
+
},
|
| 114 |
+
"ghibli": {
|
| 115 |
+
"path": model_dir + "/laur_style/nq6akX1CIp/ghibli.bin",
|
| 116 |
+
"text": ["ghibli style"],
|
| 117 |
+
"weight": 1,
|
| 118 |
+
"negativePrompt": [""],
|
| 119 |
+
"type": "custom",
|
| 120 |
+
},
|
| 121 |
+
"eQAmnK2kB2": {
|
| 122 |
+
"path": model_dir + "/laur_style/eQAmnK2kB2/final_lora.safetensors",
|
| 123 |
+
"text": ["eQAmnK2kB2 style"],
|
| 124 |
+
"weight": 0.5,
|
| 125 |
+
"negativePrompt": [""],
|
| 126 |
+
"type": "custom",
|
| 127 |
+
},
|
| 128 |
+
"to8contrast": {
|
| 129 |
+
"path": model_dir + "/laur_style/rpjgusOgqD/final_lora.bin",
|
| 130 |
+
"text": ["to8contrast style"],
|
| 131 |
+
"weight": 0.5,
|
| 132 |
+
"negativePrompt": [""],
|
| 133 |
+
"type": "custom",
|
| 134 |
+
},
|
| 135 |
+
"sfrrfz8vge": {
|
| 136 |
+
"path": model_dir + "/laur_style/replicate/sfrrfz8vge.safetensors",
|
| 137 |
+
"text": ["sfrrfz8vge style"],
|
| 138 |
+
"weight": 1.2,
|
| 139 |
+
"negativePrompt": [""],
|
| 140 |
+
"type": "custom",
|
| 141 |
+
},
|
| 142 |
+
}
|
| 143 |
+
|
| 144 |
+
def __verify(self):
|
| 145 |
+
"A method to verify if lora exists within the required path otherwise throw error"
|
| 146 |
+
|
| 147 |
+
for item in self.__styles.keys():
|
| 148 |
+
if not os.path.exists(self.__styles[item]["path"]):
|
| 149 |
+
raise Exception(
|
| 150 |
+
"Lora style model "
|
| 151 |
+
+ item
|
| 152 |
+
+ " not found at path: "
|
| 153 |
+
+ self.__styles[item]["path"]
|
| 154 |
+
)
|
internals/util/slack.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from time import sleep
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import requests
|
| 5 |
+
|
| 6 |
+
from internals.data.task import Task
|
| 7 |
+
from internals.util.config import get_environment
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Slack:
|
| 11 |
+
def __init__(self):
|
| 12 |
+
# self.webhook_url = "https://hooks.slack.com/services/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"
|
| 13 |
+
self.webhook_url = "https://hooks.slack.com/services/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"
|
| 14 |
+
self.error_webhook = "https://hooks.slack.com/services/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"
|
| 15 |
+
|
| 16 |
+
def send_alert(self, task: Task, args: Optional[dict]):
|
| 17 |
+
raw = task.get_raw().copy()
|
| 18 |
+
|
| 19 |
+
raw["environment"] = get_environment()
|
| 20 |
+
raw.pop("queue_name", None)
|
| 21 |
+
raw.pop("attempt", None)
|
| 22 |
+
raw.pop("timestamp", None)
|
| 23 |
+
raw.pop("task_id", None)
|
| 24 |
+
raw.pop("maskImageUrl", None)
|
| 25 |
+
|
| 26 |
+
if args is not None:
|
| 27 |
+
raw.update(args.items())
|
| 28 |
+
|
| 29 |
+
message = ""
|
| 30 |
+
for key, value in raw.items():
|
| 31 |
+
if value:
|
| 32 |
+
if type(value) == list:
|
| 33 |
+
message += f"*{key}*: {', '.join(value)}\n"
|
| 34 |
+
else:
|
| 35 |
+
message += f"*{key}*: {value}\n"
|
| 36 |
+
|
| 37 |
+
requests.post(
|
| 38 |
+
self.webhook_url,
|
| 39 |
+
headers={"Content-Type": "application/json"},
|
| 40 |
+
json={"text": message},
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
def error_alert(self, task: Task, e: Exception):
|
| 44 |
+
requests.post(
|
| 45 |
+
self.error_webhook,
|
| 46 |
+
headers={"Content-Type": "application/json"},
|
| 47 |
+
json={
|
| 48 |
+
"text": "Task failed:\n{} \n error is: \n {}".format(task.get_raw(), e)
|
| 49 |
+
},
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
def auto_send_alert(self, func):
|
| 53 |
+
def inner(*args, **kwargs):
|
| 54 |
+
rargs = func(*args, **kwargs)
|
| 55 |
+
self.send_alert(args[0], rargs)
|
| 56 |
+
return rargs
|
| 57 |
+
|
| 58 |
+
return inner
|
models/ade20k/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
models/ade20k/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .base import *
|
models/ade20k/base.py
ADDED
|
@@ -0,0 +1,627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch"""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
import pandas as pd
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from scipy.io import loadmat
|
| 10 |
+
from torch.nn.modules import BatchNorm2d
|
| 11 |
+
|
| 12 |
+
from . import resnet
|
| 13 |
+
from . import mobilenet
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
NUM_CLASS = 150
|
| 17 |
+
base_path = os.path.dirname(os.path.abspath(__file__)) # current file path
|
| 18 |
+
colors_path = os.path.join(base_path, 'color150.mat')
|
| 19 |
+
classes_path = os.path.join(base_path, 'object150_info.csv')
|
| 20 |
+
|
| 21 |
+
segm_options = dict(colors=loadmat(colors_path)['colors'],
|
| 22 |
+
classes=pd.read_csv(classes_path),)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class NormalizeTensor:
|
| 26 |
+
def __init__(self, mean, std, inplace=False):
|
| 27 |
+
"""Normalize a tensor image with mean and standard deviation.
|
| 28 |
+
.. note::
|
| 29 |
+
This transform acts out of place by default, i.e., it does not mutates the input tensor.
|
| 30 |
+
See :class:`~torchvision.transforms.Normalize` for more details.
|
| 31 |
+
Args:
|
| 32 |
+
tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
|
| 33 |
+
mean (sequence): Sequence of means for each channel.
|
| 34 |
+
std (sequence): Sequence of standard deviations for each channel.
|
| 35 |
+
inplace(bool,optional): Bool to make this operation inplace.
|
| 36 |
+
Returns:
|
| 37 |
+
Tensor: Normalized Tensor image.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
self.mean = mean
|
| 41 |
+
self.std = std
|
| 42 |
+
self.inplace = inplace
|
| 43 |
+
|
| 44 |
+
def __call__(self, tensor):
|
| 45 |
+
if not self.inplace:
|
| 46 |
+
tensor = tensor.clone()
|
| 47 |
+
|
| 48 |
+
dtype = tensor.dtype
|
| 49 |
+
mean = torch.as_tensor(self.mean, dtype=dtype, device=tensor.device)
|
| 50 |
+
std = torch.as_tensor(self.std, dtype=dtype, device=tensor.device)
|
| 51 |
+
tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
|
| 52 |
+
return tensor
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
# Model Builder
|
| 56 |
+
class ModelBuilder:
|
| 57 |
+
# custom weights initialization
|
| 58 |
+
@staticmethod
|
| 59 |
+
def weights_init(m):
|
| 60 |
+
classname = m.__class__.__name__
|
| 61 |
+
if classname.find('Conv') != -1:
|
| 62 |
+
nn.init.kaiming_normal_(m.weight.data)
|
| 63 |
+
elif classname.find('BatchNorm') != -1:
|
| 64 |
+
m.weight.data.fill_(1.)
|
| 65 |
+
m.bias.data.fill_(1e-4)
|
| 66 |
+
|
| 67 |
+
@staticmethod
|
| 68 |
+
def build_encoder(arch='resnet50dilated', fc_dim=512, weights=''):
|
| 69 |
+
pretrained = True if len(weights) == 0 else False
|
| 70 |
+
arch = arch.lower()
|
| 71 |
+
if arch == 'mobilenetv2dilated':
|
| 72 |
+
orig_mobilenet = mobilenet.__dict__['mobilenetv2'](pretrained=pretrained)
|
| 73 |
+
net_encoder = MobileNetV2Dilated(orig_mobilenet, dilate_scale=8)
|
| 74 |
+
elif arch == 'resnet18':
|
| 75 |
+
orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained)
|
| 76 |
+
net_encoder = Resnet(orig_resnet)
|
| 77 |
+
elif arch == 'resnet18dilated':
|
| 78 |
+
orig_resnet = resnet.__dict__['resnet18'](pretrained=pretrained)
|
| 79 |
+
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
|
| 80 |
+
elif arch == 'resnet50dilated':
|
| 81 |
+
orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
|
| 82 |
+
net_encoder = ResnetDilated(orig_resnet, dilate_scale=8)
|
| 83 |
+
elif arch == 'resnet50':
|
| 84 |
+
orig_resnet = resnet.__dict__['resnet50'](pretrained=pretrained)
|
| 85 |
+
net_encoder = Resnet(orig_resnet)
|
| 86 |
+
else:
|
| 87 |
+
raise Exception('Architecture undefined!')
|
| 88 |
+
|
| 89 |
+
# encoders are usually pretrained
|
| 90 |
+
# net_encoder.apply(ModelBuilder.weights_init)
|
| 91 |
+
if len(weights) > 0:
|
| 92 |
+
print('Loading weights for net_encoder')
|
| 93 |
+
net_encoder.load_state_dict(
|
| 94 |
+
torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
|
| 95 |
+
return net_encoder
|
| 96 |
+
|
| 97 |
+
@staticmethod
|
| 98 |
+
def build_decoder(arch='ppm_deepsup',
|
| 99 |
+
fc_dim=512, num_class=NUM_CLASS,
|
| 100 |
+
weights='', use_softmax=False, drop_last_conv=False):
|
| 101 |
+
arch = arch.lower()
|
| 102 |
+
if arch == 'ppm_deepsup':
|
| 103 |
+
net_decoder = PPMDeepsup(
|
| 104 |
+
num_class=num_class,
|
| 105 |
+
fc_dim=fc_dim,
|
| 106 |
+
use_softmax=use_softmax,
|
| 107 |
+
drop_last_conv=drop_last_conv)
|
| 108 |
+
elif arch == 'c1_deepsup':
|
| 109 |
+
net_decoder = C1DeepSup(
|
| 110 |
+
num_class=num_class,
|
| 111 |
+
fc_dim=fc_dim,
|
| 112 |
+
use_softmax=use_softmax,
|
| 113 |
+
drop_last_conv=drop_last_conv)
|
| 114 |
+
else:
|
| 115 |
+
raise Exception('Architecture undefined!')
|
| 116 |
+
|
| 117 |
+
net_decoder.apply(ModelBuilder.weights_init)
|
| 118 |
+
if len(weights) > 0:
|
| 119 |
+
print('Loading weights for net_decoder')
|
| 120 |
+
net_decoder.load_state_dict(
|
| 121 |
+
torch.load(weights, map_location=lambda storage, loc: storage), strict=False)
|
| 122 |
+
return net_decoder
|
| 123 |
+
|
| 124 |
+
@staticmethod
|
| 125 |
+
def get_decoder(weights_path, arch_encoder, arch_decoder, fc_dim, drop_last_conv, *arts, **kwargs):
|
| 126 |
+
path = os.path.join(weights_path, 'ade20k', f'ade20k-{arch_encoder}-{arch_decoder}/decoder_epoch_20.pth')
|
| 127 |
+
return ModelBuilder.build_decoder(arch=arch_decoder, fc_dim=fc_dim, weights=path, use_softmax=True, drop_last_conv=drop_last_conv)
|
| 128 |
+
|
| 129 |
+
@staticmethod
|
| 130 |
+
def get_encoder(weights_path, arch_encoder, arch_decoder, fc_dim, segmentation,
|
| 131 |
+
*arts, **kwargs):
|
| 132 |
+
if segmentation:
|
| 133 |
+
path = os.path.join(weights_path, 'ade20k', f'ade20k-{arch_encoder}-{arch_decoder}/encoder_epoch_20.pth')
|
| 134 |
+
else:
|
| 135 |
+
path = ''
|
| 136 |
+
return ModelBuilder.build_encoder(arch=arch_encoder, fc_dim=fc_dim, weights=path)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def conv3x3_bn_relu(in_planes, out_planes, stride=1):
|
| 140 |
+
return nn.Sequential(
|
| 141 |
+
nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False),
|
| 142 |
+
BatchNorm2d(out_planes),
|
| 143 |
+
nn.ReLU(inplace=True),
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class SegmentationModule(nn.Module):
|
| 148 |
+
def __init__(self,
|
| 149 |
+
weights_path,
|
| 150 |
+
num_classes=150,
|
| 151 |
+
arch_encoder="resnet50dilated",
|
| 152 |
+
drop_last_conv=False,
|
| 153 |
+
net_enc=None, # None for Default encoder
|
| 154 |
+
net_dec=None, # None for Default decoder
|
| 155 |
+
encode=None, # {None, 'binary', 'color', 'sky'}
|
| 156 |
+
use_default_normalization=False,
|
| 157 |
+
return_feature_maps=False,
|
| 158 |
+
return_feature_maps_level=3, # {0, 1, 2, 3}
|
| 159 |
+
return_feature_maps_only=True,
|
| 160 |
+
**kwargs,
|
| 161 |
+
):
|
| 162 |
+
super().__init__()
|
| 163 |
+
self.weights_path = weights_path
|
| 164 |
+
self.drop_last_conv = drop_last_conv
|
| 165 |
+
self.arch_encoder = arch_encoder
|
| 166 |
+
if self.arch_encoder == "resnet50dilated":
|
| 167 |
+
self.arch_decoder = "ppm_deepsup"
|
| 168 |
+
self.fc_dim = 2048
|
| 169 |
+
elif self.arch_encoder == "mobilenetv2dilated":
|
| 170 |
+
self.arch_decoder = "c1_deepsup"
|
| 171 |
+
self.fc_dim = 320
|
| 172 |
+
else:
|
| 173 |
+
raise NotImplementedError(f"No such arch_encoder={self.arch_encoder}")
|
| 174 |
+
model_builder_kwargs = dict(arch_encoder=self.arch_encoder,
|
| 175 |
+
arch_decoder=self.arch_decoder,
|
| 176 |
+
fc_dim=self.fc_dim,
|
| 177 |
+
drop_last_conv=drop_last_conv,
|
| 178 |
+
weights_path=self.weights_path)
|
| 179 |
+
|
| 180 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 181 |
+
self.encoder = ModelBuilder.get_encoder(**model_builder_kwargs) if net_enc is None else net_enc
|
| 182 |
+
self.decoder = ModelBuilder.get_decoder(**model_builder_kwargs) if net_dec is None else net_dec
|
| 183 |
+
self.use_default_normalization = use_default_normalization
|
| 184 |
+
self.default_normalization = NormalizeTensor(mean=[0.485, 0.456, 0.406],
|
| 185 |
+
std=[0.229, 0.224, 0.225])
|
| 186 |
+
|
| 187 |
+
self.encode = encode
|
| 188 |
+
|
| 189 |
+
self.return_feature_maps = return_feature_maps
|
| 190 |
+
|
| 191 |
+
assert 0 <= return_feature_maps_level <= 3
|
| 192 |
+
self.return_feature_maps_level = return_feature_maps_level
|
| 193 |
+
|
| 194 |
+
def normalize_input(self, tensor):
|
| 195 |
+
if tensor.min() < 0 or tensor.max() > 1:
|
| 196 |
+
raise ValueError("Tensor should be 0..1 before using normalize_input")
|
| 197 |
+
return self.default_normalization(tensor)
|
| 198 |
+
|
| 199 |
+
@property
|
| 200 |
+
def feature_maps_channels(self):
|
| 201 |
+
return 256 * 2**(self.return_feature_maps_level) # 256, 512, 1024, 2048
|
| 202 |
+
|
| 203 |
+
def forward(self, img_data, segSize=None):
|
| 204 |
+
if segSize is None:
|
| 205 |
+
raise NotImplementedError("Please pass segSize param. By default: (300, 300)")
|
| 206 |
+
|
| 207 |
+
fmaps = self.encoder(img_data, return_feature_maps=True)
|
| 208 |
+
pred = self.decoder(fmaps, segSize=segSize)
|
| 209 |
+
|
| 210 |
+
if self.return_feature_maps:
|
| 211 |
+
return pred, fmaps
|
| 212 |
+
# print("BINARY", img_data.shape, pred.shape)
|
| 213 |
+
return pred
|
| 214 |
+
|
| 215 |
+
def multi_mask_from_multiclass(self, pred, classes):
|
| 216 |
+
def isin(ar1, ar2):
|
| 217 |
+
return (ar1[..., None] == ar2).any(-1).float()
|
| 218 |
+
return isin(pred, torch.LongTensor(classes).to(self.device))
|
| 219 |
+
|
| 220 |
+
@staticmethod
|
| 221 |
+
def multi_mask_from_multiclass_probs(scores, classes):
|
| 222 |
+
res = None
|
| 223 |
+
for c in classes:
|
| 224 |
+
if res is None:
|
| 225 |
+
res = scores[:, c]
|
| 226 |
+
else:
|
| 227 |
+
res += scores[:, c]
|
| 228 |
+
return res
|
| 229 |
+
|
| 230 |
+
def predict(self, tensor, imgSizes=(-1,), # (300, 375, 450, 525, 600)
|
| 231 |
+
segSize=None):
|
| 232 |
+
"""Entry-point for segmentation. Use this methods instead of forward
|
| 233 |
+
Arguments:
|
| 234 |
+
tensor {torch.Tensor} -- BCHW
|
| 235 |
+
Keyword Arguments:
|
| 236 |
+
imgSizes {tuple or list} -- imgSizes for segmentation input.
|
| 237 |
+
default: (300, 450)
|
| 238 |
+
original implementation: (300, 375, 450, 525, 600)
|
| 239 |
+
|
| 240 |
+
"""
|
| 241 |
+
if segSize is None:
|
| 242 |
+
segSize = tensor.shape[-2:]
|
| 243 |
+
segSize = (tensor.shape[2], tensor.shape[3])
|
| 244 |
+
with torch.no_grad():
|
| 245 |
+
if self.use_default_normalization:
|
| 246 |
+
tensor = self.normalize_input(tensor)
|
| 247 |
+
scores = torch.zeros(1, NUM_CLASS, segSize[0], segSize[1]).to(self.device)
|
| 248 |
+
features = torch.zeros(1, self.feature_maps_channels, segSize[0], segSize[1]).to(self.device)
|
| 249 |
+
|
| 250 |
+
result = []
|
| 251 |
+
for img_size in imgSizes:
|
| 252 |
+
if img_size != -1:
|
| 253 |
+
img_data = F.interpolate(tensor.clone(), size=img_size)
|
| 254 |
+
else:
|
| 255 |
+
img_data = tensor.clone()
|
| 256 |
+
|
| 257 |
+
if self.return_feature_maps:
|
| 258 |
+
pred_current, fmaps = self.forward(img_data, segSize=segSize)
|
| 259 |
+
else:
|
| 260 |
+
pred_current = self.forward(img_data, segSize=segSize)
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
result.append(pred_current)
|
| 264 |
+
scores = scores + pred_current / len(imgSizes)
|
| 265 |
+
|
| 266 |
+
# Disclaimer: We use and aggregate only last fmaps: fmaps[3]
|
| 267 |
+
if self.return_feature_maps:
|
| 268 |
+
features = features + F.interpolate(fmaps[self.return_feature_maps_level], size=segSize) / len(imgSizes)
|
| 269 |
+
|
| 270 |
+
_, pred = torch.max(scores, dim=1)
|
| 271 |
+
|
| 272 |
+
if self.return_feature_maps:
|
| 273 |
+
return features
|
| 274 |
+
|
| 275 |
+
return pred, result
|
| 276 |
+
|
| 277 |
+
def get_edges(self, t):
|
| 278 |
+
edge = torch.cuda.ByteTensor(t.size()).zero_()
|
| 279 |
+
edge[:, :, :, 1:] = edge[:, :, :, 1:] | (t[:, :, :, 1:] != t[:, :, :, :-1])
|
| 280 |
+
edge[:, :, :, :-1] = edge[:, :, :, :-1] | (t[:, :, :, 1:] != t[:, :, :, :-1])
|
| 281 |
+
edge[:, :, 1:, :] = edge[:, :, 1:, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
|
| 282 |
+
edge[:, :, :-1, :] = edge[:, :, :-1, :] | (t[:, :, 1:, :] != t[:, :, :-1, :])
|
| 283 |
+
|
| 284 |
+
if True:
|
| 285 |
+
return edge.half()
|
| 286 |
+
return edge.float()
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
# pyramid pooling, deep supervision
|
| 290 |
+
class PPMDeepsup(nn.Module):
|
| 291 |
+
def __init__(self, num_class=NUM_CLASS, fc_dim=4096,
|
| 292 |
+
use_softmax=False, pool_scales=(1, 2, 3, 6),
|
| 293 |
+
drop_last_conv=False):
|
| 294 |
+
super().__init__()
|
| 295 |
+
self.use_softmax = use_softmax
|
| 296 |
+
self.drop_last_conv = drop_last_conv
|
| 297 |
+
|
| 298 |
+
self.ppm = []
|
| 299 |
+
for scale in pool_scales:
|
| 300 |
+
self.ppm.append(nn.Sequential(
|
| 301 |
+
nn.AdaptiveAvgPool2d(scale),
|
| 302 |
+
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
|
| 303 |
+
BatchNorm2d(512),
|
| 304 |
+
nn.ReLU(inplace=True)
|
| 305 |
+
))
|
| 306 |
+
self.ppm = nn.ModuleList(self.ppm)
|
| 307 |
+
self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1)
|
| 308 |
+
|
| 309 |
+
self.conv_last = nn.Sequential(
|
| 310 |
+
nn.Conv2d(fc_dim + len(pool_scales) * 512, 512,
|
| 311 |
+
kernel_size=3, padding=1, bias=False),
|
| 312 |
+
BatchNorm2d(512),
|
| 313 |
+
nn.ReLU(inplace=True),
|
| 314 |
+
nn.Dropout2d(0.1),
|
| 315 |
+
nn.Conv2d(512, num_class, kernel_size=1)
|
| 316 |
+
)
|
| 317 |
+
self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
|
| 318 |
+
self.dropout_deepsup = nn.Dropout2d(0.1)
|
| 319 |
+
|
| 320 |
+
def forward(self, conv_out, segSize=None):
|
| 321 |
+
conv5 = conv_out[-1]
|
| 322 |
+
|
| 323 |
+
input_size = conv5.size()
|
| 324 |
+
ppm_out = [conv5]
|
| 325 |
+
for pool_scale in self.ppm:
|
| 326 |
+
ppm_out.append(nn.functional.interpolate(
|
| 327 |
+
pool_scale(conv5),
|
| 328 |
+
(input_size[2], input_size[3]),
|
| 329 |
+
mode='bilinear', align_corners=False))
|
| 330 |
+
ppm_out = torch.cat(ppm_out, 1)
|
| 331 |
+
|
| 332 |
+
if self.drop_last_conv:
|
| 333 |
+
return ppm_out
|
| 334 |
+
else:
|
| 335 |
+
x = self.conv_last(ppm_out)
|
| 336 |
+
|
| 337 |
+
if self.use_softmax: # is True during inference
|
| 338 |
+
x = nn.functional.interpolate(
|
| 339 |
+
x, size=segSize, mode='bilinear', align_corners=False)
|
| 340 |
+
x = nn.functional.softmax(x, dim=1)
|
| 341 |
+
return x
|
| 342 |
+
|
| 343 |
+
# deep sup
|
| 344 |
+
conv4 = conv_out[-2]
|
| 345 |
+
_ = self.cbr_deepsup(conv4)
|
| 346 |
+
_ = self.dropout_deepsup(_)
|
| 347 |
+
_ = self.conv_last_deepsup(_)
|
| 348 |
+
|
| 349 |
+
x = nn.functional.log_softmax(x, dim=1)
|
| 350 |
+
_ = nn.functional.log_softmax(_, dim=1)
|
| 351 |
+
|
| 352 |
+
return (x, _)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
class Resnet(nn.Module):
|
| 356 |
+
def __init__(self, orig_resnet):
|
| 357 |
+
super(Resnet, self).__init__()
|
| 358 |
+
|
| 359 |
+
# take pretrained resnet, except AvgPool and FC
|
| 360 |
+
self.conv1 = orig_resnet.conv1
|
| 361 |
+
self.bn1 = orig_resnet.bn1
|
| 362 |
+
self.relu1 = orig_resnet.relu1
|
| 363 |
+
self.conv2 = orig_resnet.conv2
|
| 364 |
+
self.bn2 = orig_resnet.bn2
|
| 365 |
+
self.relu2 = orig_resnet.relu2
|
| 366 |
+
self.conv3 = orig_resnet.conv3
|
| 367 |
+
self.bn3 = orig_resnet.bn3
|
| 368 |
+
self.relu3 = orig_resnet.relu3
|
| 369 |
+
self.maxpool = orig_resnet.maxpool
|
| 370 |
+
self.layer1 = orig_resnet.layer1
|
| 371 |
+
self.layer2 = orig_resnet.layer2
|
| 372 |
+
self.layer3 = orig_resnet.layer3
|
| 373 |
+
self.layer4 = orig_resnet.layer4
|
| 374 |
+
|
| 375 |
+
def forward(self, x, return_feature_maps=False):
|
| 376 |
+
conv_out = []
|
| 377 |
+
|
| 378 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
| 379 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
| 380 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
| 381 |
+
x = self.maxpool(x)
|
| 382 |
+
|
| 383 |
+
x = self.layer1(x); conv_out.append(x);
|
| 384 |
+
x = self.layer2(x); conv_out.append(x);
|
| 385 |
+
x = self.layer3(x); conv_out.append(x);
|
| 386 |
+
x = self.layer4(x); conv_out.append(x);
|
| 387 |
+
|
| 388 |
+
if return_feature_maps:
|
| 389 |
+
return conv_out
|
| 390 |
+
return [x]
|
| 391 |
+
|
| 392 |
+
# Resnet Dilated
|
| 393 |
+
class ResnetDilated(nn.Module):
|
| 394 |
+
def __init__(self, orig_resnet, dilate_scale=8):
|
| 395 |
+
super().__init__()
|
| 396 |
+
from functools import partial
|
| 397 |
+
|
| 398 |
+
if dilate_scale == 8:
|
| 399 |
+
orig_resnet.layer3.apply(
|
| 400 |
+
partial(self._nostride_dilate, dilate=2))
|
| 401 |
+
orig_resnet.layer4.apply(
|
| 402 |
+
partial(self._nostride_dilate, dilate=4))
|
| 403 |
+
elif dilate_scale == 16:
|
| 404 |
+
orig_resnet.layer4.apply(
|
| 405 |
+
partial(self._nostride_dilate, dilate=2))
|
| 406 |
+
|
| 407 |
+
# take pretrained resnet, except AvgPool and FC
|
| 408 |
+
self.conv1 = orig_resnet.conv1
|
| 409 |
+
self.bn1 = orig_resnet.bn1
|
| 410 |
+
self.relu1 = orig_resnet.relu1
|
| 411 |
+
self.conv2 = orig_resnet.conv2
|
| 412 |
+
self.bn2 = orig_resnet.bn2
|
| 413 |
+
self.relu2 = orig_resnet.relu2
|
| 414 |
+
self.conv3 = orig_resnet.conv3
|
| 415 |
+
self.bn3 = orig_resnet.bn3
|
| 416 |
+
self.relu3 = orig_resnet.relu3
|
| 417 |
+
self.maxpool = orig_resnet.maxpool
|
| 418 |
+
self.layer1 = orig_resnet.layer1
|
| 419 |
+
self.layer2 = orig_resnet.layer2
|
| 420 |
+
self.layer3 = orig_resnet.layer3
|
| 421 |
+
self.layer4 = orig_resnet.layer4
|
| 422 |
+
|
| 423 |
+
def _nostride_dilate(self, m, dilate):
|
| 424 |
+
classname = m.__class__.__name__
|
| 425 |
+
if classname.find('Conv') != -1:
|
| 426 |
+
# the convolution with stride
|
| 427 |
+
if m.stride == (2, 2):
|
| 428 |
+
m.stride = (1, 1)
|
| 429 |
+
if m.kernel_size == (3, 3):
|
| 430 |
+
m.dilation = (dilate // 2, dilate // 2)
|
| 431 |
+
m.padding = (dilate // 2, dilate // 2)
|
| 432 |
+
# other convoluions
|
| 433 |
+
else:
|
| 434 |
+
if m.kernel_size == (3, 3):
|
| 435 |
+
m.dilation = (dilate, dilate)
|
| 436 |
+
m.padding = (dilate, dilate)
|
| 437 |
+
|
| 438 |
+
def forward(self, x, return_feature_maps=False):
|
| 439 |
+
conv_out = []
|
| 440 |
+
|
| 441 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
| 442 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
| 443 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
| 444 |
+
x = self.maxpool(x)
|
| 445 |
+
|
| 446 |
+
x = self.layer1(x)
|
| 447 |
+
conv_out.append(x)
|
| 448 |
+
x = self.layer2(x)
|
| 449 |
+
conv_out.append(x)
|
| 450 |
+
x = self.layer3(x)
|
| 451 |
+
conv_out.append(x)
|
| 452 |
+
x = self.layer4(x)
|
| 453 |
+
conv_out.append(x)
|
| 454 |
+
|
| 455 |
+
if return_feature_maps:
|
| 456 |
+
return conv_out
|
| 457 |
+
return [x]
|
| 458 |
+
|
| 459 |
+
class MobileNetV2Dilated(nn.Module):
|
| 460 |
+
def __init__(self, orig_net, dilate_scale=8):
|
| 461 |
+
super(MobileNetV2Dilated, self).__init__()
|
| 462 |
+
from functools import partial
|
| 463 |
+
|
| 464 |
+
# take pretrained mobilenet features
|
| 465 |
+
self.features = orig_net.features[:-1]
|
| 466 |
+
|
| 467 |
+
self.total_idx = len(self.features)
|
| 468 |
+
self.down_idx = [2, 4, 7, 14]
|
| 469 |
+
|
| 470 |
+
if dilate_scale == 8:
|
| 471 |
+
for i in range(self.down_idx[-2], self.down_idx[-1]):
|
| 472 |
+
self.features[i].apply(
|
| 473 |
+
partial(self._nostride_dilate, dilate=2)
|
| 474 |
+
)
|
| 475 |
+
for i in range(self.down_idx[-1], self.total_idx):
|
| 476 |
+
self.features[i].apply(
|
| 477 |
+
partial(self._nostride_dilate, dilate=4)
|
| 478 |
+
)
|
| 479 |
+
elif dilate_scale == 16:
|
| 480 |
+
for i in range(self.down_idx[-1], self.total_idx):
|
| 481 |
+
self.features[i].apply(
|
| 482 |
+
partial(self._nostride_dilate, dilate=2)
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
def _nostride_dilate(self, m, dilate):
|
| 486 |
+
classname = m.__class__.__name__
|
| 487 |
+
if classname.find('Conv') != -1:
|
| 488 |
+
# the convolution with stride
|
| 489 |
+
if m.stride == (2, 2):
|
| 490 |
+
m.stride = (1, 1)
|
| 491 |
+
if m.kernel_size == (3, 3):
|
| 492 |
+
m.dilation = (dilate//2, dilate//2)
|
| 493 |
+
m.padding = (dilate//2, dilate//2)
|
| 494 |
+
# other convoluions
|
| 495 |
+
else:
|
| 496 |
+
if m.kernel_size == (3, 3):
|
| 497 |
+
m.dilation = (dilate, dilate)
|
| 498 |
+
m.padding = (dilate, dilate)
|
| 499 |
+
|
| 500 |
+
def forward(self, x, return_feature_maps=False):
|
| 501 |
+
if return_feature_maps:
|
| 502 |
+
conv_out = []
|
| 503 |
+
for i in range(self.total_idx):
|
| 504 |
+
x = self.features[i](x)
|
| 505 |
+
if i in self.down_idx:
|
| 506 |
+
conv_out.append(x)
|
| 507 |
+
conv_out.append(x)
|
| 508 |
+
return conv_out
|
| 509 |
+
|
| 510 |
+
else:
|
| 511 |
+
return [self.features(x)]
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
# last conv, deep supervision
|
| 515 |
+
class C1DeepSup(nn.Module):
|
| 516 |
+
def __init__(self, num_class=150, fc_dim=2048, use_softmax=False, drop_last_conv=False):
|
| 517 |
+
super(C1DeepSup, self).__init__()
|
| 518 |
+
self.use_softmax = use_softmax
|
| 519 |
+
self.drop_last_conv = drop_last_conv
|
| 520 |
+
|
| 521 |
+
self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1)
|
| 522 |
+
self.cbr_deepsup = conv3x3_bn_relu(fc_dim // 2, fc_dim // 4, 1)
|
| 523 |
+
|
| 524 |
+
# last conv
|
| 525 |
+
self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
|
| 526 |
+
self.conv_last_deepsup = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
|
| 527 |
+
|
| 528 |
+
def forward(self, conv_out, segSize=None):
|
| 529 |
+
conv5 = conv_out[-1]
|
| 530 |
+
|
| 531 |
+
x = self.cbr(conv5)
|
| 532 |
+
|
| 533 |
+
if self.drop_last_conv:
|
| 534 |
+
return x
|
| 535 |
+
else:
|
| 536 |
+
x = self.conv_last(x)
|
| 537 |
+
|
| 538 |
+
if self.use_softmax: # is True during inference
|
| 539 |
+
x = nn.functional.interpolate(
|
| 540 |
+
x, size=segSize, mode='bilinear', align_corners=False)
|
| 541 |
+
x = nn.functional.softmax(x, dim=1)
|
| 542 |
+
return x
|
| 543 |
+
|
| 544 |
+
# deep sup
|
| 545 |
+
conv4 = conv_out[-2]
|
| 546 |
+
_ = self.cbr_deepsup(conv4)
|
| 547 |
+
_ = self.conv_last_deepsup(_)
|
| 548 |
+
|
| 549 |
+
x = nn.functional.log_softmax(x, dim=1)
|
| 550 |
+
_ = nn.functional.log_softmax(_, dim=1)
|
| 551 |
+
|
| 552 |
+
return (x, _)
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
# last conv
|
| 556 |
+
class C1(nn.Module):
|
| 557 |
+
def __init__(self, num_class=150, fc_dim=2048, use_softmax=False):
|
| 558 |
+
super(C1, self).__init__()
|
| 559 |
+
self.use_softmax = use_softmax
|
| 560 |
+
|
| 561 |
+
self.cbr = conv3x3_bn_relu(fc_dim, fc_dim // 4, 1)
|
| 562 |
+
|
| 563 |
+
# last conv
|
| 564 |
+
self.conv_last = nn.Conv2d(fc_dim // 4, num_class, 1, 1, 0)
|
| 565 |
+
|
| 566 |
+
def forward(self, conv_out, segSize=None):
|
| 567 |
+
conv5 = conv_out[-1]
|
| 568 |
+
x = self.cbr(conv5)
|
| 569 |
+
x = self.conv_last(x)
|
| 570 |
+
|
| 571 |
+
if self.use_softmax: # is True during inference
|
| 572 |
+
x = nn.functional.interpolate(
|
| 573 |
+
x, size=segSize, mode='bilinear', align_corners=False)
|
| 574 |
+
x = nn.functional.softmax(x, dim=1)
|
| 575 |
+
else:
|
| 576 |
+
x = nn.functional.log_softmax(x, dim=1)
|
| 577 |
+
|
| 578 |
+
return x
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
# pyramid pooling
|
| 582 |
+
class PPM(nn.Module):
|
| 583 |
+
def __init__(self, num_class=150, fc_dim=4096,
|
| 584 |
+
use_softmax=False, pool_scales=(1, 2, 3, 6)):
|
| 585 |
+
super(PPM, self).__init__()
|
| 586 |
+
self.use_softmax = use_softmax
|
| 587 |
+
|
| 588 |
+
self.ppm = []
|
| 589 |
+
for scale in pool_scales:
|
| 590 |
+
self.ppm.append(nn.Sequential(
|
| 591 |
+
nn.AdaptiveAvgPool2d(scale),
|
| 592 |
+
nn.Conv2d(fc_dim, 512, kernel_size=1, bias=False),
|
| 593 |
+
BatchNorm2d(512),
|
| 594 |
+
nn.ReLU(inplace=True)
|
| 595 |
+
))
|
| 596 |
+
self.ppm = nn.ModuleList(self.ppm)
|
| 597 |
+
|
| 598 |
+
self.conv_last = nn.Sequential(
|
| 599 |
+
nn.Conv2d(fc_dim+len(pool_scales)*512, 512,
|
| 600 |
+
kernel_size=3, padding=1, bias=False),
|
| 601 |
+
BatchNorm2d(512),
|
| 602 |
+
nn.ReLU(inplace=True),
|
| 603 |
+
nn.Dropout2d(0.1),
|
| 604 |
+
nn.Conv2d(512, num_class, kernel_size=1)
|
| 605 |
+
)
|
| 606 |
+
|
| 607 |
+
def forward(self, conv_out, segSize=None):
|
| 608 |
+
conv5 = conv_out[-1]
|
| 609 |
+
|
| 610 |
+
input_size = conv5.size()
|
| 611 |
+
ppm_out = [conv5]
|
| 612 |
+
for pool_scale in self.ppm:
|
| 613 |
+
ppm_out.append(nn.functional.interpolate(
|
| 614 |
+
pool_scale(conv5),
|
| 615 |
+
(input_size[2], input_size[3]),
|
| 616 |
+
mode='bilinear', align_corners=False))
|
| 617 |
+
ppm_out = torch.cat(ppm_out, 1)
|
| 618 |
+
|
| 619 |
+
x = self.conv_last(ppm_out)
|
| 620 |
+
|
| 621 |
+
if self.use_softmax: # is True during inference
|
| 622 |
+
x = nn.functional.interpolate(
|
| 623 |
+
x, size=segSize, mode='bilinear', align_corners=False)
|
| 624 |
+
x = nn.functional.softmax(x, dim=1)
|
| 625 |
+
else:
|
| 626 |
+
x = nn.functional.log_softmax(x, dim=1)
|
| 627 |
+
return x
|
models/ade20k/color150.mat
ADDED
|
Binary file (502 Bytes). View file
|
|
|
models/ade20k/mobilenet.py
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This MobileNetV2 implementation is modified from the following repository:
|
| 3 |
+
https://github.com/tonylins/pytorch-mobilenet-v2
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import math
|
| 8 |
+
from .utils import load_url
|
| 9 |
+
from .segm_lib.nn import SynchronizedBatchNorm2d
|
| 10 |
+
|
| 11 |
+
BatchNorm2d = SynchronizedBatchNorm2d
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
__all__ = ['mobilenetv2']
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
model_urls = {
|
| 18 |
+
'mobilenetv2': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/mobilenet_v2.pth.tar',
|
| 19 |
+
}
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def conv_bn(inp, oup, stride):
|
| 23 |
+
return nn.Sequential(
|
| 24 |
+
nn.Conv2d(inp, oup, 3, stride, 1, bias=False),
|
| 25 |
+
BatchNorm2d(oup),
|
| 26 |
+
nn.ReLU6(inplace=True)
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def conv_1x1_bn(inp, oup):
|
| 31 |
+
return nn.Sequential(
|
| 32 |
+
nn.Conv2d(inp, oup, 1, 1, 0, bias=False),
|
| 33 |
+
BatchNorm2d(oup),
|
| 34 |
+
nn.ReLU6(inplace=True)
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class InvertedResidual(nn.Module):
|
| 39 |
+
def __init__(self, inp, oup, stride, expand_ratio):
|
| 40 |
+
super(InvertedResidual, self).__init__()
|
| 41 |
+
self.stride = stride
|
| 42 |
+
assert stride in [1, 2]
|
| 43 |
+
|
| 44 |
+
hidden_dim = round(inp * expand_ratio)
|
| 45 |
+
self.use_res_connect = self.stride == 1 and inp == oup
|
| 46 |
+
|
| 47 |
+
if expand_ratio == 1:
|
| 48 |
+
self.conv = nn.Sequential(
|
| 49 |
+
# dw
|
| 50 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
| 51 |
+
BatchNorm2d(hidden_dim),
|
| 52 |
+
nn.ReLU6(inplace=True),
|
| 53 |
+
# pw-linear
|
| 54 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
| 55 |
+
BatchNorm2d(oup),
|
| 56 |
+
)
|
| 57 |
+
else:
|
| 58 |
+
self.conv = nn.Sequential(
|
| 59 |
+
# pw
|
| 60 |
+
nn.Conv2d(inp, hidden_dim, 1, 1, 0, bias=False),
|
| 61 |
+
BatchNorm2d(hidden_dim),
|
| 62 |
+
nn.ReLU6(inplace=True),
|
| 63 |
+
# dw
|
| 64 |
+
nn.Conv2d(hidden_dim, hidden_dim, 3, stride, 1, groups=hidden_dim, bias=False),
|
| 65 |
+
BatchNorm2d(hidden_dim),
|
| 66 |
+
nn.ReLU6(inplace=True),
|
| 67 |
+
# pw-linear
|
| 68 |
+
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
|
| 69 |
+
BatchNorm2d(oup),
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
if self.use_res_connect:
|
| 74 |
+
return x + self.conv(x)
|
| 75 |
+
else:
|
| 76 |
+
return self.conv(x)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class MobileNetV2(nn.Module):
|
| 80 |
+
def __init__(self, n_class=1000, input_size=224, width_mult=1.):
|
| 81 |
+
super(MobileNetV2, self).__init__()
|
| 82 |
+
block = InvertedResidual
|
| 83 |
+
input_channel = 32
|
| 84 |
+
last_channel = 1280
|
| 85 |
+
interverted_residual_setting = [
|
| 86 |
+
# t, c, n, s
|
| 87 |
+
[1, 16, 1, 1],
|
| 88 |
+
[6, 24, 2, 2],
|
| 89 |
+
[6, 32, 3, 2],
|
| 90 |
+
[6, 64, 4, 2],
|
| 91 |
+
[6, 96, 3, 1],
|
| 92 |
+
[6, 160, 3, 2],
|
| 93 |
+
[6, 320, 1, 1],
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
# building first layer
|
| 97 |
+
assert input_size % 32 == 0
|
| 98 |
+
input_channel = int(input_channel * width_mult)
|
| 99 |
+
self.last_channel = int(last_channel * width_mult) if width_mult > 1.0 else last_channel
|
| 100 |
+
self.features = [conv_bn(3, input_channel, 2)]
|
| 101 |
+
# building inverted residual blocks
|
| 102 |
+
for t, c, n, s in interverted_residual_setting:
|
| 103 |
+
output_channel = int(c * width_mult)
|
| 104 |
+
for i in range(n):
|
| 105 |
+
if i == 0:
|
| 106 |
+
self.features.append(block(input_channel, output_channel, s, expand_ratio=t))
|
| 107 |
+
else:
|
| 108 |
+
self.features.append(block(input_channel, output_channel, 1, expand_ratio=t))
|
| 109 |
+
input_channel = output_channel
|
| 110 |
+
# building last several layers
|
| 111 |
+
self.features.append(conv_1x1_bn(input_channel, self.last_channel))
|
| 112 |
+
# make it nn.Sequential
|
| 113 |
+
self.features = nn.Sequential(*self.features)
|
| 114 |
+
|
| 115 |
+
# building classifier
|
| 116 |
+
self.classifier = nn.Sequential(
|
| 117 |
+
nn.Dropout(0.2),
|
| 118 |
+
nn.Linear(self.last_channel, n_class),
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
self._initialize_weights()
|
| 122 |
+
|
| 123 |
+
def forward(self, x):
|
| 124 |
+
x = self.features(x)
|
| 125 |
+
x = x.mean(3).mean(2)
|
| 126 |
+
x = self.classifier(x)
|
| 127 |
+
return x
|
| 128 |
+
|
| 129 |
+
def _initialize_weights(self):
|
| 130 |
+
for m in self.modules():
|
| 131 |
+
if isinstance(m, nn.Conv2d):
|
| 132 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 133 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 134 |
+
if m.bias is not None:
|
| 135 |
+
m.bias.data.zero_()
|
| 136 |
+
elif isinstance(m, BatchNorm2d):
|
| 137 |
+
m.weight.data.fill_(1)
|
| 138 |
+
m.bias.data.zero_()
|
| 139 |
+
elif isinstance(m, nn.Linear):
|
| 140 |
+
n = m.weight.size(1)
|
| 141 |
+
m.weight.data.normal_(0, 0.01)
|
| 142 |
+
m.bias.data.zero_()
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def mobilenetv2(pretrained=False, **kwargs):
|
| 146 |
+
"""Constructs a MobileNet_V2 model.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 150 |
+
"""
|
| 151 |
+
model = MobileNetV2(n_class=1000, **kwargs)
|
| 152 |
+
if pretrained:
|
| 153 |
+
model.load_state_dict(load_url(model_urls['mobilenetv2']), strict=False)
|
| 154 |
+
return model
|
models/ade20k/object150_info.csv
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Idx,Ratio,Train,Val,Stuff,Name
|
| 2 |
+
1,0.1576,11664,1172,1,wall
|
| 3 |
+
2,0.1072,6046,612,1,building;edifice
|
| 4 |
+
3,0.0878,8265,796,1,sky
|
| 5 |
+
4,0.0621,9336,917,1,floor;flooring
|
| 6 |
+
5,0.0480,6678,641,0,tree
|
| 7 |
+
6,0.0450,6604,643,1,ceiling
|
| 8 |
+
7,0.0398,4023,408,1,road;route
|
| 9 |
+
8,0.0231,1906,199,0,bed
|
| 10 |
+
9,0.0198,4688,460,0,windowpane;window
|
| 11 |
+
10,0.0183,2423,225,1,grass
|
| 12 |
+
11,0.0181,2874,294,0,cabinet
|
| 13 |
+
12,0.0166,3068,310,1,sidewalk;pavement
|
| 14 |
+
13,0.0160,5075,526,0,person;individual;someone;somebody;mortal;soul
|
| 15 |
+
14,0.0151,1804,190,1,earth;ground
|
| 16 |
+
15,0.0118,6666,796,0,door;double;door
|
| 17 |
+
16,0.0110,4269,411,0,table
|
| 18 |
+
17,0.0109,1691,160,1,mountain;mount
|
| 19 |
+
18,0.0104,3999,441,0,plant;flora;plant;life
|
| 20 |
+
19,0.0104,2149,217,0,curtain;drape;drapery;mantle;pall
|
| 21 |
+
20,0.0103,3261,318,0,chair
|
| 22 |
+
21,0.0098,3164,306,0,car;auto;automobile;machine;motorcar
|
| 23 |
+
22,0.0074,709,75,1,water
|
| 24 |
+
23,0.0067,3296,315,0,painting;picture
|
| 25 |
+
24,0.0065,1191,106,0,sofa;couch;lounge
|
| 26 |
+
25,0.0061,1516,162,0,shelf
|
| 27 |
+
26,0.0060,667,69,1,house
|
| 28 |
+
27,0.0053,651,57,1,sea
|
| 29 |
+
28,0.0052,1847,224,0,mirror
|
| 30 |
+
29,0.0046,1158,128,1,rug;carpet;carpeting
|
| 31 |
+
30,0.0044,480,44,1,field
|
| 32 |
+
31,0.0044,1172,98,0,armchair
|
| 33 |
+
32,0.0044,1292,184,0,seat
|
| 34 |
+
33,0.0033,1386,138,0,fence;fencing
|
| 35 |
+
34,0.0031,698,61,0,desk
|
| 36 |
+
35,0.0030,781,73,0,rock;stone
|
| 37 |
+
36,0.0027,380,43,0,wardrobe;closet;press
|
| 38 |
+
37,0.0026,3089,302,0,lamp
|
| 39 |
+
38,0.0024,404,37,0,bathtub;bathing;tub;bath;tub
|
| 40 |
+
39,0.0024,804,99,0,railing;rail
|
| 41 |
+
40,0.0023,1453,153,0,cushion
|
| 42 |
+
41,0.0023,411,37,0,base;pedestal;stand
|
| 43 |
+
42,0.0022,1440,162,0,box
|
| 44 |
+
43,0.0022,800,77,0,column;pillar
|
| 45 |
+
44,0.0020,2650,298,0,signboard;sign
|
| 46 |
+
45,0.0019,549,46,0,chest;of;drawers;chest;bureau;dresser
|
| 47 |
+
46,0.0019,367,36,0,counter
|
| 48 |
+
47,0.0018,311,30,1,sand
|
| 49 |
+
48,0.0018,1181,122,0,sink
|
| 50 |
+
49,0.0018,287,23,1,skyscraper
|
| 51 |
+
50,0.0018,468,38,0,fireplace;hearth;open;fireplace
|
| 52 |
+
51,0.0018,402,43,0,refrigerator;icebox
|
| 53 |
+
52,0.0018,130,12,1,grandstand;covered;stand
|
| 54 |
+
53,0.0018,561,64,1,path
|
| 55 |
+
54,0.0017,880,102,0,stairs;steps
|
| 56 |
+
55,0.0017,86,12,1,runway
|
| 57 |
+
56,0.0017,172,11,0,case;display;case;showcase;vitrine
|
| 58 |
+
57,0.0017,198,18,0,pool;table;billiard;table;snooker;table
|
| 59 |
+
58,0.0017,930,109,0,pillow
|
| 60 |
+
59,0.0015,139,18,0,screen;door;screen
|
| 61 |
+
60,0.0015,564,52,1,stairway;staircase
|
| 62 |
+
61,0.0015,320,26,1,river
|
| 63 |
+
62,0.0015,261,29,1,bridge;span
|
| 64 |
+
63,0.0014,275,22,0,bookcase
|
| 65 |
+
64,0.0014,335,60,0,blind;screen
|
| 66 |
+
65,0.0014,792,75,0,coffee;table;cocktail;table
|
| 67 |
+
66,0.0014,395,49,0,toilet;can;commode;crapper;pot;potty;stool;throne
|
| 68 |
+
67,0.0014,1309,138,0,flower
|
| 69 |
+
68,0.0013,1112,113,0,book
|
| 70 |
+
69,0.0013,266,27,1,hill
|
| 71 |
+
70,0.0013,659,66,0,bench
|
| 72 |
+
71,0.0012,331,31,0,countertop
|
| 73 |
+
72,0.0012,531,56,0,stove;kitchen;stove;range;kitchen;range;cooking;stove
|
| 74 |
+
73,0.0012,369,36,0,palm;palm;tree
|
| 75 |
+
74,0.0012,144,9,0,kitchen;island
|
| 76 |
+
75,0.0011,265,29,0,computer;computing;machine;computing;device;data;processor;electronic;computer;information;processing;system
|
| 77 |
+
76,0.0010,324,33,0,swivel;chair
|
| 78 |
+
77,0.0009,304,27,0,boat
|
| 79 |
+
78,0.0009,170,20,0,bar
|
| 80 |
+
79,0.0009,68,6,0,arcade;machine
|
| 81 |
+
80,0.0009,65,8,1,hovel;hut;hutch;shack;shanty
|
| 82 |
+
81,0.0009,248,25,0,bus;autobus;coach;charabanc;double-decker;jitney;motorbus;motorcoach;omnibus;passenger;vehicle
|
| 83 |
+
82,0.0008,492,49,0,towel
|
| 84 |
+
83,0.0008,2510,269,0,light;light;source
|
| 85 |
+
84,0.0008,440,39,0,truck;motortruck
|
| 86 |
+
85,0.0008,147,18,1,tower
|
| 87 |
+
86,0.0008,583,56,0,chandelier;pendant;pendent
|
| 88 |
+
87,0.0007,533,61,0,awning;sunshade;sunblind
|
| 89 |
+
88,0.0007,1989,239,0,streetlight;street;lamp
|
| 90 |
+
89,0.0007,71,5,0,booth;cubicle;stall;kiosk
|
| 91 |
+
90,0.0007,618,53,0,television;television;receiver;television;set;tv;tv;set;idiot;box;boob;tube;telly;goggle;box
|
| 92 |
+
91,0.0007,135,12,0,airplane;aeroplane;plane
|
| 93 |
+
92,0.0007,83,5,1,dirt;track
|
| 94 |
+
93,0.0007,178,17,0,apparel;wearing;apparel;dress;clothes
|
| 95 |
+
94,0.0006,1003,104,0,pole
|
| 96 |
+
95,0.0006,182,12,1,land;ground;soil
|
| 97 |
+
96,0.0006,452,50,0,bannister;banister;balustrade;balusters;handrail
|
| 98 |
+
97,0.0006,42,6,1,escalator;moving;staircase;moving;stairway
|
| 99 |
+
98,0.0006,307,31,0,ottoman;pouf;pouffe;puff;hassock
|
| 100 |
+
99,0.0006,965,114,0,bottle
|
| 101 |
+
100,0.0006,117,13,0,buffet;counter;sideboard
|
| 102 |
+
101,0.0006,354,35,0,poster;posting;placard;notice;bill;card
|
| 103 |
+
102,0.0006,108,9,1,stage
|
| 104 |
+
103,0.0006,557,55,0,van
|
| 105 |
+
104,0.0006,52,4,0,ship
|
| 106 |
+
105,0.0005,99,5,0,fountain
|
| 107 |
+
106,0.0005,57,4,1,conveyer;belt;conveyor;belt;conveyer;conveyor;transporter
|
| 108 |
+
107,0.0005,292,31,0,canopy
|
| 109 |
+
108,0.0005,77,9,0,washer;automatic;washer;washing;machine
|
| 110 |
+
109,0.0005,340,38,0,plaything;toy
|
| 111 |
+
110,0.0005,66,3,1,swimming;pool;swimming;bath;natatorium
|
| 112 |
+
111,0.0005,465,49,0,stool
|
| 113 |
+
112,0.0005,50,4,0,barrel;cask
|
| 114 |
+
113,0.0005,622,75,0,basket;handbasket
|
| 115 |
+
114,0.0005,80,9,1,waterfall;falls
|
| 116 |
+
115,0.0005,59,3,0,tent;collapsible;shelter
|
| 117 |
+
116,0.0005,531,72,0,bag
|
| 118 |
+
117,0.0005,282,30,0,minibike;motorbike
|
| 119 |
+
118,0.0005,73,7,0,cradle
|
| 120 |
+
119,0.0005,435,44,0,oven
|
| 121 |
+
120,0.0005,136,25,0,ball
|
| 122 |
+
121,0.0005,116,24,0,food;solid;food
|
| 123 |
+
122,0.0004,266,31,0,step;stair
|
| 124 |
+
123,0.0004,58,12,0,tank;storage;tank
|
| 125 |
+
124,0.0004,418,83,0,trade;name;brand;name;brand;marque
|
| 126 |
+
125,0.0004,319,43,0,microwave;microwave;oven
|
| 127 |
+
126,0.0004,1193,139,0,pot;flowerpot
|
| 128 |
+
127,0.0004,97,23,0,animal;animate;being;beast;brute;creature;fauna
|
| 129 |
+
128,0.0004,347,36,0,bicycle;bike;wheel;cycle
|
| 130 |
+
129,0.0004,52,5,1,lake
|
| 131 |
+
130,0.0004,246,22,0,dishwasher;dish;washer;dishwashing;machine
|
| 132 |
+
131,0.0004,108,13,0,screen;silver;screen;projection;screen
|
| 133 |
+
132,0.0004,201,30,0,blanket;cover
|
| 134 |
+
133,0.0004,285,21,0,sculpture
|
| 135 |
+
134,0.0004,268,27,0,hood;exhaust;hood
|
| 136 |
+
135,0.0003,1020,108,0,sconce
|
| 137 |
+
136,0.0003,1282,122,0,vase
|
| 138 |
+
137,0.0003,528,65,0,traffic;light;traffic;signal;stoplight
|
| 139 |
+
138,0.0003,453,57,0,tray
|
| 140 |
+
139,0.0003,671,100,0,ashcan;trash;can;garbage;can;wastebin;ash;bin;ash-bin;ashbin;dustbin;trash;barrel;trash;bin
|
| 141 |
+
140,0.0003,397,44,0,fan
|
| 142 |
+
141,0.0003,92,8,1,pier;wharf;wharfage;dock
|
| 143 |
+
142,0.0003,228,18,0,crt;screen
|
| 144 |
+
143,0.0003,570,59,0,plate
|
| 145 |
+
144,0.0003,217,22,0,monitor;monitoring;device
|
| 146 |
+
145,0.0003,206,19,0,bulletin;board;notice;board
|
| 147 |
+
146,0.0003,130,14,0,shower
|
| 148 |
+
147,0.0003,178,28,0,radiator
|
| 149 |
+
148,0.0002,504,57,0,glass;drinking;glass
|
| 150 |
+
149,0.0002,775,96,0,clock
|
| 151 |
+
150,0.0002,421,56,0,flag
|
models/ade20k/resnet.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Modified from https://github.com/CSAILVision/semantic-segmentation-pytorch"""
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from torch.nn import BatchNorm2d
|
| 7 |
+
|
| 8 |
+
from .utils import load_url
|
| 9 |
+
|
| 10 |
+
__all__ = ['ResNet', 'resnet50']
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
model_urls = {
|
| 14 |
+
'resnet50': 'http://sceneparsing.csail.mit.edu/model/pretrained_resnet/resnet50-imagenet.pth',
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 19 |
+
"3x3 convolution with padding"
|
| 20 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 21 |
+
padding=1, bias=False)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class BasicBlock(nn.Module):
|
| 25 |
+
expansion = 1
|
| 26 |
+
|
| 27 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 28 |
+
super(BasicBlock, self).__init__()
|
| 29 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 30 |
+
self.bn1 = BatchNorm2d(planes)
|
| 31 |
+
self.relu = nn.ReLU(inplace=True)
|
| 32 |
+
self.conv2 = conv3x3(planes, planes)
|
| 33 |
+
self.bn2 = BatchNorm2d(planes)
|
| 34 |
+
self.downsample = downsample
|
| 35 |
+
self.stride = stride
|
| 36 |
+
|
| 37 |
+
def forward(self, x):
|
| 38 |
+
residual = x
|
| 39 |
+
|
| 40 |
+
out = self.conv1(x)
|
| 41 |
+
out = self.bn1(out)
|
| 42 |
+
out = self.relu(out)
|
| 43 |
+
|
| 44 |
+
out = self.conv2(out)
|
| 45 |
+
out = self.bn2(out)
|
| 46 |
+
|
| 47 |
+
if self.downsample is not None:
|
| 48 |
+
residual = self.downsample(x)
|
| 49 |
+
|
| 50 |
+
out += residual
|
| 51 |
+
out = self.relu(out)
|
| 52 |
+
|
| 53 |
+
return out
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class Bottleneck(nn.Module):
|
| 57 |
+
expansion = 4
|
| 58 |
+
|
| 59 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 60 |
+
super(Bottleneck, self).__init__()
|
| 61 |
+
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
| 62 |
+
self.bn1 = BatchNorm2d(planes)
|
| 63 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
| 64 |
+
padding=1, bias=False)
|
| 65 |
+
self.bn2 = BatchNorm2d(planes)
|
| 66 |
+
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
| 67 |
+
self.bn3 = BatchNorm2d(planes * 4)
|
| 68 |
+
self.relu = nn.ReLU(inplace=True)
|
| 69 |
+
self.downsample = downsample
|
| 70 |
+
self.stride = stride
|
| 71 |
+
|
| 72 |
+
def forward(self, x):
|
| 73 |
+
residual = x
|
| 74 |
+
|
| 75 |
+
out = self.conv1(x)
|
| 76 |
+
out = self.bn1(out)
|
| 77 |
+
out = self.relu(out)
|
| 78 |
+
|
| 79 |
+
out = self.conv2(out)
|
| 80 |
+
out = self.bn2(out)
|
| 81 |
+
out = self.relu(out)
|
| 82 |
+
|
| 83 |
+
out = self.conv3(out)
|
| 84 |
+
out = self.bn3(out)
|
| 85 |
+
|
| 86 |
+
if self.downsample is not None:
|
| 87 |
+
residual = self.downsample(x)
|
| 88 |
+
|
| 89 |
+
out += residual
|
| 90 |
+
out = self.relu(out)
|
| 91 |
+
|
| 92 |
+
return out
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class ResNet(nn.Module):
|
| 96 |
+
|
| 97 |
+
def __init__(self, block, layers, num_classes=1000):
|
| 98 |
+
self.inplanes = 128
|
| 99 |
+
super(ResNet, self).__init__()
|
| 100 |
+
self.conv1 = conv3x3(3, 64, stride=2)
|
| 101 |
+
self.bn1 = BatchNorm2d(64)
|
| 102 |
+
self.relu1 = nn.ReLU(inplace=True)
|
| 103 |
+
self.conv2 = conv3x3(64, 64)
|
| 104 |
+
self.bn2 = BatchNorm2d(64)
|
| 105 |
+
self.relu2 = nn.ReLU(inplace=True)
|
| 106 |
+
self.conv3 = conv3x3(64, 128)
|
| 107 |
+
self.bn3 = BatchNorm2d(128)
|
| 108 |
+
self.relu3 = nn.ReLU(inplace=True)
|
| 109 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 110 |
+
|
| 111 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 112 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 113 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 114 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
| 115 |
+
self.avgpool = nn.AvgPool2d(7, stride=1)
|
| 116 |
+
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
| 117 |
+
|
| 118 |
+
for m in self.modules():
|
| 119 |
+
if isinstance(m, nn.Conv2d):
|
| 120 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 121 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
| 122 |
+
elif isinstance(m, BatchNorm2d):
|
| 123 |
+
m.weight.data.fill_(1)
|
| 124 |
+
m.bias.data.zero_()
|
| 125 |
+
|
| 126 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 127 |
+
downsample = None
|
| 128 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 129 |
+
downsample = nn.Sequential(
|
| 130 |
+
nn.Conv2d(self.inplanes, planes * block.expansion,
|
| 131 |
+
kernel_size=1, stride=stride, bias=False),
|
| 132 |
+
BatchNorm2d(planes * block.expansion),
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
layers = []
|
| 136 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 137 |
+
self.inplanes = planes * block.expansion
|
| 138 |
+
for i in range(1, blocks):
|
| 139 |
+
layers.append(block(self.inplanes, planes))
|
| 140 |
+
|
| 141 |
+
return nn.Sequential(*layers)
|
| 142 |
+
|
| 143 |
+
def forward(self, x):
|
| 144 |
+
x = self.relu1(self.bn1(self.conv1(x)))
|
| 145 |
+
x = self.relu2(self.bn2(self.conv2(x)))
|
| 146 |
+
x = self.relu3(self.bn3(self.conv3(x)))
|
| 147 |
+
x = self.maxpool(x)
|
| 148 |
+
|
| 149 |
+
x = self.layer1(x)
|
| 150 |
+
x = self.layer2(x)
|
| 151 |
+
x = self.layer3(x)
|
| 152 |
+
x = self.layer4(x)
|
| 153 |
+
|
| 154 |
+
x = self.avgpool(x)
|
| 155 |
+
x = x.view(x.size(0), -1)
|
| 156 |
+
x = self.fc(x)
|
| 157 |
+
|
| 158 |
+
return x
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def resnet50(pretrained=False, **kwargs):
|
| 162 |
+
"""Constructs a ResNet-50 model.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 166 |
+
"""
|
| 167 |
+
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
| 168 |
+
if pretrained:
|
| 169 |
+
model.load_state_dict(load_url(model_urls['resnet50']), strict=False)
|
| 170 |
+
return model
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def resnet18(pretrained=False, **kwargs):
|
| 174 |
+
"""Constructs a ResNet-18 model.
|
| 175 |
+
Args:
|
| 176 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 177 |
+
"""
|
| 178 |
+
model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
|
| 179 |
+
if pretrained:
|
| 180 |
+
model.load_state_dict(load_url(model_urls['resnet18']))
|
| 181 |
+
return model
|
models/ade20k/segm_lib/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
models/ade20k/segm_lib/nn/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
models/ade20k/segm_lib/nn/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .modules import *
|
| 2 |
+
from .parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
|
models/ade20k/segm_lib/nn/modules/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# File : __init__.py
|
| 3 |
+
# Author : Jiayuan Mao
|
| 4 |
+
# Email : [email protected]
|
| 5 |
+
# Date : 27/01/2018
|
| 6 |
+
#
|
| 7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
| 8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
| 9 |
+
# Distributed under MIT License.
|
| 10 |
+
|
| 11 |
+
from .batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, SynchronizedBatchNorm3d
|
| 12 |
+
from .replicate import DataParallelWithCallback, patch_replication_callback
|
models/ade20k/segm_lib/nn/modules/batchnorm.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# File : batchnorm.py
|
| 3 |
+
# Author : Jiayuan Mao
|
| 4 |
+
# Email : [email protected]
|
| 5 |
+
# Date : 27/01/2018
|
| 6 |
+
#
|
| 7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
| 8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
| 9 |
+
# Distributed under MIT License.
|
| 10 |
+
|
| 11 |
+
import collections
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn.functional as F
|
| 15 |
+
|
| 16 |
+
from torch.nn.modules.batchnorm import _BatchNorm
|
| 17 |
+
from torch.nn.parallel._functions import ReduceAddCoalesced, Broadcast
|
| 18 |
+
|
| 19 |
+
from .comm import SyncMaster
|
| 20 |
+
|
| 21 |
+
__all__ = ['SynchronizedBatchNorm1d', 'SynchronizedBatchNorm2d', 'SynchronizedBatchNorm3d']
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _sum_ft(tensor):
|
| 25 |
+
"""sum over the first and last dimention"""
|
| 26 |
+
return tensor.sum(dim=0).sum(dim=-1)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def _unsqueeze_ft(tensor):
|
| 30 |
+
"""add new dementions at the front and the tail"""
|
| 31 |
+
return tensor.unsqueeze(0).unsqueeze(-1)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
_ChildMessage = collections.namedtuple('_ChildMessage', ['sum', 'ssum', 'sum_size'])
|
| 35 |
+
_MasterMessage = collections.namedtuple('_MasterMessage', ['sum', 'inv_std'])
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class _SynchronizedBatchNorm(_BatchNorm):
|
| 39 |
+
def __init__(self, num_features, eps=1e-5, momentum=0.001, affine=True):
|
| 40 |
+
super(_SynchronizedBatchNorm, self).__init__(num_features, eps=eps, momentum=momentum, affine=affine)
|
| 41 |
+
|
| 42 |
+
self._sync_master = SyncMaster(self._data_parallel_master)
|
| 43 |
+
|
| 44 |
+
self._is_parallel = False
|
| 45 |
+
self._parallel_id = None
|
| 46 |
+
self._slave_pipe = None
|
| 47 |
+
|
| 48 |
+
# customed batch norm statistics
|
| 49 |
+
self._moving_average_fraction = 1. - momentum
|
| 50 |
+
self.register_buffer('_tmp_running_mean', torch.zeros(self.num_features))
|
| 51 |
+
self.register_buffer('_tmp_running_var', torch.ones(self.num_features))
|
| 52 |
+
self.register_buffer('_running_iter', torch.ones(1))
|
| 53 |
+
self._tmp_running_mean = self.running_mean.clone() * self._running_iter
|
| 54 |
+
self._tmp_running_var = self.running_var.clone() * self._running_iter
|
| 55 |
+
|
| 56 |
+
def forward(self, input):
|
| 57 |
+
# If it is not parallel computation or is in evaluation mode, use PyTorch's implementation.
|
| 58 |
+
if not (self._is_parallel and self.training):
|
| 59 |
+
return F.batch_norm(
|
| 60 |
+
input, self.running_mean, self.running_var, self.weight, self.bias,
|
| 61 |
+
self.training, self.momentum, self.eps)
|
| 62 |
+
|
| 63 |
+
# Resize the input to (B, C, -1).
|
| 64 |
+
input_shape = input.size()
|
| 65 |
+
input = input.view(input.size(0), self.num_features, -1)
|
| 66 |
+
|
| 67 |
+
# Compute the sum and square-sum.
|
| 68 |
+
sum_size = input.size(0) * input.size(2)
|
| 69 |
+
input_sum = _sum_ft(input)
|
| 70 |
+
input_ssum = _sum_ft(input ** 2)
|
| 71 |
+
|
| 72 |
+
# Reduce-and-broadcast the statistics.
|
| 73 |
+
if self._parallel_id == 0:
|
| 74 |
+
mean, inv_std = self._sync_master.run_master(_ChildMessage(input_sum, input_ssum, sum_size))
|
| 75 |
+
else:
|
| 76 |
+
mean, inv_std = self._slave_pipe.run_slave(_ChildMessage(input_sum, input_ssum, sum_size))
|
| 77 |
+
|
| 78 |
+
# Compute the output.
|
| 79 |
+
if self.affine:
|
| 80 |
+
# MJY:: Fuse the multiplication for speed.
|
| 81 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std * self.weight) + _unsqueeze_ft(self.bias)
|
| 82 |
+
else:
|
| 83 |
+
output = (input - _unsqueeze_ft(mean)) * _unsqueeze_ft(inv_std)
|
| 84 |
+
|
| 85 |
+
# Reshape it.
|
| 86 |
+
return output.view(input_shape)
|
| 87 |
+
|
| 88 |
+
def __data_parallel_replicate__(self, ctx, copy_id):
|
| 89 |
+
self._is_parallel = True
|
| 90 |
+
self._parallel_id = copy_id
|
| 91 |
+
|
| 92 |
+
# parallel_id == 0 means master device.
|
| 93 |
+
if self._parallel_id == 0:
|
| 94 |
+
ctx.sync_master = self._sync_master
|
| 95 |
+
else:
|
| 96 |
+
self._slave_pipe = ctx.sync_master.register_slave(copy_id)
|
| 97 |
+
|
| 98 |
+
def _data_parallel_master(self, intermediates):
|
| 99 |
+
"""Reduce the sum and square-sum, compute the statistics, and broadcast it."""
|
| 100 |
+
intermediates = sorted(intermediates, key=lambda i: i[1].sum.get_device())
|
| 101 |
+
|
| 102 |
+
to_reduce = [i[1][:2] for i in intermediates]
|
| 103 |
+
to_reduce = [j for i in to_reduce for j in i] # flatten
|
| 104 |
+
target_gpus = [i[1].sum.get_device() for i in intermediates]
|
| 105 |
+
|
| 106 |
+
sum_size = sum([i[1].sum_size for i in intermediates])
|
| 107 |
+
sum_, ssum = ReduceAddCoalesced.apply(target_gpus[0], 2, *to_reduce)
|
| 108 |
+
|
| 109 |
+
mean, inv_std = self._compute_mean_std(sum_, ssum, sum_size)
|
| 110 |
+
|
| 111 |
+
broadcasted = Broadcast.apply(target_gpus, mean, inv_std)
|
| 112 |
+
|
| 113 |
+
outputs = []
|
| 114 |
+
for i, rec in enumerate(intermediates):
|
| 115 |
+
outputs.append((rec[0], _MasterMessage(*broadcasted[i*2:i*2+2])))
|
| 116 |
+
|
| 117 |
+
return outputs
|
| 118 |
+
|
| 119 |
+
def _add_weighted(self, dest, delta, alpha=1, beta=1, bias=0):
|
| 120 |
+
"""return *dest* by `dest := dest*alpha + delta*beta + bias`"""
|
| 121 |
+
return dest * alpha + delta * beta + bias
|
| 122 |
+
|
| 123 |
+
def _compute_mean_std(self, sum_, ssum, size):
|
| 124 |
+
"""Compute the mean and standard-deviation with sum and square-sum. This method
|
| 125 |
+
also maintains the moving average on the master device."""
|
| 126 |
+
assert size > 1, 'BatchNorm computes unbiased standard-deviation, which requires size > 1.'
|
| 127 |
+
mean = sum_ / size
|
| 128 |
+
sumvar = ssum - sum_ * mean
|
| 129 |
+
unbias_var = sumvar / (size - 1)
|
| 130 |
+
bias_var = sumvar / size
|
| 131 |
+
|
| 132 |
+
self._tmp_running_mean = self._add_weighted(self._tmp_running_mean, mean.data, alpha=self._moving_average_fraction)
|
| 133 |
+
self._tmp_running_var = self._add_weighted(self._tmp_running_var, unbias_var.data, alpha=self._moving_average_fraction)
|
| 134 |
+
self._running_iter = self._add_weighted(self._running_iter, 1, alpha=self._moving_average_fraction)
|
| 135 |
+
|
| 136 |
+
self.running_mean = self._tmp_running_mean / self._running_iter
|
| 137 |
+
self.running_var = self._tmp_running_var / self._running_iter
|
| 138 |
+
|
| 139 |
+
return mean, bias_var.clamp(self.eps) ** -0.5
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class SynchronizedBatchNorm1d(_SynchronizedBatchNorm):
|
| 143 |
+
r"""Applies Synchronized Batch Normalization over a 2d or 3d input that is seen as a
|
| 144 |
+
mini-batch.
|
| 145 |
+
|
| 146 |
+
.. math::
|
| 147 |
+
|
| 148 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
| 149 |
+
|
| 150 |
+
This module differs from the built-in PyTorch BatchNorm1d as the mean and
|
| 151 |
+
standard-deviation are reduced across all devices during training.
|
| 152 |
+
|
| 153 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
| 154 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
| 155 |
+
the statistics only on that device, which accelerated the computation and
|
| 156 |
+
is also easy to implement, but the statistics might be inaccurate.
|
| 157 |
+
Instead, in this synchronized version, the statistics will be computed
|
| 158 |
+
over all training samples distributed on multiple devices.
|
| 159 |
+
|
| 160 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
| 161 |
+
as the built-in PyTorch implementation.
|
| 162 |
+
|
| 163 |
+
The mean and standard-deviation are calculated per-dimension over
|
| 164 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
| 165 |
+
of size C (where C is the input size).
|
| 166 |
+
|
| 167 |
+
During training, this layer keeps a running estimate of its computed mean
|
| 168 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
| 169 |
+
|
| 170 |
+
During evaluation, this running mean/variance is used for normalization.
|
| 171 |
+
|
| 172 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
| 173 |
+
on `(N, L)` slices, it's common terminology to call this Temporal BatchNorm
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
num_features: num_features from an expected input of size
|
| 177 |
+
`batch_size x num_features [x width]`
|
| 178 |
+
eps: a value added to the denominator for numerical stability.
|
| 179 |
+
Default: 1e-5
|
| 180 |
+
momentum: the value used for the running_mean and running_var
|
| 181 |
+
computation. Default: 0.1
|
| 182 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
| 183 |
+
affine parameters. Default: ``True``
|
| 184 |
+
|
| 185 |
+
Shape:
|
| 186 |
+
- Input: :math:`(N, C)` or :math:`(N, C, L)`
|
| 187 |
+
- Output: :math:`(N, C)` or :math:`(N, C, L)` (same shape as input)
|
| 188 |
+
|
| 189 |
+
Examples:
|
| 190 |
+
>>> # With Learnable Parameters
|
| 191 |
+
>>> m = SynchronizedBatchNorm1d(100)
|
| 192 |
+
>>> # Without Learnable Parameters
|
| 193 |
+
>>> m = SynchronizedBatchNorm1d(100, affine=False)
|
| 194 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100))
|
| 195 |
+
>>> output = m(input)
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
def _check_input_dim(self, input):
|
| 199 |
+
if input.dim() != 2 and input.dim() != 3:
|
| 200 |
+
raise ValueError('expected 2D or 3D input (got {}D input)'
|
| 201 |
+
.format(input.dim()))
|
| 202 |
+
super(SynchronizedBatchNorm1d, self)._check_input_dim(input)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class SynchronizedBatchNorm2d(_SynchronizedBatchNorm):
|
| 206 |
+
r"""Applies Batch Normalization over a 4d input that is seen as a mini-batch
|
| 207 |
+
of 3d inputs
|
| 208 |
+
|
| 209 |
+
.. math::
|
| 210 |
+
|
| 211 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
| 212 |
+
|
| 213 |
+
This module differs from the built-in PyTorch BatchNorm2d as the mean and
|
| 214 |
+
standard-deviation are reduced across all devices during training.
|
| 215 |
+
|
| 216 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
| 217 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
| 218 |
+
the statistics only on that device, which accelerated the computation and
|
| 219 |
+
is also easy to implement, but the statistics might be inaccurate.
|
| 220 |
+
Instead, in this synchronized version, the statistics will be computed
|
| 221 |
+
over all training samples distributed on multiple devices.
|
| 222 |
+
|
| 223 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
| 224 |
+
as the built-in PyTorch implementation.
|
| 225 |
+
|
| 226 |
+
The mean and standard-deviation are calculated per-dimension over
|
| 227 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
| 228 |
+
of size C (where C is the input size).
|
| 229 |
+
|
| 230 |
+
During training, this layer keeps a running estimate of its computed mean
|
| 231 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
| 232 |
+
|
| 233 |
+
During evaluation, this running mean/variance is used for normalization.
|
| 234 |
+
|
| 235 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
| 236 |
+
on `(N, H, W)` slices, it's common terminology to call this Spatial BatchNorm
|
| 237 |
+
|
| 238 |
+
Args:
|
| 239 |
+
num_features: num_features from an expected input of
|
| 240 |
+
size batch_size x num_features x height x width
|
| 241 |
+
eps: a value added to the denominator for numerical stability.
|
| 242 |
+
Default: 1e-5
|
| 243 |
+
momentum: the value used for the running_mean and running_var
|
| 244 |
+
computation. Default: 0.1
|
| 245 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
| 246 |
+
affine parameters. Default: ``True``
|
| 247 |
+
|
| 248 |
+
Shape:
|
| 249 |
+
- Input: :math:`(N, C, H, W)`
|
| 250 |
+
- Output: :math:`(N, C, H, W)` (same shape as input)
|
| 251 |
+
|
| 252 |
+
Examples:
|
| 253 |
+
>>> # With Learnable Parameters
|
| 254 |
+
>>> m = SynchronizedBatchNorm2d(100)
|
| 255 |
+
>>> # Without Learnable Parameters
|
| 256 |
+
>>> m = SynchronizedBatchNorm2d(100, affine=False)
|
| 257 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45))
|
| 258 |
+
>>> output = m(input)
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
def _check_input_dim(self, input):
|
| 262 |
+
if input.dim() != 4:
|
| 263 |
+
raise ValueError('expected 4D input (got {}D input)'
|
| 264 |
+
.format(input.dim()))
|
| 265 |
+
super(SynchronizedBatchNorm2d, self)._check_input_dim(input)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class SynchronizedBatchNorm3d(_SynchronizedBatchNorm):
|
| 269 |
+
r"""Applies Batch Normalization over a 5d input that is seen as a mini-batch
|
| 270 |
+
of 4d inputs
|
| 271 |
+
|
| 272 |
+
.. math::
|
| 273 |
+
|
| 274 |
+
y = \frac{x - mean[x]}{ \sqrt{Var[x] + \epsilon}} * gamma + beta
|
| 275 |
+
|
| 276 |
+
This module differs from the built-in PyTorch BatchNorm3d as the mean and
|
| 277 |
+
standard-deviation are reduced across all devices during training.
|
| 278 |
+
|
| 279 |
+
For example, when one uses `nn.DataParallel` to wrap the network during
|
| 280 |
+
training, PyTorch's implementation normalize the tensor on each device using
|
| 281 |
+
the statistics only on that device, which accelerated the computation and
|
| 282 |
+
is also easy to implement, but the statistics might be inaccurate.
|
| 283 |
+
Instead, in this synchronized version, the statistics will be computed
|
| 284 |
+
over all training samples distributed on multiple devices.
|
| 285 |
+
|
| 286 |
+
Note that, for one-GPU or CPU-only case, this module behaves exactly same
|
| 287 |
+
as the built-in PyTorch implementation.
|
| 288 |
+
|
| 289 |
+
The mean and standard-deviation are calculated per-dimension over
|
| 290 |
+
the mini-batches and gamma and beta are learnable parameter vectors
|
| 291 |
+
of size C (where C is the input size).
|
| 292 |
+
|
| 293 |
+
During training, this layer keeps a running estimate of its computed mean
|
| 294 |
+
and variance. The running sum is kept with a default momentum of 0.1.
|
| 295 |
+
|
| 296 |
+
During evaluation, this running mean/variance is used for normalization.
|
| 297 |
+
|
| 298 |
+
Because the BatchNorm is done over the `C` dimension, computing statistics
|
| 299 |
+
on `(N, D, H, W)` slices, it's common terminology to call this Volumetric BatchNorm
|
| 300 |
+
or Spatio-temporal BatchNorm
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
num_features: num_features from an expected input of
|
| 304 |
+
size batch_size x num_features x depth x height x width
|
| 305 |
+
eps: a value added to the denominator for numerical stability.
|
| 306 |
+
Default: 1e-5
|
| 307 |
+
momentum: the value used for the running_mean and running_var
|
| 308 |
+
computation. Default: 0.1
|
| 309 |
+
affine: a boolean value that when set to ``True``, gives the layer learnable
|
| 310 |
+
affine parameters. Default: ``True``
|
| 311 |
+
|
| 312 |
+
Shape:
|
| 313 |
+
- Input: :math:`(N, C, D, H, W)`
|
| 314 |
+
- Output: :math:`(N, C, D, H, W)` (same shape as input)
|
| 315 |
+
|
| 316 |
+
Examples:
|
| 317 |
+
>>> # With Learnable Parameters
|
| 318 |
+
>>> m = SynchronizedBatchNorm3d(100)
|
| 319 |
+
>>> # Without Learnable Parameters
|
| 320 |
+
>>> m = SynchronizedBatchNorm3d(100, affine=False)
|
| 321 |
+
>>> input = torch.autograd.Variable(torch.randn(20, 100, 35, 45, 10))
|
| 322 |
+
>>> output = m(input)
|
| 323 |
+
"""
|
| 324 |
+
|
| 325 |
+
def _check_input_dim(self, input):
|
| 326 |
+
if input.dim() != 5:
|
| 327 |
+
raise ValueError('expected 5D input (got {}D input)'
|
| 328 |
+
.format(input.dim()))
|
| 329 |
+
super(SynchronizedBatchNorm3d, self)._check_input_dim(input)
|
models/ade20k/segm_lib/nn/modules/comm.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# File : comm.py
|
| 3 |
+
# Author : Jiayuan Mao
|
| 4 |
+
# Email : [email protected]
|
| 5 |
+
# Date : 27/01/2018
|
| 6 |
+
#
|
| 7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
| 8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
| 9 |
+
# Distributed under MIT License.
|
| 10 |
+
|
| 11 |
+
import queue
|
| 12 |
+
import collections
|
| 13 |
+
import threading
|
| 14 |
+
|
| 15 |
+
__all__ = ['FutureResult', 'SlavePipe', 'SyncMaster']
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class FutureResult(object):
|
| 19 |
+
"""A thread-safe future implementation. Used only as one-to-one pipe."""
|
| 20 |
+
|
| 21 |
+
def __init__(self):
|
| 22 |
+
self._result = None
|
| 23 |
+
self._lock = threading.Lock()
|
| 24 |
+
self._cond = threading.Condition(self._lock)
|
| 25 |
+
|
| 26 |
+
def put(self, result):
|
| 27 |
+
with self._lock:
|
| 28 |
+
assert self._result is None, 'Previous result has\'t been fetched.'
|
| 29 |
+
self._result = result
|
| 30 |
+
self._cond.notify()
|
| 31 |
+
|
| 32 |
+
def get(self):
|
| 33 |
+
with self._lock:
|
| 34 |
+
if self._result is None:
|
| 35 |
+
self._cond.wait()
|
| 36 |
+
|
| 37 |
+
res = self._result
|
| 38 |
+
self._result = None
|
| 39 |
+
return res
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
_MasterRegistry = collections.namedtuple('MasterRegistry', ['result'])
|
| 43 |
+
_SlavePipeBase = collections.namedtuple('_SlavePipeBase', ['identifier', 'queue', 'result'])
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class SlavePipe(_SlavePipeBase):
|
| 47 |
+
"""Pipe for master-slave communication."""
|
| 48 |
+
|
| 49 |
+
def run_slave(self, msg):
|
| 50 |
+
self.queue.put((self.identifier, msg))
|
| 51 |
+
ret = self.result.get()
|
| 52 |
+
self.queue.put(True)
|
| 53 |
+
return ret
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class SyncMaster(object):
|
| 57 |
+
"""An abstract `SyncMaster` object.
|
| 58 |
+
|
| 59 |
+
- During the replication, as the data parallel will trigger an callback of each module, all slave devices should
|
| 60 |
+
call `register(id)` and obtain an `SlavePipe` to communicate with the master.
|
| 61 |
+
- During the forward pass, master device invokes `run_master`, all messages from slave devices will be collected,
|
| 62 |
+
and passed to a registered callback.
|
| 63 |
+
- After receiving the messages, the master device should gather the information and determine to message passed
|
| 64 |
+
back to each slave devices.
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
def __init__(self, master_callback):
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
master_callback: a callback to be invoked after having collected messages from slave devices.
|
| 72 |
+
"""
|
| 73 |
+
self._master_callback = master_callback
|
| 74 |
+
self._queue = queue.Queue()
|
| 75 |
+
self._registry = collections.OrderedDict()
|
| 76 |
+
self._activated = False
|
| 77 |
+
|
| 78 |
+
def register_slave(self, identifier):
|
| 79 |
+
"""
|
| 80 |
+
Register an slave device.
|
| 81 |
+
|
| 82 |
+
Args:
|
| 83 |
+
identifier: an identifier, usually is the device id.
|
| 84 |
+
|
| 85 |
+
Returns: a `SlavePipe` object which can be used to communicate with the master device.
|
| 86 |
+
|
| 87 |
+
"""
|
| 88 |
+
if self._activated:
|
| 89 |
+
assert self._queue.empty(), 'Queue is not clean before next initialization.'
|
| 90 |
+
self._activated = False
|
| 91 |
+
self._registry.clear()
|
| 92 |
+
future = FutureResult()
|
| 93 |
+
self._registry[identifier] = _MasterRegistry(future)
|
| 94 |
+
return SlavePipe(identifier, self._queue, future)
|
| 95 |
+
|
| 96 |
+
def run_master(self, master_msg):
|
| 97 |
+
"""
|
| 98 |
+
Main entry for the master device in each forward pass.
|
| 99 |
+
The messages were first collected from each devices (including the master device), and then
|
| 100 |
+
an callback will be invoked to compute the message to be sent back to each devices
|
| 101 |
+
(including the master device).
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
master_msg: the message that the master want to send to itself. This will be placed as the first
|
| 105 |
+
message when calling `master_callback`. For detailed usage, see `_SynchronizedBatchNorm` for an example.
|
| 106 |
+
|
| 107 |
+
Returns: the message to be sent back to the master device.
|
| 108 |
+
|
| 109 |
+
"""
|
| 110 |
+
self._activated = True
|
| 111 |
+
|
| 112 |
+
intermediates = [(0, master_msg)]
|
| 113 |
+
for i in range(self.nr_slaves):
|
| 114 |
+
intermediates.append(self._queue.get())
|
| 115 |
+
|
| 116 |
+
results = self._master_callback(intermediates)
|
| 117 |
+
assert results[0][0] == 0, 'The first result should belongs to the master.'
|
| 118 |
+
|
| 119 |
+
for i, res in results:
|
| 120 |
+
if i == 0:
|
| 121 |
+
continue
|
| 122 |
+
self._registry[i].result.put(res)
|
| 123 |
+
|
| 124 |
+
for i in range(self.nr_slaves):
|
| 125 |
+
assert self._queue.get() is True
|
| 126 |
+
|
| 127 |
+
return results[0][1]
|
| 128 |
+
|
| 129 |
+
@property
|
| 130 |
+
def nr_slaves(self):
|
| 131 |
+
return len(self._registry)
|
models/ade20k/segm_lib/nn/modules/replicate.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# File : replicate.py
|
| 3 |
+
# Author : Jiayuan Mao
|
| 4 |
+
# Email : [email protected]
|
| 5 |
+
# Date : 27/01/2018
|
| 6 |
+
#
|
| 7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
| 8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
| 9 |
+
# Distributed under MIT License.
|
| 10 |
+
|
| 11 |
+
import functools
|
| 12 |
+
|
| 13 |
+
from torch.nn.parallel.data_parallel import DataParallel
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
'CallbackContext',
|
| 17 |
+
'execute_replication_callbacks',
|
| 18 |
+
'DataParallelWithCallback',
|
| 19 |
+
'patch_replication_callback'
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CallbackContext(object):
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def execute_replication_callbacks(modules):
|
| 28 |
+
"""
|
| 29 |
+
Execute an replication callback `__data_parallel_replicate__` on each module created by original replication.
|
| 30 |
+
|
| 31 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
| 32 |
+
|
| 33 |
+
Note that, as all modules are isomorphism, we assign each sub-module with a context
|
| 34 |
+
(shared among multiple copies of this module on different devices).
|
| 35 |
+
Through this context, different copies can share some information.
|
| 36 |
+
|
| 37 |
+
We guarantee that the callback on the master copy (the first copy) will be called ahead of calling the callback
|
| 38 |
+
of any slave copies.
|
| 39 |
+
"""
|
| 40 |
+
master_copy = modules[0]
|
| 41 |
+
nr_modules = len(list(master_copy.modules()))
|
| 42 |
+
ctxs = [CallbackContext() for _ in range(nr_modules)]
|
| 43 |
+
|
| 44 |
+
for i, module in enumerate(modules):
|
| 45 |
+
for j, m in enumerate(module.modules()):
|
| 46 |
+
if hasattr(m, '__data_parallel_replicate__'):
|
| 47 |
+
m.__data_parallel_replicate__(ctxs[j], i)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
class DataParallelWithCallback(DataParallel):
|
| 51 |
+
"""
|
| 52 |
+
Data Parallel with a replication callback.
|
| 53 |
+
|
| 54 |
+
An replication callback `__data_parallel_replicate__` of each module will be invoked after being created by
|
| 55 |
+
original `replicate` function.
|
| 56 |
+
The callback will be invoked with arguments `__data_parallel_replicate__(ctx, copy_id)`
|
| 57 |
+
|
| 58 |
+
Examples:
|
| 59 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
| 60 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
| 61 |
+
# sync_bn.__data_parallel_replicate__ will be invoked.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
def replicate(self, module, device_ids):
|
| 65 |
+
modules = super(DataParallelWithCallback, self).replicate(module, device_ids)
|
| 66 |
+
execute_replication_callbacks(modules)
|
| 67 |
+
return modules
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def patch_replication_callback(data_parallel):
|
| 71 |
+
"""
|
| 72 |
+
Monkey-patch an existing `DataParallel` object. Add the replication callback.
|
| 73 |
+
Useful when you have customized `DataParallel` implementation.
|
| 74 |
+
|
| 75 |
+
Examples:
|
| 76 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
| 77 |
+
> sync_bn = DataParallel(sync_bn, device_ids=[0, 1])
|
| 78 |
+
> patch_replication_callback(sync_bn)
|
| 79 |
+
# this is equivalent to
|
| 80 |
+
> sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
| 81 |
+
> sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
| 82 |
+
"""
|
| 83 |
+
|
| 84 |
+
assert isinstance(data_parallel, DataParallel)
|
| 85 |
+
|
| 86 |
+
old_replicate = data_parallel.replicate
|
| 87 |
+
|
| 88 |
+
@functools.wraps(old_replicate)
|
| 89 |
+
def new_replicate(module, device_ids):
|
| 90 |
+
modules = old_replicate(module, device_ids)
|
| 91 |
+
execute_replication_callbacks(modules)
|
| 92 |
+
return modules
|
| 93 |
+
|
| 94 |
+
data_parallel.replicate = new_replicate
|
models/ade20k/segm_lib/nn/modules/tests/test_numeric_batchnorm.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# File : test_numeric_batchnorm.py
|
| 3 |
+
# Author : Jiayuan Mao
|
| 4 |
+
# Email : [email protected]
|
| 5 |
+
# Date : 27/01/2018
|
| 6 |
+
#
|
| 7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
| 8 |
+
|
| 9 |
+
import unittest
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch.autograd import Variable
|
| 14 |
+
|
| 15 |
+
from sync_batchnorm.unittest import TorchTestCase
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def handy_var(a, unbias=True):
|
| 19 |
+
n = a.size(0)
|
| 20 |
+
asum = a.sum(dim=0)
|
| 21 |
+
as_sum = (a ** 2).sum(dim=0) # a square sum
|
| 22 |
+
sumvar = as_sum - asum * asum / n
|
| 23 |
+
if unbias:
|
| 24 |
+
return sumvar / (n - 1)
|
| 25 |
+
else:
|
| 26 |
+
return sumvar / n
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class NumericTestCase(TorchTestCase):
|
| 30 |
+
def testNumericBatchNorm(self):
|
| 31 |
+
a = torch.rand(16, 10)
|
| 32 |
+
bn = nn.BatchNorm2d(10, momentum=1, eps=1e-5, affine=False)
|
| 33 |
+
bn.train()
|
| 34 |
+
|
| 35 |
+
a_var1 = Variable(a, requires_grad=True)
|
| 36 |
+
b_var1 = bn(a_var1)
|
| 37 |
+
loss1 = b_var1.sum()
|
| 38 |
+
loss1.backward()
|
| 39 |
+
|
| 40 |
+
a_var2 = Variable(a, requires_grad=True)
|
| 41 |
+
a_mean2 = a_var2.mean(dim=0, keepdim=True)
|
| 42 |
+
a_std2 = torch.sqrt(handy_var(a_var2, unbias=False).clamp(min=1e-5))
|
| 43 |
+
# a_std2 = torch.sqrt(a_var2.var(dim=0, keepdim=True, unbiased=False) + 1e-5)
|
| 44 |
+
b_var2 = (a_var2 - a_mean2) / a_std2
|
| 45 |
+
loss2 = b_var2.sum()
|
| 46 |
+
loss2.backward()
|
| 47 |
+
|
| 48 |
+
self.assertTensorClose(bn.running_mean, a.mean(dim=0))
|
| 49 |
+
self.assertTensorClose(bn.running_var, handy_var(a))
|
| 50 |
+
self.assertTensorClose(a_var1.data, a_var2.data)
|
| 51 |
+
self.assertTensorClose(b_var1.data, b_var2.data)
|
| 52 |
+
self.assertTensorClose(a_var1.grad, a_var2.grad)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
if __name__ == '__main__':
|
| 56 |
+
unittest.main()
|
models/ade20k/segm_lib/nn/modules/tests/test_sync_batchnorm.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# File : test_sync_batchnorm.py
|
| 3 |
+
# Author : Jiayuan Mao
|
| 4 |
+
# Email : [email protected]
|
| 5 |
+
# Date : 27/01/2018
|
| 6 |
+
#
|
| 7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
| 8 |
+
|
| 9 |
+
import unittest
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch.autograd import Variable
|
| 14 |
+
|
| 15 |
+
from sync_batchnorm import SynchronizedBatchNorm1d, SynchronizedBatchNorm2d, DataParallelWithCallback
|
| 16 |
+
from sync_batchnorm.unittest import TorchTestCase
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def handy_var(a, unbias=True):
|
| 20 |
+
n = a.size(0)
|
| 21 |
+
asum = a.sum(dim=0)
|
| 22 |
+
as_sum = (a ** 2).sum(dim=0) # a square sum
|
| 23 |
+
sumvar = as_sum - asum * asum / n
|
| 24 |
+
if unbias:
|
| 25 |
+
return sumvar / (n - 1)
|
| 26 |
+
else:
|
| 27 |
+
return sumvar / n
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _find_bn(module):
|
| 31 |
+
for m in module.modules():
|
| 32 |
+
if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, SynchronizedBatchNorm1d, SynchronizedBatchNorm2d)):
|
| 33 |
+
return m
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class SyncTestCase(TorchTestCase):
|
| 37 |
+
def _syncParameters(self, bn1, bn2):
|
| 38 |
+
bn1.reset_parameters()
|
| 39 |
+
bn2.reset_parameters()
|
| 40 |
+
if bn1.affine and bn2.affine:
|
| 41 |
+
bn2.weight.data.copy_(bn1.weight.data)
|
| 42 |
+
bn2.bias.data.copy_(bn1.bias.data)
|
| 43 |
+
|
| 44 |
+
def _checkBatchNormResult(self, bn1, bn2, input, is_train, cuda=False):
|
| 45 |
+
"""Check the forward and backward for the customized batch normalization."""
|
| 46 |
+
bn1.train(mode=is_train)
|
| 47 |
+
bn2.train(mode=is_train)
|
| 48 |
+
|
| 49 |
+
if cuda:
|
| 50 |
+
input = input.cuda()
|
| 51 |
+
|
| 52 |
+
self._syncParameters(_find_bn(bn1), _find_bn(bn2))
|
| 53 |
+
|
| 54 |
+
input1 = Variable(input, requires_grad=True)
|
| 55 |
+
output1 = bn1(input1)
|
| 56 |
+
output1.sum().backward()
|
| 57 |
+
input2 = Variable(input, requires_grad=True)
|
| 58 |
+
output2 = bn2(input2)
|
| 59 |
+
output2.sum().backward()
|
| 60 |
+
|
| 61 |
+
self.assertTensorClose(input1.data, input2.data)
|
| 62 |
+
self.assertTensorClose(output1.data, output2.data)
|
| 63 |
+
self.assertTensorClose(input1.grad, input2.grad)
|
| 64 |
+
self.assertTensorClose(_find_bn(bn1).running_mean, _find_bn(bn2).running_mean)
|
| 65 |
+
self.assertTensorClose(_find_bn(bn1).running_var, _find_bn(bn2).running_var)
|
| 66 |
+
|
| 67 |
+
def testSyncBatchNormNormalTrain(self):
|
| 68 |
+
bn = nn.BatchNorm1d(10)
|
| 69 |
+
sync_bn = SynchronizedBatchNorm1d(10)
|
| 70 |
+
|
| 71 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True)
|
| 72 |
+
|
| 73 |
+
def testSyncBatchNormNormalEval(self):
|
| 74 |
+
bn = nn.BatchNorm1d(10)
|
| 75 |
+
sync_bn = SynchronizedBatchNorm1d(10)
|
| 76 |
+
|
| 77 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False)
|
| 78 |
+
|
| 79 |
+
def testSyncBatchNormSyncTrain(self):
|
| 80 |
+
bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
|
| 81 |
+
sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
| 82 |
+
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
| 83 |
+
|
| 84 |
+
bn.cuda()
|
| 85 |
+
sync_bn.cuda()
|
| 86 |
+
|
| 87 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), True, cuda=True)
|
| 88 |
+
|
| 89 |
+
def testSyncBatchNormSyncEval(self):
|
| 90 |
+
bn = nn.BatchNorm1d(10, eps=1e-5, affine=False)
|
| 91 |
+
sync_bn = SynchronizedBatchNorm1d(10, eps=1e-5, affine=False)
|
| 92 |
+
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
| 93 |
+
|
| 94 |
+
bn.cuda()
|
| 95 |
+
sync_bn.cuda()
|
| 96 |
+
|
| 97 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10), False, cuda=True)
|
| 98 |
+
|
| 99 |
+
def testSyncBatchNorm2DSyncTrain(self):
|
| 100 |
+
bn = nn.BatchNorm2d(10)
|
| 101 |
+
sync_bn = SynchronizedBatchNorm2d(10)
|
| 102 |
+
sync_bn = DataParallelWithCallback(sync_bn, device_ids=[0, 1])
|
| 103 |
+
|
| 104 |
+
bn.cuda()
|
| 105 |
+
sync_bn.cuda()
|
| 106 |
+
|
| 107 |
+
self._checkBatchNormResult(bn, sync_bn, torch.rand(16, 10, 16, 16), True, cuda=True)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
if __name__ == '__main__':
|
| 111 |
+
unittest.main()
|
models/ade20k/segm_lib/nn/modules/unittest.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# File : unittest.py
|
| 3 |
+
# Author : Jiayuan Mao
|
| 4 |
+
# Email : [email protected]
|
| 5 |
+
# Date : 27/01/2018
|
| 6 |
+
#
|
| 7 |
+
# This file is part of Synchronized-BatchNorm-PyTorch.
|
| 8 |
+
# https://github.com/vacancy/Synchronized-BatchNorm-PyTorch
|
| 9 |
+
# Distributed under MIT License.
|
| 10 |
+
|
| 11 |
+
import unittest
|
| 12 |
+
|
| 13 |
+
import numpy as np
|
| 14 |
+
from torch.autograd import Variable
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def as_numpy(v):
|
| 18 |
+
if isinstance(v, Variable):
|
| 19 |
+
v = v.data
|
| 20 |
+
return v.cpu().numpy()
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class TorchTestCase(unittest.TestCase):
|
| 24 |
+
def assertTensorClose(self, a, b, atol=1e-3, rtol=1e-3):
|
| 25 |
+
npa, npb = as_numpy(a), as_numpy(b)
|
| 26 |
+
self.assertTrue(
|
| 27 |
+
np.allclose(npa, npb, atol=atol),
|
| 28 |
+
'Tensor close check failed\n{}\n{}\nadiff={}, rdiff={}'.format(a, b, np.abs(npa - npb).max(), np.abs((npa - npb) / np.fmax(npa, 1e-5)).max())
|
| 29 |
+
)
|
models/ade20k/segm_lib/nn/parallel/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .data_parallel import UserScatteredDataParallel, user_scattered_collate, async_copy_to
|