diff --git a/Generation_Pipeline_filter/Atlas_X_1k.txt b/Generation_Pipeline_filter/Atlas_X_1k.txt new file mode 100644 index 0000000000000000000000000000000000000000..9184dd0cbff7a9e917f585256ad7abb4afc6ef31 --- /dev/null +++ b/Generation_Pipeline_filter/Atlas_X_1k.txt @@ -0,0 +1,956 @@ +BDMAP_00002654 +BDMAP_00002173 +BDMAP_00003294 +BDMAP_00001597 +BDMAP_00001557 +BDMAP_00003327 +BDMAP_00002075 +BDMAP_00004887 +BDMAP_00001434 +BDMAP_00001705 +BDMAP_00000710 +BDMAP_00002271 +BDMAP_00003406 +BDMAP_00003556 +BDMAP_00002103 +BDMAP_00002230 +BDMAP_00000427 +BDMAP_00002746 +BDMAP_00003483 +BDMAP_00003543 +BDMAP_00001396 +BDMAP_00000836 +BDMAP_00003808 +BDMAP_00002619 +BDMAP_00004183 +BDMAP_00001562 +BDMAP_00001414 +BDMAP_00004087 +BDMAP_00002704 +BDMAP_00004198 +BDMAP_00000285 +BDMAP_00005077 +BDMAP_00001343 +BDMAP_00002909 +BDMAP_00002849 +BDMAP_00002655 +BDMAP_00001015 +BDMAP_00003592 +BDMAP_00001676 +BDMAP_00001863 +BDMAP_00002404 +BDMAP_00001035 +BDMAP_00003457 +BDMAP_00001782 +BDMAP_00004586 +BDMAP_00004514 +BDMAP_00004165 +BDMAP_00001171 +BDMAP_00005140 +BDMAP_00005037 +BDMAP_00001769 +BDMAP_00004482 +BDMAP_00003551 +BDMAP_00000887 +BDMAP_00004103 +BDMAP_00002689 +BDMAP_00003727 +BDMAP_00002653 +BDMAP_00000034 +BDMAP_00001504 +BDMAP_00000889 +BDMAP_00004992 +BDMAP_00002065 +BDMAP_00003815 +BDMAP_00004494 +BDMAP_00001545 +BDMAP_00004954 +BDMAP_00002332 +BDMAP_00004288 +BDMAP_00005006 +BDMAP_00001865 +BDMAP_00000604 +BDMAP_00004616 +BDMAP_00001359 +BDMAP_00003956 +BDMAP_00004148 +BDMAP_00001426 +BDMAP_00003301 +BDMAP_00003300 +BDMAP_00000104 +BDMAP_00001185 +BDMAP_00004459 +BDMAP_00000805 +BDMAP_00001238 +BDMAP_00004066 +BDMAP_00001020 +BDMAP_00002626 +BDMAP_00002730 +BDMAP_00000241 +BDMAP_00002017 +BDMAP_00001055 +BDMAP_00005073 +BDMAP_00004296 +BDMAP_00003425 +BDMAP_00003749 +BDMAP_00004775 +BDMAP_00004843 +BDMAP_00003752 +BDMAP_00005105 +BDMAP_00003832 +BDMAP_00004262 +BDMAP_00002085 +BDMAP_00003824 +BDMAP_00001057 +BDMAP_00003812 +BDMAP_00000993 +BDMAP_00000176 +BDMAP_00000618 +BDMAP_00003133 +BDMAP_00004652 +BDMAP_00002437 +BDMAP_00001461 +BDMAP_00003847 +BDMAP_00003381 +BDMAP_00004229 +BDMAP_00001109 +BDMAP_00002930 +BDMAP_00003664 +BDMAP_00001853 +BDMAP_00000851 +BDMAP_00002152 +BDMAP_00004510 +BDMAP_00000362 +BDMAP_00003178 +BDMAP_00003168 +BDMAP_00000465 +BDMAP_00003603 +BDMAP_00002776 +BDMAP_00000480 +BDMAP_00003822 +BDMAP_00004113 +BDMAP_00002695 +BDMAP_00003513 +BDMAP_00001590 +BDMAP_00000826 +BDMAP_00002403 +BDMAP_00001169 +BDMAP_00002661 +BDMAP_00003920 +BDMAP_00000122 +BDMAP_00004130 +BDMAP_00002133 +BDMAP_00002612 +BDMAP_00003923 +BDMAP_00004278 +BDMAP_00004888 +BDMAP_00002422 +BDMAP_00004639 +BDMAP_00002856 +BDMAP_00001907 +BDMAP_00004175 +BDMAP_00002896 +BDMAP_00004257 +BDMAP_00003017 +BDMAP_00004509 +BDMAP_00003377 +BDMAP_00001704 +BDMAP_00002283 +BDMAP_00004664 +BDMAP_00001305 +BDMAP_00004481 +BDMAP_00000696 +BDMAP_00000716 +BDMAP_00002807 +BDMAP_00003608 +BDMAP_00000881 +BDMAP_00004561 +BDMAP_00001027 +BDMAP_00003002 +BDMAP_00002361 +BDMAP_00002289 +BDMAP_00000159 +BDMAP_00000809 +BDMAP_00003918 +BDMAP_00001636 +BDMAP_00003153 +BDMAP_00000413 +BDMAP_00000137 +BDMAP_00002472 +BDMAP_00001281 +BDMAP_00000965 +BDMAP_00002226 +BDMAP_00001605 +BDMAP_00003347 +BDMAP_00002471 +BDMAP_00002582 +BDMAP_00002114 +BDMAP_00005083 +BDMAP_00000438 +BDMAP_00002354 +BDMAP_00003580 +BDMAP_00003315 +BDMAP_00003612 +BDMAP_00004829 +BDMAP_00004395 +BDMAP_00000709 +BDMAP_00000273 +BDMAP_00004636 +BDMAP_00001732 +BDMAP_00004331 +BDMAP_00001868 +BDMAP_00001214 +BDMAP_00001275 +BDMAP_00001809 +BDMAP_00004374 +BDMAP_00005009 +BDMAP_00001807 +BDMAP_00004294 +BDMAP_00004499 +BDMAP_00001251 +BDMAP_00004457 +BDMAP_00002495 +BDMAP_00001331 +BDMAP_00000481 +BDMAP_00000236 +BDMAP_00001862 +BDMAP_00002288 +BDMAP_00004620 +BDMAP_00001122 +BDMAP_00000882 +BDMAP_00002164 +BDMAP_00004196 +BDMAP_00003384 +BDMAP_00001710 +BDMAP_00003701 +BDMAP_00000607 +BDMAP_00000161 +BDMAP_00004065 +BDMAP_00003031 +BDMAP_00002216 +BDMAP_00001995 +BDMAP_00001584 +BDMAP_00000066 +BDMAP_00004475 +BDMAP_00001620 +BDMAP_00003658 +BDMAP_00003615 +BDMAP_00005113 +BDMAP_00004903 +BDMAP_00001125 +BDMAP_00003484 +BDMAP_00001325 +BDMAP_00000036 +BDMAP_00001370 +BDMAP_00002387 +BDMAP_00002396 +BDMAP_00003514 +BDMAP_00002918 +BDMAP_00004990 +BDMAP_00004106 +BDMAP_00000321 +BDMAP_00000713 +BDMAP_00002363 +BDMAP_00001445 +BDMAP_00000980 +BDMAP_00002485 +BDMAP_00002260 +BDMAP_00000388 +BDMAP_00001476 +BDMAP_00002592 +BDMAP_00003058 +BDMAP_00003364 +BDMAP_00000810 +BDMAP_00003329 +BDMAP_00001891 +BDMAP_00000117 +BDMAP_00001283 +BDMAP_00001128 +BDMAP_00005114 +BDMAP_00000692 +BDMAP_00000190 +BDMAP_00004579 +BDMAP_00005174 +BDMAP_00002690 +BDMAP_00004231 +BDMAP_00000219 +BDMAP_00002846 +BDMAP_00002057 +BDMAP_00001518 +BDMAP_00000589 +BDMAP_00003482 +BDMAP_00004817 +BDMAP_00003633 +BDMAP_00003890 +BDMAP_00002401 +BDMAP_00001223 +BDMAP_00004017 +BDMAP_00003400 +BDMAP_00000091 +BDMAP_00003363 +BDMAP_00004839 +BDMAP_00002383 +BDMAP_00004927 +BDMAP_00002451 +BDMAP_00004815 +BDMAP_00004783 +BDMAP_00005157 +BDMAP_00002373 +BDMAP_00001736 +BDMAP_00004943 +BDMAP_00004015 +BDMAP_00004773 +BDMAP_00001522 +BDMAP_00002171 +BDMAP_00002945 +BDMAP_00002990 +BDMAP_00001802 +BDMAP_00002326 +BDMAP_00000069 +BDMAP_00002185 +BDMAP_00001093 +BDMAP_00001487 +BDMAP_00001456 +BDMAP_00001045 +BDMAP_00001024 +BDMAP_00004615 +BDMAP_00000232 +BDMAP_00003722 +BDMAP_00001383 +BDMAP_00003267 +BDMAP_00002844 +BDMAP_00000030 +BDMAP_00001288 +BDMAP_00001483 +BDMAP_00000437 +BDMAP_00002855 +BDMAP_00003427 +BDMAP_00000771 +BDMAP_00004185 +BDMAP_00003740 +BDMAP_00004841 +BDMAP_00000062 +BDMAP_00004546 +BDMAP_00000662 +BDMAP_00002663 +BDMAP_00000936 +BDMAP_00002758 +BDMAP_00001892 +BDMAP_00002609 +BDMAP_00001982 +BDMAP_00005167 +BDMAP_00001945 +BDMAP_00001102 +BDMAP_00005170 +BDMAP_00000982 +BDMAP_00004129 +BDMAP_00001875 +BDMAP_00004735 +BDMAP_00000366 +BDMAP_00001175 +BDMAP_00002902 +BDMAP_00003558 +BDMAP_00002476 +BDMAP_00003694 +BDMAP_00000304 +BDMAP_00000225 +BDMAP_00002411 +BDMAP_00002304 +BDMAP_00000452 +BDMAP_00003598 +BDMAP_00001212 +BDMAP_00000683 +BDMAP_00005075 +BDMAP_00000162 +BDMAP_00002748 +BDMAP_00005099 +BDMAP_00002854 +BDMAP_00001289 +BDMAP_00000714 +BDMAP_00003849 +BDMAP_00003268 +BDMAP_00002529 +BDMAP_00001258 +BDMAP_00003438 +BDMAP_00000571 +BDMAP_00003853 +BDMAP_00003744 +BDMAP_00002829 +BDMAP_00000364 +BDMAP_00004039 +BDMAP_00000774 +BDMAP_00001834 +BDMAP_00001183 +BDMAP_00002458 +BDMAP_00004511 +BDMAP_00003255 +BDMAP_00003976 +BDMAP_00001924 +BDMAP_00004804 +BDMAP_00004163 +BDMAP_00001646 +BDMAP_00000435 +BDMAP_00002347 +BDMAP_00004297 +BDMAP_00002184 +BDMAP_00004712 +BDMAP_00003683 +BDMAP_00003657 +BDMAP_00004885 +BDMAP_00002947 +BDMAP_00002545 +BDMAP_00001119 +BDMAP_00001754 +BDMAP_00002267 +BDMAP_00003202 +BDMAP_00005108 +BDMAP_00001265 +BDMAP_00001092 +BDMAP_00004253 +BDMAP_00001563 +BDMAP_00001966 +BDMAP_00004304 +BDMAP_00000197 +BDMAP_00001273 +BDMAP_00003867 +BDMAP_00000859 +BDMAP_00001649 +BDMAP_00001664 +BDMAP_00003833 +BDMAP_00002710 +BDMAP_00001791 +BDMAP_00003932 +BDMAP_00002523 +BDMAP_00001632 +BDMAP_00002863 +BDMAP_00003762 +BDMAP_00001040 +BDMAP_00003971 +BDMAP_00005097 +BDMAP_00001845 +BDMAP_00000989 +BDMAP_00003672 +BDMAP_00001114 +BDMAP_00002742 +BDMAP_00004373 +BDMAP_00004850 +BDMAP_00002278 +BDMAP_00001701 +BDMAP_00001804 +BDMAP_00002349 +BDMAP_00002167 +BDMAP_00002265 +BDMAP_00004417 +BDMAP_00000245 +BDMAP_00005022 +BDMAP_00000871 +BDMAP_00002803 +BDMAP_00000656 +BDMAP_00001095 +BDMAP_00003506 +BDMAP_00003359 +BDMAP_00005141 +BDMAP_00001617 +BDMAP_00002479 +BDMAP_00000778 +BDMAP_00000113 +BDMAP_00000439 +BDMAP_00003409 +BDMAP_00003769 +BDMAP_00001025 +BDMAP_00000469 +BDMAP_00002841 +BDMAP_00001906 +BDMAP_00002426 +BDMAP_00004228 +BDMAP_00000616 +BDMAP_00000547 +BDMAP_00002440 +BDMAP_00002188 +BDMAP_00002484 +BDMAP_00003385 +BDMAP_00001261 +BDMAP_00001441 +BDMAP_00001324 +BDMAP_00003549 +BDMAP_00002465 +BDMAP_00004014 +BDMAP_00000432 +BDMAP_00001067 +BDMAP_00001001 +BDMAP_00000940 +BDMAP_00004597 +BDMAP_00001104 +BDMAP_00001296 +BDMAP_00002562 +BDMAP_00001692 +BDMAP_00005151 +BDMAP_00000883 +BDMAP_00001533 +BDMAP_00001921 +BDMAP_00002410 +BDMAP_00002237 +BDMAP_00002328 +BDMAP_00003614 +BDMAP_00000562 +BDMAP_00001237 +BDMAP_00003333 +BDMAP_00004847 +BDMAP_00005119 +BDMAP_00003277 +BDMAP_00005120 +BDMAP_00005081 +BDMAP_00001607 +BDMAP_00001523 +BDMAP_00005017 +BDMAP_00001010 +BDMAP_00001126 +BDMAP_00001957 +BDMAP_00003776 +BDMAP_00000368 +BDMAP_00002199 +BDMAP_00000956 +BDMAP_00001752 +BDMAP_00005168 +BDMAP_00000205 +BDMAP_00002309 +BDMAP_00002419 +BDMAP_00000093 +BDMAP_00000698 +BDMAP_00004917 +BDMAP_00000434 +BDMAP_00004867 +BDMAP_00000429 +BDMAP_00003947 +BDMAP_00004030 +BDMAP_00001270 +BDMAP_00002402 +BDMAP_00000972 +BDMAP_00003330 +BDMAP_00003244 +BDMAP_00001200 +BDMAP_00000149 +BDMAP_00003252 +BDMAP_00002029 +BDMAP_00000154 +BDMAP_00002940 +BDMAP_00000152 +BDMAP_00001471 +BDMAP_00002737 +BDMAP_00000023 +BDMAP_00002251 +BDMAP_00000701 +BDMAP_00002166 +BDMAP_00001236 +BDMAP_00000329 +BDMAP_00000642 +BDMAP_00001397 +BDMAP_00003435 +BDMAP_00000913 +BDMAP_00005092 +BDMAP_00004925 +BDMAP_00003412 +BDMAP_00003957 +BDMAP_00003897 +BDMAP_00004398 +BDMAP_00001539 +BDMAP_00001911 +BDMAP_00002421 +BDMAP_00004745 +BDMAP_00002318 +BDMAP_00000470 +BDMAP_00002889 +BDMAP_00001912 +BDMAP_00003326 +BDMAP_00002275 +BDMAP_00002227 +BDMAP_00000926 +BDMAP_00004187 +BDMAP_00001148 +BDMAP_00003376 +BDMAP_00003774 +BDMAP_00003857 +BDMAP_00003650 +BDMAP_00005078 +BDMAP_00003151 +BDMAP_00001242 +BDMAP_00003215 +BDMAP_00000676 +BDMAP_00003396 +BDMAP_00003479 +BDMAP_00003781 +BDMAP_00005070 +BDMAP_00003631 +BDMAP_00003840 +BDMAP_00003640 +BDMAP_00000347 +BDMAP_00004645 +BDMAP_00000715 +BDMAP_00002871 +BDMAP_00004834 +BDMAP_00004493 +BDMAP_00001828 +BDMAP_00001565 +BDMAP_00000902 +BDMAP_00001908 +BDMAP_00002688 +BDMAP_00003130 +BDMAP_00000971 +BDMAP_00000192 +BDMAP_00002924 +BDMAP_00002845 +BDMAP_00000660 +BDMAP_00000324 +BDMAP_00004895 +BDMAP_00002751 +BDMAP_00001474 +BDMAP_00001218 +BDMAP_00001130 +BDMAP_00001697 +BDMAP_00002498 +BDMAP_00001768 +BDMAP_00000233 +BDMAP_00004416 +BDMAP_00003138 +BDMAP_00000138 +BDMAP_00004508 +BDMAP_00001514 +BDMAP_00000243 +BDMAP_00001747 +BDMAP_00002487 +BDMAP_00003943 +BDMAP_00000043 +BDMAP_00001835 +BDMAP_00002233 +BDMAP_00004897 +BDMAP_00001230 +BDMAP_00004956 +BDMAP_00005191 +BDMAP_00001444 +BDMAP_00002117 +BDMAP_00001598 +BDMAP_00000087 +BDMAP_00000725 +BDMAP_00004552 +BDMAP_00005064 +BDMAP_00003111 +BDMAP_00004420 +BDMAP_00004293 +BDMAP_00000449 +BDMAP_00001905 +BDMAP_00003569 +BDMAP_00005005 +BDMAP_00004600 +BDMAP_00001766 +BDMAP_00001656 +BDMAP_00000345 +BDMAP_00001753 +BDMAP_00004028 +BDMAP_00000084 +BDMAP_00002253 +BDMAP_00004808 +BDMAP_00003052 +BDMAP_00002362 +BDMAP_00004435 +BDMAP_00004964 +BDMAP_00000516 +BDMAP_00004876 +BDMAP_00004651 +BDMAP_00000431 +BDMAP_00002022 +BDMAP_00001316 +BDMAP_00002359 +BDMAP_00004147 +BDMAP_00004264 +BDMAP_00004980 +BDMAP_00003685 +BDMAP_00004384 +BDMAP_00004199 +BDMAP_00002791 +BDMAP_00002120 +BDMAP_00002244 +BDMAP_00004462 +BDMAP_00000279 +BDMAP_00004676 +BDMAP_00000569 +BDMAP_00001517 +BDMAP_00004450 +BDMAP_00000414 +BDMAP_00000582 +BDMAP_00004558 +BDMAP_00001712 +BDMAP_00004796 +BDMAP_00004295 +BDMAP_00001842 +BDMAP_00001422 +BDMAP_00003036 +BDMAP_00001419 +BDMAP_00003576 +BDMAP_00000331 +BDMAP_00001225 +BDMAP_00004673 +BDMAP_00000977 +BDMAP_00000044 +BDMAP_00001826 +BDMAP_00001440 +BDMAP_00000574 +BDMAP_00004672 +BDMAP_00004830 +BDMAP_00004077 +BDMAP_00004793 +BDMAP_00004074 +BDMAP_00000139 +BDMAP_00003356 +BDMAP_00003713 +BDMAP_00003254 +BDMAP_00001333 +BDMAP_00004023 +BDMAP_00004880 +BDMAP_00002981 +BDMAP_00005160 +BDMAP_00001096 +BDMAP_00003109 +BDMAP_00003063 +BDMAP_00003973 +BDMAP_00004719 +BDMAP_00000542 +BDMAP_00004491 +BDMAP_00002172 +BDMAP_00000907 +BDMAP_00005154 +BDMAP_00003827 +BDMAP_00004541 +BDMAP_00003493 +BDMAP_00003461 +BDMAP_00000338 +BDMAP_00004016 +BDMAP_00002815 +BDMAP_00002805 +BDMAP_00000918 +BDMAP_00003141 +BDMAP_00001564 +BDMAP_00003392 +BDMAP_00000939 +BDMAP_00001368 +BDMAP_00004549 +BDMAP_00001707 +BDMAP_00001475 +BDMAP_00002232 +BDMAP_00000923 +BDMAP_00004104 +BDMAP_00004608 +BDMAP_00004825 +BDMAP_00001209 +BDMAP_00005185 +BDMAP_00002696 +BDMAP_00000828 +BDMAP_00001059 +BDMAP_00001647 +BDMAP_00000039 +BDMAP_00000935 +BDMAP_00002712 +BDMAP_00003451 +BDMAP_00000059 +BDMAP_00003516 +BDMAP_00002295 +BDMAP_00001516 +BDMAP_00002319 +BDMAP_00001077 +BDMAP_00003581 +BDMAP_00002884 +BDMAP_00003324 +BDMAP_00000128 +BDMAP_00002959 +BDMAP_00000411 +BDMAP_00003717 +BDMAP_00004995 +BDMAP_00000653 +BDMAP_00004031 +BDMAP_00003590 +BDMAP_00001215 +BDMAP_00001256 +BDMAP_00002273 +BDMAP_00000667 +BDMAP_00000373 +BDMAP_00003680 +BDMAP_00001784 +BDMAP_00001286 +BDMAP_00001246 +BDMAP_00003440 +BDMAP_00002656 +BDMAP_00003955 +BDMAP_00003930 +BDMAP_00001985 +BDMAP_00004328 +BDMAP_00004744 +BDMAP_00004529 +BDMAP_00004447 +BDMAP_00002252 +BDMAP_00003994 +BDMAP_00001711 +BDMAP_00000355 +BDMAP_00001836 +BDMAP_00003448 +BDMAP_00000855 +BDMAP_00002039 +BDMAP_00005063 +BDMAP_00004286 +BDMAP_00001823 +BDMAP_00002407 +BDMAP_00002933 +BDMAP_00003928 +BDMAP_00000447 +BDMAP_00003411 +BDMAP_00004641 +BDMAP_00003886 +BDMAP_00000240 +BDMAP_00001917 +BDMAP_00003952 +BDMAP_00001464 +BDMAP_00000614 +BDMAP_00003491 +BDMAP_00004427 +BDMAP_00004131 +BDMAP_00004011 +BDMAP_00000297 +BDMAP_00001511 +BDMAP_00000812 +BDMAP_00005020 +BDMAP_00004060 +BDMAP_00002496 +BDMAP_00003455 +BDMAP_00005169 +BDMAP_00000462 +BDMAP_00001502 +BDMAP_00000558 +BDMAP_00004216 +BDMAP_00000244 +BDMAP_00001602 +BDMAP_00003073 +BDMAP_00001618 +BDMAP_00000839 +BDMAP_00002333 +BDMAP_00002298 +BDMAP_00000873 +BDMAP_00001521 +BDMAP_00003946 +BDMAP_00000690 +BDMAP_00004969 +BDMAP_00000320 +BDMAP_00003074 +BDMAP_00004154 +BDMAP_00001420 +BDMAP_00002826 +BDMAP_00002076 +BDMAP_00002021 +BDMAP_00000837 +BDMAP_00000968 +BDMAP_00001138 +BDMAP_00002524 +BDMAP_00000532 +BDMAP_00002250 +BDMAP_00002282 +BDMAP_00003281 +BDMAP_00004738 +BDMAP_00004389 +BDMAP_00004922 +BDMAP_00002305 +BDMAP_00003070 +BDMAP_00002793 +BDMAP_00002986 +BDMAP_00000623 +BDMAP_00001794 +BDMAP_00002475 +BDMAP_00004415 +BDMAP_00001898 +BDMAP_00002936 +BDMAP_00003443 +BDMAP_00004550 +BDMAP_00004479 +BDMAP_00002041 +BDMAP_00001806 +BDMAP_00002509 +BDMAP_00002616 +BDMAP_00005065 +BDMAP_00005085 +BDMAP_00001379 +BDMAP_00003911 +BDMAP_00002707 +BDMAP_00004097 +BDMAP_00003128 +BDMAP_00003996 +BDMAP_00000626 +BDMAP_00000263 +BDMAP_00001549 +BDMAP_00000229 +BDMAP_00001688 +BDMAP_00002313 +BDMAP_00003319 +BDMAP_00003343 +BDMAP_00004624 +BDMAP_00001737 +BDMAP_00001624 +BDMAP_00003358 +BDMAP_00000998 +BDMAP_00004195 +BDMAP_00001941 +BDMAP_00004870 +BDMAP_00000948 +BDMAP_00001496 +BDMAP_00000687 +BDMAP_00004033 +BDMAP_00001068 +BDMAP_00003520 +BDMAP_00000941 +BDMAP_00000867 +BDMAP_00000264 +BDMAP_00005067 +BDMAP_00000132 +BDMAP_00004650 +BDMAP_00003736 +BDMAP_00003564 +BDMAP_00001635 +BDMAP_00003898 +BDMAP_00004901 +BDMAP_00000400 +BDMAP_00004671 +BDMAP_00000353 +BDMAP_00001089 +BDMAP_00000572 +BDMAP_00002953 +BDMAP_00003600 +BDMAP_00003798 +BDMAP_00000987 +BDMAP_00000541 +BDMAP_00004717 +BDMAP_00002068 +BDMAP_00001977 +BDMAP_00002942 +BDMAP_00000416 +BDMAP_00002580 +BDMAP_00001410 +BDMAP_00000052 +BDMAP_00003361 +BDMAP_00001247 +BDMAP_00004894 +BDMAP_00002060 +BDMAP_00000319 +BDMAP_00004407 +BDMAP_00002099 +BDMAP_00004431 +BDMAP_00003225 +BDMAP_00003236 +BDMAP_00004981 +BDMAP_00000671 +BDMAP_00003444 +BDMAP_00003525 +BDMAP_00000259 +BDMAP_00003497 +BDMAP_00003767 +BDMAP_00004184 +BDMAP_00003524 +BDMAP_00000942 +BDMAP_00002719 +BDMAP_00004232 +BDMAP_00005186 +BDMAP_00003900 diff --git a/Generation_Pipeline_filter/best_metric_model_classification3d_dict.pth b/Generation_Pipeline_filter/best_metric_model_classification3d_dict.pth new file mode 100644 index 0000000000000000000000000000000000000000..b86cfac903c22ca1329f602b0360ea98f0c8fae7 --- /dev/null +++ b/Generation_Pipeline_filter/best_metric_model_classification3d_dict.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:331395112f91825f17114067c7218c6c9bd727378cbe4c43a1d47edb140b7282 +size 45594067 diff --git a/Generation_Pipeline_filter/get_syn_list.py b/Generation_Pipeline_filter/get_syn_list.py new file mode 100644 index 0000000000000000000000000000000000000000..a03571f9f096b48bd73a6cd8236435aac77dca75 --- /dev/null +++ b/Generation_Pipeline_filter/get_syn_list.py @@ -0,0 +1,26 @@ +import os + +organ = 'colon' +real_organ = [] +with open(f'real_set/{organ}.txt', 'r') as f: + real_organ=f.readlines() +real_organ = [i.split('\n')[0] for i in real_organ] + + +total_case = [] +with open(f'real_total.txt', 'r') as f: + total_case=f.readlines() +total_case = [i.split('\n')[0] for i in total_case] + + +absence2= list(set(total_case) - set(real_organ)) +absence2 = [i for i in absence2] +# breakpoint() + +filename = open(f'syn_{organ}/healthy_{organ}_1k.txt','a+')#dict转txt +for i in absence2: + filename.write(i) + filename.write('\n') +filename.close() + + diff --git a/Generation_Pipeline_filter/get_training_list.py b/Generation_Pipeline_filter/get_training_list.py new file mode 100644 index 0000000000000000000000000000000000000000..dba70b6bd3d7081cf58d741719b87d4ce4170511 --- /dev/null +++ b/Generation_Pipeline_filter/get_training_list.py @@ -0,0 +1,45 @@ +import os + +total_case = [] +with open(f'real_total.txt', 'r') as f: + total_case=f.readlines() +total_case = [i.split('\n')[0] for i in total_case] + + +real_organ = [] +with open(f'val_set/bodymap_liver.txt', 'r') as f: + real_organ=f.readlines() +real_organ = [i.split('\n')[0] for i in real_organ] +total_case= list(set(total_case) - set(real_organ)) +total_case = [i for i in total_case] + +real_organ = [] +with open(f'val_set/bodymap_pancreas.txt', 'r') as f: + real_organ=f.readlines() +real_organ = [i.split('\n')[0] for i in real_organ] +total_case= list(set(total_case) - set(real_organ)) +total_case = [i for i in total_case] + +real_organ = [] +with open(f'val_set/bodymap_kidney.txt', 'r') as f: + real_organ=f.readlines() +real_organ = [i.split('\n')[0] for i in real_organ] +total_case= list(set(total_case) - set(real_organ)) +total_case = [i for i in total_case] + +real_organ = [] +with open(f'val_set/bodymap_colon.txt', 'r') as f: + real_organ=f.readlines() +real_organ = [i.split('\n')[0] for i in real_organ] +total_case= list(set(total_case) - set(real_organ)) +total_case = [i for i in total_case] + + + +filename = open(f'Atlas_X_1k.txt','a+')#dict转txt +for i in total_case: + filename.write(i) + filename.write('\n') +filename.close() + + diff --git a/Generation_Pipeline_filter/real_set/colon.txt b/Generation_Pipeline_filter/real_set/colon.txt new file mode 100644 index 0000000000000000000000000000000000000000..c19f373284fab2de8fa90f1ccc891b5b514d9b72 --- /dev/null +++ b/Generation_Pipeline_filter/real_set/colon.txt @@ -0,0 +1,126 @@ +BDMAP_00001078 +BDMAP_00003031 +BDMAP_00002253 +BDMAP_00001732 +BDMAP_00000874 +BDMAP_00003847 +BDMAP_00003268 +BDMAP_00002846 +BDMAP_00001438 +BDMAP_00004650 +BDMAP_00003109 +BDMAP_00004121 +BDMAP_00004165 +BDMAP_00004676 +BDMAP_00003890 +BDMAP_00003327 +BDMAP_00000132 +BDMAP_00001215 +BDMAP_00001769 +BDMAP_00003412 +BDMAP_00002318 +BDMAP_00004624 +BDMAP_00000345 +BDMAP_00002230 +BDMAP_00003111 +BDMAP_00001015 +BDMAP_00001514 +BDMAP_00001924 +BDMAP_00002845 +BDMAP_00002598 +BDMAP_00001209 +BDMAP_00000373 +BDMAP_00001737 +BDMAP_00003113 +BDMAP_00004876 +BDMAP_00003640 +BDMAP_00001985 +BDMAP_00000138 +BDMAP_00000881 +BDMAP_00002739 +BDMAP_00003560 +BDMAP_00002612 +BDMAP_00001445 +BDMAP_00003827 +BDMAP_00001024 +BDMAP_00000568 +BDMAP_00001095 +BDMAP_00002458 +BDMAP_00002986 +BDMAP_00000913 +BDMAP_00000264 +BDMAP_00000690 +BDMAP_00002039 +BDMAP_00001426 +BDMAP_00002730 +BDMAP_00001917 +BDMAP_00005067 +BDMAP_00002924 +BDMAP_00005160 +BDMAP_00005073 +BDMAP_00000547 +BDMAP_00000942 +BDMAP_00002103 +BDMAP_00002654 +BDMAP_00004374 +BDMAP_00003510 +BDMAP_00004910 +BDMAP_00004558 +BDMAP_00004450 +BDMAP_00000152 +BDMAP_00004491 +BDMAP_00001237 +BDMAP_00001785 +BDMAP_00001865 +BDMAP_00000851 +BDMAP_00003357 +BDMAP_00004415 +BDMAP_00004615 +BDMAP_00003680 +BDMAP_00001875 +BDMAP_00004894 +BDMAP_00001835 +BDMAP_00000069 +BDMAP_00001809 +BDMAP_00004431 +BDMAP_00002704 +BDMAP_00002185 +BDMAP_00004384 +BDMAP_00003299 +BDMAP_00003333 +BDMAP_00002305 +BDMAP_00001598 +BDMAP_00002465 +BDMAP_00002199 +BDMAP_00002875 +BDMAP_00000828 +BDMAP_00003564 +BDMAP_00005001 +BDMAP_00004493 +BDMAP_00000190 +BDMAP_00000873 +BDMAP_00005170 +BDMAP_00002152 +BDMAP_00004163 +BDMAP_00000939 +BDMAP_00001212 +BDMAP_00001982 +BDMAP_00000552 +BDMAP_00004764 +BDMAP_00002401 +BDMAP_00002451 +BDMAP_00003634 +BDMAP_00005016 +BDMAP_00000716 +BDMAP_00003373 +BDMAP_00000030 +BDMAP_00003946 +BDMAP_00002828 +BDMAP_00004196 +BDMAP_00005005 +BDMAP_00003972 +BDMAP_00003172 +BDMAP_00004783 +BDMAP_00001102 +BDMAP_00004147 +BDMAP_00004604 diff --git a/Generation_Pipeline_filter/real_set/kidney.txt b/Generation_Pipeline_filter/real_set/kidney.txt new file mode 100644 index 0000000000000000000000000000000000000000..dc40eda19f629c971abb354633257a95ca7cea14 --- /dev/null +++ b/Generation_Pipeline_filter/real_set/kidney.txt @@ -0,0 +1,489 @@ +BDMAP_00000245 +BDMAP_00000036 +BDMAP_00003833 +BDMAP_00001517 +BDMAP_00004087 +BDMAP_00002807 +BDMAP_00002099 +BDMAP_00001602 +BDMAP_00001035 +BDMAP_00002422 +BDMAP_00000626 +BDMAP_00002173 +BDMAP_00000240 +BDMAP_00001246 +BDMAP_00000582 +BDMAP_00003996 +BDMAP_00001707 +BDMAP_00000923 +BDMAP_00003411 +BDMAP_00004113 +BDMAP_00002582 +BDMAP_00001261 +BDMAP_00005167 +BDMAP_00004897 +BDMAP_00001169 +BDMAP_00001148 +BDMAP_00002164 +BDMAP_00002041 +BDMAP_00000889 +BDMAP_00001109 +BDMAP_00005009 +BDMAP_00001286 +BDMAP_00000297 +BDMAP_00005099 +BDMAP_00004257 +BDMAP_00005017 +BDMAP_00000604 +BDMAP_00002472 +BDMAP_00001225 +BDMAP_00005081 +BDMAP_00003491 +BDMAP_00001635 +BDMAP_00002075 +BDMAP_00000660 +BDMAP_00001238 +BDMAP_00002656 +BDMAP_00003558 +BDMAP_00001104 +BDMAP_00004066 +BDMAP_00003294 +BDMAP_00001607 +BDMAP_00001077 +BDMAP_00000653 +BDMAP_00001273 +BDMAP_00000616 +BDMAP_00002057 +BDMAP_00004586 +BDMAP_00004407 +BDMAP_00004922 +BDMAP_00002592 +BDMAP_00000149 +BDMAP_00000320 +BDMAP_00001511 +BDMAP_00000435 +BDMAP_00002746 +BDMAP_00004457 +BDMAP_00000805 +BDMAP_00002661 +BDMAP_00004552 +BDMAP_00004154 +BDMAP_00002902 +BDMAP_00000839 +BDMAP_00000233 +BDMAP_00000122 +BDMAP_00005151 +BDMAP_00004427 +BDMAP_00002936 +BDMAP_00003955 +BDMAP_00001863 +BDMAP_00002326 +BDMAP_00001420 +BDMAP_00000329 +BDMAP_00004561 +BDMAP_00003971 +BDMAP_00000935 +BDMAP_00000569 +BDMAP_00004956 +BDMAP_00000285 +BDMAP_00004597 +BDMAP_00001747 +BDMAP_00001059 +BDMAP_00002354 +BDMAP_00001656 +BDMAP_00004395 +BDMAP_00002942 +BDMAP_00004981 +BDMAP_00001768 +BDMAP_00002319 +BDMAP_00003947 +BDMAP_00001868 +BDMAP_00002065 +BDMAP_00002333 +BDMAP_00003358 +BDMAP_00001265 +BDMAP_00003952 +BDMAP_00001891 +BDMAP_00003576 +BDMAP_00000980 +BDMAP_00003300 +BDMAP_00001782 +BDMAP_00003717 +BDMAP_00001251 +BDMAP_00000044 +BDMAP_00004510 +BDMAP_00003315 +BDMAP_00002653 +BDMAP_00001045 +BDMAP_00003694 +BDMAP_00004216 +BDMAP_00001794 +BDMAP_00000532 +BDMAP_00002288 +BDMAP_00001256 +BDMAP_00000219 +BDMAP_00000710 +BDMAP_00003930 +BDMAP_00001636 +BDMAP_00003749 +BDMAP_00000998 +BDMAP_00000176 +BDMAP_00000429 +BDMAP_00001001 +BDMAP_00001908 +BDMAP_00003363 +BDMAP_00004903 +BDMAP_00004482 +BDMAP_00003178 +BDMAP_00003202 +BDMAP_00001230 +BDMAP_00003461 +BDMAP_00003281 +BDMAP_00000434 +BDMAP_00001218 +BDMAP_00003976 +BDMAP_00003455 +BDMAP_00001183 +BDMAP_00002609 +BDMAP_00001305 +BDMAP_00000364 +BDMAP_00003516 +BDMAP_00003956 +BDMAP_00000977 +BDMAP_00001784 +BDMAP_00004389 +BDMAP_00001711 +BDMAP_00000698 +BDMAP_00003153 +BDMAP_00001995 +BDMAP_00001549 +BDMAP_00001324 +BDMAP_00004195 +BDMAP_00001562 +BDMAP_00004074 +BDMAP_00001483 +BDMAP_00002085 +BDMAP_00001396 +BDMAP_00000241 +BDMAP_00004031 +BDMAP_00004775 +BDMAP_00001807 +BDMAP_00005120 +BDMAP_00004065 +BDMAP_00003943 +BDMAP_00002953 +BDMAP_00004232 +BDMAP_00002184 +BDMAP_00002407 +BDMAP_00003252 +BDMAP_00004296 +BDMAP_00000161 +BDMAP_00002981 +BDMAP_00003608 +BDMAP_00003128 +BDMAP_00000571 +BDMAP_00000259 +BDMAP_00003444 +BDMAP_00001647 +BDMAP_00000662 +BDMAP_00003774 +BDMAP_00001383 +BDMAP_00004616 +BDMAP_00001906 +BDMAP_00003740 +BDMAP_00001422 +BDMAP_00002631 +BDMAP_00004294 +BDMAP_00003994 +BDMAP_00004475 +BDMAP_00002744 +BDMAP_00001068 +BDMAP_00000667 +BDMAP_00001945 +BDMAP_00002710 +BDMAP_00002440 +BDMAP_00000833 +BDMAP_00003143 +BDMAP_00000062 +BDMAP_00003392 +BDMAP_00004373 +BDMAP_00001020 +BDMAP_00003603 +BDMAP_00001027 +BDMAP_00005114 +BDMAP_00003384 +BDMAP_00000794 +BDMAP_00001911 +BDMAP_00002437 +BDMAP_00004579 +BDMAP_00004250 +BDMAP_00002068 +BDMAP_00000608 +BDMAP_00004551 +BDMAP_00002884 +BDMAP_00004033 +BDMAP_00005105 +BDMAP_00002776 +BDMAP_00000414 +BDMAP_00003580 +BDMAP_00004712 +BDMAP_00002114 +BDMAP_00002226 +BDMAP_00003923 +BDMAP_00002854 +BDMAP_00004039 +BDMAP_00004014 +BDMAP_00001289 +BDMAP_00003435 +BDMAP_00004578 +BDMAP_00002940 +BDMAP_00003164 +BDMAP_00002751 +BDMAP_00001516 +BDMAP_00003486 +BDMAP_00000279 +BDMAP_00001664 +BDMAP_00004738 +BDMAP_00001735 +BDMAP_00000562 +BDMAP_00000812 +BDMAP_00000511 +BDMAP_00004746 +BDMAP_00000452 +BDMAP_00004328 +BDMAP_00002017 +BDMAP_00002840 +BDMAP_00000039 +BDMAP_00002242 +BDMAP_00002775 +BDMAP_00003762 +BDMAP_00000229 +BDMAP_00003520 +BDMAP_00000725 +BDMAP_00000516 +BDMAP_00001941 +BDMAP_00003928 +BDMAP_00001255 +BDMAP_00001456 +BDMAP_00002410 +BDMAP_00002742 +BDMAP_00001688 +BDMAP_00000487 +BDMAP_00000469 +BDMAP_00002022 +BDMAP_00003058 +BDMAP_00004148 +BDMAP_00001977 +BDMAP_00000887 +BDMAP_00003448 +BDMAP_00001410 +BDMAP_00002383 +BDMAP_00003736 +BDMAP_00002626 +BDMAP_00001710 +BDMAP_00001130 +BDMAP_00001138 +BDMAP_00001413 +BDMAP_00003815 +BDMAP_00004130 +BDMAP_00004652 +BDMAP_00002864 +BDMAP_00000574 +BDMAP_00003493 +BDMAP_00003364 +BDMAP_00002648 +BDMAP_00001281 +BDMAP_00002655 +BDMAP_00001126 +BDMAP_00002804 +BDMAP_00000321 +BDMAP_00005191 +BDMAP_00004420 +BDMAP_00000304 +BDMAP_00003150 +BDMAP_00004620 +BDMAP_00000368 +BDMAP_00000066 +BDMAP_00003701 +BDMAP_00005174 +BDMAP_00002545 +BDMAP_00003957 +BDMAP_00004331 +BDMAP_00000687 +BDMAP_00001791 +BDMAP_00002959 +BDMAP_00004104 +BDMAP_00003073 +BDMAP_00003713 +BDMAP_00002363 +BDMAP_00000137 +BDMAP_00000104 +BDMAP_00002689 +BDMAP_00004990 +BDMAP_00003301 +BDMAP_00001434 +BDMAP_00000449 +BDMAP_00005113 +BDMAP_00003225 +BDMAP_00001359 +BDMAP_00001223 +BDMAP_00002803 +BDMAP_00000355 +BDMAP_00001826 +BDMAP_00004673 +BDMAP_00002251 +BDMAP_00000439 +BDMAP_00005085 +BDMAP_00003381 +BDMAP_00004645 +BDMAP_00000432 +BDMAP_00001444 +BDMAP_00001705 +BDMAP_00001892 +BDMAP_00002826 +BDMAP_00004671 +BDMAP_00000926 +BDMAP_00004817 +BDMAP_00004175 +BDMAP_00003484 +BDMAP_00003672 +BDMAP_00003267 +BDMAP_00001089 +BDMAP_00001496 +BDMAP_00003615 +BDMAP_00003832 +BDMAP_00002695 +BDMAP_00002696 +BDMAP_00004499 +BDMAP_00004867 +BDMAP_00004479 +BDMAP_00003600 +BDMAP_00000989 +BDMAP_00002421 +BDMAP_00003406 +BDMAP_00000263 +BDMAP_00002396 +BDMAP_00002265 +BDMAP_00000713 +BDMAP_00000883 +BDMAP_00001258 +BDMAP_00004253 +BDMAP_00004870 +BDMAP_00000331 +BDMAP_00004608 +BDMAP_00001518 +BDMAP_00002562 +BDMAP_00002889 +BDMAP_00001676 +BDMAP_00000117 +BDMAP_00003973 +BDMAP_00002509 +BDMAP_00002487 +BDMAP_00003457 +BDMAP_00000982 +BDMAP_00002260 +BDMAP_00001283 +BDMAP_00003506 +BDMAP_00000366 +BDMAP_00002133 +BDMAP_00000465 +BDMAP_00003767 +BDMAP_00001853 +BDMAP_00002361 +BDMAP_00004815 +BDMAP_00002933 +BDMAP_00000162 +BDMAP_00004925 +BDMAP_00005077 +BDMAP_00001533 +BDMAP_00001242 +BDMAP_00000871 +BDMAP_00000948 +BDMAP_00001119 +BDMAP_00004887 +BDMAP_00002404 +BDMAP_00003722 +BDMAP_00002426 +BDMAP_00002060 +BDMAP_00004850 +BDMAP_00003343 +BDMAP_00001624 +BDMAP_00000481 +BDMAP_00002166 +BDMAP_00003849 +BDMAP_00004808 +BDMAP_00002471 +BDMAP_00000656 +BDMAP_00003581 +BDMAP_00000023 +BDMAP_00003727 +BDMAP_00000319 +BDMAP_00003255 +BDMAP_00003752 +BDMAP_00000139 +BDMAP_00003614 +BDMAP_00003549 +BDMAP_00003808 +BDMAP_00002930 +BDMAP_00001128 +BDMAP_00004717 +BDMAP_00000826 +BDMAP_00002663 +BDMAP_00000837 +BDMAP_00000159 +BDMAP_00005154 +BDMAP_00002524 +BDMAP_00000968 +BDMAP_00004278 +BDMAP_00001325 +BDMAP_00000987 +BDMAP_00004901 +BDMAP_00003425 +BDMAP_00005006 +BDMAP_00004131 +BDMAP_00002403 +BDMAP_00001620 +BDMAP_00002347 +BDMAP_00001522 +BDMAP_00004011 +BDMAP_00001474 +BDMAP_00004744 +BDMAP_00002484 +BDMAP_00001370 +BDMAP_00003324 +BDMAP_00001557 +BDMAP_00000867 +BDMAP_00001487 +BDMAP_00004980 +BDMAP_00000034 +BDMAP_00000936 +BDMAP_00000128 +BDMAP_00001275 +BDMAP_00004030 +BDMAP_00003359 +BDMAP_00003070 +BDMAP_00002476 +BDMAP_00002990 +BDMAP_00000810 +BDMAP_00003514 +BDMAP_00004834 +BDMAP_00003409 +BDMAP_00002498 +BDMAP_00004481 +BDMAP_00002273 +BDMAP_00002496 +BDMAP_00002871 +BDMAP_00000059 +BDMAP_00001475 +BDMAP_00000902 +BDMAP_00004417 +BDMAP_00005157 +BDMAP_00001752 +BDMAP_00001563 +BDMAP_00003063 +BDMAP_00001296 +BDMAP_00002707 +BDMAP_00000836 +BDMAP_00000353 +BDMAP_00000043 +BDMAP_00000244 diff --git a/Generation_Pipeline_filter/real_set/liver.txt b/Generation_Pipeline_filter/real_set/liver.txt new file mode 100644 index 0000000000000000000000000000000000000000..4e722fe9f916144960815ec7f01a5d3ba64a5d1d --- /dev/null +++ b/Generation_Pipeline_filter/real_set/liver.txt @@ -0,0 +1,159 @@ +BDMAP_00000400 +BDMAP_00003497 +BDMAP_00001270 +BDMAP_00001766 +BDMAP_00001309 +BDMAP_00004745 +BDMAP_00003002 +BDMAP_00004825 +BDMAP_00004416 +BDMAP_00002712 +BDMAP_00004830 +BDMAP_00000907 +BDMAP_00001957 +BDMAP_00000941 +BDMAP_00002841 +BDMAP_00001962 +BDMAP_00004462 +BDMAP_00004281 +BDMAP_00004890 +BDMAP_00003272 +BDMAP_00003377 +BDMAP_00005186 +BDMAP_00002172 +BDMAP_00000091 +BDMAP_00004639 +BDMAP_00000918 +BDMAP_00000671 +BDMAP_00004028 +BDMAP_00004529 +BDMAP_00001907 +BDMAP_00001122 +BDMAP_00003151 +BDMAP_00002252 +BDMAP_00003524 +BDMAP_00004704 +BDMAP_00000362 +BDMAP_00003932 +BDMAP_00004995 +BDMAP_00002748 +BDMAP_00004117 +BDMAP_00000480 +BDMAP_00001010 +BDMAP_00000100 +BDMAP_00001200 +BDMAP_00004103 +BDMAP_00004878 +BDMAP_00002282 +BDMAP_00001471 +BDMAP_00000232 +BDMAP_00003439 +BDMAP_00003857 +BDMAP_00004943 +BDMAP_00005130 +BDMAP_00002479 +BDMAP_00002909 +BDMAP_00004185 +BDMAP_00003569 +BDMAP_00001185 +BDMAP_00002849 +BDMAP_00003556 +BDMAP_00003052 +BDMAP_00000971 +BDMAP_00003330 +BDMAP_00000113 +BDMAP_00004600 +BDMAP_00002529 +BDMAP_00000437 +BDMAP_00003074 +BDMAP_00005139 +BDMAP_00001966 +BDMAP_00002791 +BDMAP_00001692 +BDMAP_00001786 +BDMAP_00001697 +BDMAP_00003798 +BDMAP_00000273 +BDMAP_00001114 +BDMAP_00003898 +BDMAP_00001397 +BDMAP_00003867 +BDMAP_00005065 +BDMAP_00001802 +BDMAP_00001539 +BDMAP_00000084 +BDMAP_00002955 +BDMAP_00002271 +BDMAP_00004459 +BDMAP_00004378 +BDMAP_00004435 +BDMAP_00001093 +BDMAP_00003897 +BDMAP_00003236 +BDMAP_00001502 +BDMAP_00001834 +BDMAP_00000347 +BDMAP_00000831 +BDMAP_00002717 +BDMAP_00002856 +BDMAP_00004199 +BDMAP_00000709 +BDMAP_00003481 +BDMAP_00002719 +BDMAP_00005083 +BDMAP_00002359 +BDMAP_00000642 +BDMAP_00000778 +BDMAP_00000745 +BDMAP_00000607 +BDMAP_00001236 +BDMAP_00001333 +BDMAP_00003920 +BDMAP_00003664 +BDMAP_00003911 +BDMAP_00002463 +BDMAP_00002419 +BDMAP_00000965 +BDMAP_00003513 +BDMAP_00004508 +BDMAP_00002283 +BDMAP_00004509 +BDMAP_00000615 +BDMAP_00001171 +BDMAP_00001343 +BDMAP_00002167 +BDMAP_00000205 +BDMAP_00002805 +BDMAP_00002275 +BDMAP_00002485 +BDMAP_00004228 +BDMAP_00004304 +BDMAP_00004187 +BDMAP_00001379 +BDMAP_00001753 +BDMAP_00000413 +BDMAP_00002289 +BDMAP_00000572 +BDMAP_00005119 +BDMAP_00004017 +BDMAP_00004016 +BDMAP_00002349 +BDMAP_00000101 +BDMAP_00003482 +BDMAP_00004839 +BDMAP_00001025 +BDMAP_00003361 +BDMAP_00002495 +BDMAP_00001055 +BDMAP_00002214 +BDMAP_00005097 +BDMAP_00005168 +BDMAP_00002267 +BDMAP_00001198 +BDMAP_00002918 +BDMAP_00004664 +BDMAP_00004888 +BDMAP_00000921 +BDMAP_00002373 +BDMAP_00001316 +BDMAP_00002117 diff --git a/Generation_Pipeline_filter/real_set/pancreas.txt b/Generation_Pipeline_filter/real_set/pancreas.txt new file mode 100644 index 0000000000000000000000000000000000000000..2f0d669103ff6594f0e57d4a55cadffb2c4a8d9c --- /dev/null +++ b/Generation_Pipeline_filter/real_set/pancreas.txt @@ -0,0 +1,281 @@ +BDMAP_00003244 +BDMAP_00005074 +BDMAP_00004804 +BDMAP_00004672 +BDMAP_00003133 +BDMAP_00004969 +BDMAP_00002278 +BDMAP_00001862 +BDMAP_00005185 +BDMAP_00004880 +BDMAP_00004770 +BDMAP_00002690 +BDMAP_00002944 +BDMAP_00003744 +BDMAP_00002021 +BDMAP_00003141 +BDMAP_00004927 +BDMAP_00001476 +BDMAP_00003551 +BDMAP_00004964 +BDMAP_00001605 +BDMAP_00002298 +BDMAP_00001746 +BDMAP_00000332 +BDMAP_00003590 +BDMAP_00000956 +BDMAP_00001649 +BDMAP_00003781 +BDMAP_00001523 +BDMAP_00003347 +BDMAP_00005022 +BDMAP_00004128 +BDMAP_00003612 +BDMAP_00003658 +BDMAP_00003812 +BDMAP_00003427 +BDMAP_00003502 +BDMAP_00001823 +BDMAP_00004847 +BDMAP_00003776 +BDMAP_00001205 +BDMAP_00000192 +BDMAP_00004511 +BDMAP_00001564 +BDMAP_00000416 +BDMAP_00005070 +BDMAP_00001040 +BDMAP_00004231 +BDMAP_00002945 +BDMAP_00001704 +BDMAP_00002402 +BDMAP_00000940 +BDMAP_00000243 +BDMAP_00001464 +BDMAP_00002793 +BDMAP_00001646 +BDMAP_00005020 +BDMAP_00004992 +BDMAP_00003017 +BDMAP_00001096 +BDMAP_00003451 +BDMAP_00001067 +BDMAP_00001331 +BDMAP_00000696 +BDMAP_00001461 +BDMAP_00003326 +BDMAP_00000715 +BDMAP_00000855 +BDMAP_00000087 +BDMAP_00000093 +BDMAP_00000324 +BDMAP_00003440 +BDMAP_00002387 +BDMAP_00004060 +BDMAP_00000714 +BDMAP_00001617 +BDMAP_00004494 +BDMAP_00002616 +BDMAP_00000225 +BDMAP_00001754 +BDMAP_00005075 +BDMAP_00002328 +BDMAP_00004229 +BDMAP_00000541 +BDMAP_00004447 +BDMAP_00004106 +BDMAP_00003592 +BDMAP_00003036 +BDMAP_00001125 +BDMAP_00001361 +BDMAP_00002863 +BDMAP_00002309 +BDMAP_00001905 +BDMAP_00004115 +BDMAP_00002216 +BDMAP_00004829 +BDMAP_00003443 +BDMAP_00001504 +BDMAP_00004885 +BDMAP_00003451 +BDMAP_00000679 +BDMAP_00002362 +BDMAP_00000388 +BDMAP_00003769 +BDMAP_00004198 +BDMAP_00004719 +BDMAP_00000809 +BDMAP_00003525 +BDMAP_00003138 +BDMAP_00005063 +BDMAP_00000676 +BDMAP_00000411 +BDMAP_00002523 +BDMAP_00003367 +BDMAP_00003961 +BDMAP_00003822 +BDMAP_00000462 +BDMAP_00001632 +BDMAP_00003840 +BDMAP_00003483 +BDMAP_00002313 +BDMAP_00000154 +BDMAP_00001828 +BDMAP_00003771 +BDMAP_00004550 +BDMAP_00001628 +BDMAP_00003479 +BDMAP_00003396 +BDMAP_00000431 +BDMAP_00004077 +BDMAP_00002899 +BDMAP_00000542 +BDMAP_00000438 +BDMAP_00003277 +BDMAP_00002295 +BDMAP_00005140 +BDMAP_00004183 +BDMAP_00002029 +BDMAP_00003385 +BDMAP_00000447 +BDMAP_00004262 +BDMAP_00000430 +BDMAP_00001247 +BDMAP_00003809 +BDMAP_00000771 +BDMAP_00004773 +BDMAP_00001175 +BDMAP_00000774 +BDMAP_00001419 +BDMAP_00003319 +BDMAP_00001712 +BDMAP_00004129 +BDMAP_00002688 +BDMAP_00004858 +BDMAP_00003886 +BDMAP_00004184 +BDMAP_00000589 +BDMAP_00001414 +BDMAP_00001590 +BDMAP_00002896 +BDMAP_00005064 +BDMAP_00004514 +BDMAP_00003884 +BDMAP_00001565 +BDMAP_00000236 +BDMAP_00001736 +BDMAP_00004895 +BDMAP_00001597 +BDMAP_00003631 +BDMAP_00000692 +BDMAP_00004843 +BDMAP_00004288 +BDMAP_00000623 +BDMAP_00004398 +BDMAP_00001368 +BDMAP_00000701 +BDMAP_00002855 +BDMAP_00004293 +BDMAP_00001806 +BDMAP_00000882 +BDMAP_00004796 +BDMAP_00002603 +BDMAP_00005155 +BDMAP_00001836 +BDMAP_00001440 +BDMAP_00004295 +BDMAP_00000859 +BDMAP_00002120 +BDMAP_00001092 +BDMAP_00002171 +BDMAP_00002947 +BDMAP_00005169 +BDMAP_00004015 +BDMAP_00001804 +BDMAP_00003329 +BDMAP_00003657 +BDMAP_00000427 +BDMAP_00001921 +BDMAP_00003215 +BDMAP_00001521 +BDMAP_00001288 +BDMAP_00003918 +BDMAP_00004097 +BDMAP_00003598 +BDMAP_00000614 +BDMAP_00004541 +BDMAP_00004264 +BDMAP_00001618 +BDMAP_00001842 +BDMAP_00002076 +BDMAP_00002332 +BDMAP_00003683 +BDMAP_00001214 +BDMAP_00003685 +BDMAP_00002244 +BDMAP_00003114 +BDMAP_00001057 +BDMAP_00004917 +BDMAP_00003543 +BDMAP_00003633 +BDMAP_00001898 +BDMAP_00000683 +BDMAP_00005141 +BDMAP_00003853 +BDMAP_00003650 +BDMAP_00002619 +BDMAP_00002250 +BDMAP_00002304 +BDMAP_00002815 +BDMAP_00002188 +BDMAP_00001701 +BDMAP_00004023 +BDMAP_00002233 +BDMAP_00003130 +BDMAP_00004286 +BDMAP_00002227 +BDMAP_00003254 +BDMAP_00003376 +BDMAP_00001441 +BDMAP_00004954 +BDMAP_00000052 +BDMAP_00000558 +BDMAP_00005092 +BDMAP_00000993 +BDMAP_00001912 +BDMAP_00003168 +BDMAP_00001545 +BDMAP_00005078 +BDMAP_00000618 +BDMAP_00004546 +BDMAP_00002580 +BDMAP_00000197 +BDMAP_00000972 +BDMAP_00002237 +BDMAP_00004549 +BDMAP_00004841 +BDMAP_00004741 +BDMAP_00003824 +BDMAP_00005108 +BDMAP_00004651 +BDMAP_00005037 +BDMAP_00000470 +BDMAP_00002829 +BDMAP_00003438 +BDMAP_00002411 +BDMAP_00004793 +BDMAP_00004636 +BDMAP_00004641 +BDMAP_00002737 +BDMAP_00003356 +BDMAP_00001845 +BDMAP_00004735 +BDMAP_00000338 +BDMAP_00002844 +BDMAP_00001584 +BDMAP_00003900 +BDMAP_00002232 +BDMAP_00004297 +BDMAP_00003400 +BDMAP_00002758 +BDMAP_00002475 diff --git a/Generation_Pipeline_filter/real_total.txt b/Generation_Pipeline_filter/real_total.txt new file mode 100644 index 0000000000000000000000000000000000000000..6ad82b735dc6c762eb481d2f0d29160601f75127 --- /dev/null +++ b/Generation_Pipeline_filter/real_total.txt @@ -0,0 +1,1054 @@ +BDMAP_00002856 +BDMAP_00004199 +BDMAP_00000709 +BDMAP_00003481 +BDMAP_00002719 +BDMAP_00005083 +BDMAP_00002359 +BDMAP_00000642 +BDMAP_00000778 +BDMAP_00000745 +BDMAP_00000607 +BDMAP_00001236 +BDMAP_00001333 +BDMAP_00003920 +BDMAP_00003664 +BDMAP_00003911 +BDMAP_00002463 +BDMAP_00002419 +BDMAP_00000965 +BDMAP_00003513 +BDMAP_00004508 +BDMAP_00002283 +BDMAP_00004509 +BDMAP_00000615 +BDMAP_00001171 +BDMAP_00001343 +BDMAP_00002167 +BDMAP_00000205 +BDMAP_00002805 +BDMAP_00002275 +BDMAP_00002485 +BDMAP_00004228 +BDMAP_00004304 +BDMAP_00004187 +BDMAP_00001379 +BDMAP_00001753 +BDMAP_00000413 +BDMAP_00002289 +BDMAP_00000572 +BDMAP_00005119 +BDMAP_00004017 +BDMAP_00004016 +BDMAP_00002349 +BDMAP_00000101 +BDMAP_00003482 +BDMAP_00004839 +BDMAP_00001025 +BDMAP_00003361 +BDMAP_00002495 +BDMAP_00001055 +BDMAP_00002214 +BDMAP_00005097 +BDMAP_00005168 +BDMAP_00002267 +BDMAP_00001198 +BDMAP_00002918 +BDMAP_00004664 +BDMAP_00004888 +BDMAP_00000921 +BDMAP_00002373 +BDMAP_00001316 +BDMAP_00002117 +BDMAP_00001361 +BDMAP_00002863 +BDMAP_00002309 +BDMAP_00001905 +BDMAP_00004115 +BDMAP_00002216 +BDMAP_00004829 +BDMAP_00003443 +BDMAP_00001504 +BDMAP_00004885 +BDMAP_00003451 +BDMAP_00000679 +BDMAP_00002362 +BDMAP_00000388 +BDMAP_00003769 +BDMAP_00004198 +BDMAP_00004719 +BDMAP_00000809 +BDMAP_00003525 +BDMAP_00003138 +BDMAP_00005063 +BDMAP_00000676 +BDMAP_00000411 +BDMAP_00002523 +BDMAP_00003367 +BDMAP_00003961 +BDMAP_00003822 +BDMAP_00000462 +BDMAP_00001632 +BDMAP_00003840 +BDMAP_00003483 +BDMAP_00002313 +BDMAP_00000154 +BDMAP_00001828 +BDMAP_00003771 +BDMAP_00004550 +BDMAP_00001628 +BDMAP_00003479 +BDMAP_00003396 +BDMAP_00000431 +BDMAP_00004077 +BDMAP_00002899 +BDMAP_00000542 +BDMAP_00000438 +BDMAP_00003277 +BDMAP_00002295 +BDMAP_00005140 +BDMAP_00004183 +BDMAP_00002029 +BDMAP_00003385 +BDMAP_00000447 +BDMAP_00004262 +BDMAP_00000430 +BDMAP_00001247 +BDMAP_00003809 +BDMAP_00000771 +BDMAP_00004773 +BDMAP_00001175 +BDMAP_00000774 +BDMAP_00001419 +BDMAP_00003319 +BDMAP_00001712 +BDMAP_00004129 +BDMAP_00002688 +BDMAP_00004858 +BDMAP_00003886 +BDMAP_00004184 +BDMAP_00000589 +BDMAP_00001414 +BDMAP_00001590 +BDMAP_00002896 +BDMAP_00005064 +BDMAP_00004514 +BDMAP_00003884 +BDMAP_00001565 +BDMAP_00000236 +BDMAP_00001736 +BDMAP_00004895 +BDMAP_00001597 +BDMAP_00003631 +BDMAP_00000692 +BDMAP_00004843 +BDMAP_00004288 +BDMAP_00000623 +BDMAP_00004398 +BDMAP_00001368 +BDMAP_00000701 +BDMAP_00002855 +BDMAP_00004293 +BDMAP_00001806 +BDMAP_00000882 +BDMAP_00004796 +BDMAP_00002603 +BDMAP_00005155 +BDMAP_00001836 +BDMAP_00001440 +BDMAP_00004295 +BDMAP_00000859 +BDMAP_00002120 +BDMAP_00001092 +BDMAP_00002171 +BDMAP_00002947 +BDMAP_00005169 +BDMAP_00004015 +BDMAP_00001804 +BDMAP_00003329 +BDMAP_00003657 +BDMAP_00000427 +BDMAP_00001921 +BDMAP_00003215 +BDMAP_00001521 +BDMAP_00001288 +BDMAP_00003918 +BDMAP_00004097 +BDMAP_00003598 +BDMAP_00000614 +BDMAP_00004541 +BDMAP_00004264 +BDMAP_00001618 +BDMAP_00001842 +BDMAP_00002076 +BDMAP_00002332 +BDMAP_00003683 +BDMAP_00001214 +BDMAP_00003685 +BDMAP_00002244 +BDMAP_00003114 +BDMAP_00001057 +BDMAP_00004917 +BDMAP_00003543 +BDMAP_00003633 +BDMAP_00001898 +BDMAP_00000683 +BDMAP_00005141 +BDMAP_00003853 +BDMAP_00003650 +BDMAP_00002619 +BDMAP_00002250 +BDMAP_00002304 +BDMAP_00002815 +BDMAP_00002188 +BDMAP_00001701 +BDMAP_00004023 +BDMAP_00002233 +BDMAP_00003130 +BDMAP_00004286 +BDMAP_00002227 +BDMAP_00003254 +BDMAP_00003376 +BDMAP_00001441 +BDMAP_00004954 +BDMAP_00000052 +BDMAP_00000558 +BDMAP_00005092 +BDMAP_00000993 +BDMAP_00001912 +BDMAP_00003168 +BDMAP_00001545 +BDMAP_00005078 +BDMAP_00000618 +BDMAP_00004546 +BDMAP_00002580 +BDMAP_00000197 +BDMAP_00000972 +BDMAP_00002237 +BDMAP_00004549 +BDMAP_00004841 +BDMAP_00004741 +BDMAP_00003824 +BDMAP_00005108 +BDMAP_00004651 +BDMAP_00005037 +BDMAP_00000470 +BDMAP_00002829 +BDMAP_00003438 +BDMAP_00002411 +BDMAP_00004793 +BDMAP_00004636 +BDMAP_00004641 +BDMAP_00002737 +BDMAP_00003356 +BDMAP_00001845 +BDMAP_00004735 +BDMAP_00000338 +BDMAP_00002844 +BDMAP_00001584 +BDMAP_00003900 +BDMAP_00002232 +BDMAP_00004297 +BDMAP_00003400 +BDMAP_00002758 +BDMAP_00002475 +BDMAP_00000245 +BDMAP_00000036 +BDMAP_00003833 +BDMAP_00001517 +BDMAP_00004087 +BDMAP_00002807 +BDMAP_00002099 +BDMAP_00001602 +BDMAP_00001035 +BDMAP_00002422 +BDMAP_00000626 +BDMAP_00002173 +BDMAP_00000240 +BDMAP_00001246 +BDMAP_00000582 +BDMAP_00003996 +BDMAP_00001707 +BDMAP_00000923 +BDMAP_00003411 +BDMAP_00004113 +BDMAP_00002582 +BDMAP_00001261 +BDMAP_00005167 +BDMAP_00004897 +BDMAP_00001169 +BDMAP_00001148 +BDMAP_00002164 +BDMAP_00002041 +BDMAP_00000889 +BDMAP_00001109 +BDMAP_00005009 +BDMAP_00001286 +BDMAP_00000297 +BDMAP_00005099 +BDMAP_00004257 +BDMAP_00005017 +BDMAP_00000604 +BDMAP_00002472 +BDMAP_00001225 +BDMAP_00005081 +BDMAP_00003491 +BDMAP_00001635 +BDMAP_00002075 +BDMAP_00000660 +BDMAP_00001238 +BDMAP_00002656 +BDMAP_00003558 +BDMAP_00001104 +BDMAP_00004066 +BDMAP_00003294 +BDMAP_00001607 +BDMAP_00001077 +BDMAP_00000653 +BDMAP_00001273 +BDMAP_00000616 +BDMAP_00002057 +BDMAP_00004586 +BDMAP_00004407 +BDMAP_00004922 +BDMAP_00002592 +BDMAP_00000149 +BDMAP_00000320 +BDMAP_00001511 +BDMAP_00000435 +BDMAP_00002746 +BDMAP_00004457 +BDMAP_00000805 +BDMAP_00002661 +BDMAP_00004552 +BDMAP_00004154 +BDMAP_00002902 +BDMAP_00000839 +BDMAP_00000233 +BDMAP_00000122 +BDMAP_00005151 +BDMAP_00004427 +BDMAP_00002936 +BDMAP_00003955 +BDMAP_00001863 +BDMAP_00002326 +BDMAP_00001420 +BDMAP_00000329 +BDMAP_00004561 +BDMAP_00003971 +BDMAP_00000935 +BDMAP_00000569 +BDMAP_00004956 +BDMAP_00000285 +BDMAP_00004597 +BDMAP_00001747 +BDMAP_00001059 +BDMAP_00002354 +BDMAP_00001656 +BDMAP_00004395 +BDMAP_00002942 +BDMAP_00004981 +BDMAP_00001768 +BDMAP_00002319 +BDMAP_00003947 +BDMAP_00001868 +BDMAP_00002065 +BDMAP_00002333 +BDMAP_00003358 +BDMAP_00001265 +BDMAP_00003952 +BDMAP_00001891 +BDMAP_00003576 +BDMAP_00000980 +BDMAP_00003300 +BDMAP_00001782 +BDMAP_00003717 +BDMAP_00001251 +BDMAP_00000044 +BDMAP_00004510 +BDMAP_00003315 +BDMAP_00002653 +BDMAP_00001045 +BDMAP_00003694 +BDMAP_00004216 +BDMAP_00001794 +BDMAP_00000532 +BDMAP_00002288 +BDMAP_00001256 +BDMAP_00000219 +BDMAP_00000710 +BDMAP_00003930 +BDMAP_00001636 +BDMAP_00003749 +BDMAP_00000998 +BDMAP_00000176 +BDMAP_00000429 +BDMAP_00001001 +BDMAP_00001908 +BDMAP_00003363 +BDMAP_00004903 +BDMAP_00004482 +BDMAP_00003178 +BDMAP_00003202 +BDMAP_00001230 +BDMAP_00003461 +BDMAP_00003281 +BDMAP_00000434 +BDMAP_00001218 +BDMAP_00003976 +BDMAP_00003455 +BDMAP_00001183 +BDMAP_00002609 +BDMAP_00001305 +BDMAP_00000364 +BDMAP_00003516 +BDMAP_00003956 +BDMAP_00000977 +BDMAP_00001784 +BDMAP_00004389 +BDMAP_00001711 +BDMAP_00000698 +BDMAP_00003153 +BDMAP_00001995 +BDMAP_00001549 +BDMAP_00001324 +BDMAP_00004195 +BDMAP_00001562 +BDMAP_00004074 +BDMAP_00001483 +BDMAP_00002085 +BDMAP_00001396 +BDMAP_00000241 +BDMAP_00004031 +BDMAP_00004775 +BDMAP_00001807 +BDMAP_00005120 +BDMAP_00004065 +BDMAP_00003943 +BDMAP_00002953 +BDMAP_00004232 +BDMAP_00002184 +BDMAP_00002407 +BDMAP_00003252 +BDMAP_00004296 +BDMAP_00000161 +BDMAP_00002981 +BDMAP_00003608 +BDMAP_00003128 +BDMAP_00000571 +BDMAP_00000259 +BDMAP_00003444 +BDMAP_00001647 +BDMAP_00000662 +BDMAP_00003774 +BDMAP_00001383 +BDMAP_00004616 +BDMAP_00001906 +BDMAP_00003740 +BDMAP_00001422 +BDMAP_00002631 +BDMAP_00004294 +BDMAP_00003994 +BDMAP_00004475 +BDMAP_00002744 +BDMAP_00001068 +BDMAP_00000667 +BDMAP_00001945 +BDMAP_00002710 +BDMAP_00002440 +BDMAP_00000833 +BDMAP_00003143 +BDMAP_00000062 +BDMAP_00003392 +BDMAP_00004373 +BDMAP_00001020 +BDMAP_00003603 +BDMAP_00001027 +BDMAP_00005114 +BDMAP_00003384 +BDMAP_00000794 +BDMAP_00001911 +BDMAP_00002437 +BDMAP_00004579 +BDMAP_00004250 +BDMAP_00002068 +BDMAP_00000608 +BDMAP_00004551 +BDMAP_00002884 +BDMAP_00004033 +BDMAP_00005105 +BDMAP_00002776 +BDMAP_00000414 +BDMAP_00003580 +BDMAP_00004712 +BDMAP_00002114 +BDMAP_00002226 +BDMAP_00003923 +BDMAP_00002854 +BDMAP_00004039 +BDMAP_00004014 +BDMAP_00001289 +BDMAP_00003435 +BDMAP_00004578 +BDMAP_00002940 +BDMAP_00003164 +BDMAP_00002751 +BDMAP_00001516 +BDMAP_00003486 +BDMAP_00000279 +BDMAP_00001664 +BDMAP_00004738 +BDMAP_00001735 +BDMAP_00000562 +BDMAP_00000812 +BDMAP_00000511 +BDMAP_00004746 +BDMAP_00000452 +BDMAP_00004328 +BDMAP_00002017 +BDMAP_00002840 +BDMAP_00000039 +BDMAP_00002242 +BDMAP_00002775 +BDMAP_00003762 +BDMAP_00000229 +BDMAP_00003520 +BDMAP_00000725 +BDMAP_00000516 +BDMAP_00001941 +BDMAP_00003928 +BDMAP_00001255 +BDMAP_00001456 +BDMAP_00002410 +BDMAP_00002742 +BDMAP_00001688 +BDMAP_00000487 +BDMAP_00000469 +BDMAP_00002022 +BDMAP_00003058 +BDMAP_00004148 +BDMAP_00001977 +BDMAP_00000887 +BDMAP_00003448 +BDMAP_00001410 +BDMAP_00002383 +BDMAP_00003736 +BDMAP_00002626 +BDMAP_00001710 +BDMAP_00001130 +BDMAP_00001138 +BDMAP_00001413 +BDMAP_00003815 +BDMAP_00004130 +BDMAP_00004652 +BDMAP_00002864 +BDMAP_00000574 +BDMAP_00003493 +BDMAP_00003364 +BDMAP_00002648 +BDMAP_00001281 +BDMAP_00002655 +BDMAP_00001126 +BDMAP_00002804 +BDMAP_00000321 +BDMAP_00005191 +BDMAP_00004420 +BDMAP_00000304 +BDMAP_00003150 +BDMAP_00004620 +BDMAP_00000368 +BDMAP_00000066 +BDMAP_00003701 +BDMAP_00005174 +BDMAP_00002545 +BDMAP_00003957 +BDMAP_00004331 +BDMAP_00000687 +BDMAP_00001791 +BDMAP_00002959 +BDMAP_00004104 +BDMAP_00003073 +BDMAP_00003713 +BDMAP_00002363 +BDMAP_00000137 +BDMAP_00000104 +BDMAP_00002689 +BDMAP_00004990 +BDMAP_00003301 +BDMAP_00001434 +BDMAP_00000449 +BDMAP_00005113 +BDMAP_00003225 +BDMAP_00001359 +BDMAP_00001223 +BDMAP_00002803 +BDMAP_00000355 +BDMAP_00001826 +BDMAP_00004673 +BDMAP_00002251 +BDMAP_00000439 +BDMAP_00005085 +BDMAP_00003381 +BDMAP_00004645 +BDMAP_00000432 +BDMAP_00001444 +BDMAP_00001705 +BDMAP_00001892 +BDMAP_00002826 +BDMAP_00004671 +BDMAP_00000926 +BDMAP_00004817 +BDMAP_00004175 +BDMAP_00003484 +BDMAP_00003672 +BDMAP_00003267 +BDMAP_00001089 +BDMAP_00001496 +BDMAP_00003615 +BDMAP_00003832 +BDMAP_00002695 +BDMAP_00002696 +BDMAP_00004499 +BDMAP_00004867 +BDMAP_00004479 +BDMAP_00003600 +BDMAP_00000989 +BDMAP_00002421 +BDMAP_00003406 +BDMAP_00000263 +BDMAP_00002396 +BDMAP_00002265 +BDMAP_00000713 +BDMAP_00000883 +BDMAP_00001258 +BDMAP_00004253 +BDMAP_00004870 +BDMAP_00000331 +BDMAP_00004608 +BDMAP_00001518 +BDMAP_00002562 +BDMAP_00002889 +BDMAP_00001676 +BDMAP_00000117 +BDMAP_00003973 +BDMAP_00002509 +BDMAP_00002487 +BDMAP_00003457 +BDMAP_00000982 +BDMAP_00002260 +BDMAP_00001283 +BDMAP_00003506 +BDMAP_00000366 +BDMAP_00002133 +BDMAP_00000465 +BDMAP_00003767 +BDMAP_00001853 +BDMAP_00002361 +BDMAP_00004815 +BDMAP_00002933 +BDMAP_00000162 +BDMAP_00004925 +BDMAP_00005077 +BDMAP_00001533 +BDMAP_00001242 +BDMAP_00000871 +BDMAP_00000948 +BDMAP_00001119 +BDMAP_00004887 +BDMAP_00002404 +BDMAP_00003722 +BDMAP_00002426 +BDMAP_00002060 +BDMAP_00004850 +BDMAP_00003343 +BDMAP_00001624 +BDMAP_00000481 +BDMAP_00002166 +BDMAP_00003849 +BDMAP_00004808 +BDMAP_00002471 +BDMAP_00000656 +BDMAP_00003581 +BDMAP_00000023 +BDMAP_00003727 +BDMAP_00000319 +BDMAP_00003255 +BDMAP_00003752 +BDMAP_00000139 +BDMAP_00003614 +BDMAP_00003549 +BDMAP_00003808 +BDMAP_00002930 +BDMAP_00001128 +BDMAP_00004717 +BDMAP_00000826 +BDMAP_00002663 +BDMAP_00000837 +BDMAP_00000159 +BDMAP_00005154 +BDMAP_00002524 +BDMAP_00000968 +BDMAP_00004278 +BDMAP_00001325 +BDMAP_00000987 +BDMAP_00004901 +BDMAP_00003425 +BDMAP_00005006 +BDMAP_00004131 +BDMAP_00002403 +BDMAP_00001620 +BDMAP_00002347 +BDMAP_00001522 +BDMAP_00004011 +BDMAP_00001474 +BDMAP_00004744 +BDMAP_00002484 +BDMAP_00001370 +BDMAP_00003324 +BDMAP_00001557 +BDMAP_00000867 +BDMAP_00001487 +BDMAP_00004980 +BDMAP_00000034 +BDMAP_00000936 +BDMAP_00000128 +BDMAP_00001275 +BDMAP_00004030 +BDMAP_00003359 +BDMAP_00003070 +BDMAP_00002476 +BDMAP_00002990 +BDMAP_00000810 +BDMAP_00003514 +BDMAP_00004834 +BDMAP_00003409 +BDMAP_00002498 +BDMAP_00004481 +BDMAP_00002273 +BDMAP_00002496 +BDMAP_00002871 +BDMAP_00000059 +BDMAP_00001475 +BDMAP_00000902 +BDMAP_00004417 +BDMAP_00005157 +BDMAP_00001752 +BDMAP_00001563 +BDMAP_00003063 +BDMAP_00001296 +BDMAP_00002707 +BDMAP_00000836 +BDMAP_00000353 +BDMAP_00000043 +BDMAP_00000244 +BDMAP_00000264 +BDMAP_00000690 +BDMAP_00002039 +BDMAP_00001426 +BDMAP_00002730 +BDMAP_00001917 +BDMAP_00005067 +BDMAP_00002924 +BDMAP_00005160 +BDMAP_00005073 +BDMAP_00000547 +BDMAP_00000942 +BDMAP_00002103 +BDMAP_00002654 +BDMAP_00004374 +BDMAP_00003510 +BDMAP_00004910 +BDMAP_00004558 +BDMAP_00004450 +BDMAP_00000152 +BDMAP_00004491 +BDMAP_00001237 +BDMAP_00001785 +BDMAP_00001865 +BDMAP_00000851 +BDMAP_00003357 +BDMAP_00004415 +BDMAP_00004615 +BDMAP_00003680 +BDMAP_00001875 +BDMAP_00004894 +BDMAP_00001835 +BDMAP_00000069 +BDMAP_00001809 +BDMAP_00004431 +BDMAP_00002704 +BDMAP_00002185 +BDMAP_00004384 +BDMAP_00003299 +BDMAP_00003333 +BDMAP_00002305 +BDMAP_00001598 +BDMAP_00002465 +BDMAP_00002199 +BDMAP_00002875 +BDMAP_00000828 +BDMAP_00003564 +BDMAP_00005001 +BDMAP_00004493 +BDMAP_00000190 +BDMAP_00000873 +BDMAP_00005170 +BDMAP_00002152 +BDMAP_00004163 +BDMAP_00000939 +BDMAP_00001212 +BDMAP_00001982 +BDMAP_00000552 +BDMAP_00004764 +BDMAP_00002401 +BDMAP_00002451 +BDMAP_00003634 +BDMAP_00005016 +BDMAP_00000716 +BDMAP_00003373 +BDMAP_00000030 +BDMAP_00003946 +BDMAP_00002828 +BDMAP_00004196 +BDMAP_00005005 +BDMAP_00003972 +BDMAP_00003172 +BDMAP_00004783 +BDMAP_00001102 +BDMAP_00004147 +BDMAP_00004604 +BDMAP_00000400 +BDMAP_00003497 +BDMAP_00001270 +BDMAP_00001766 +BDMAP_00001309 +BDMAP_00004745 +BDMAP_00003002 +BDMAP_00004825 +BDMAP_00004416 +BDMAP_00002712 +BDMAP_00004830 +BDMAP_00000907 +BDMAP_00001957 +BDMAP_00000941 +BDMAP_00002841 +BDMAP_00001962 +BDMAP_00004462 +BDMAP_00004281 +BDMAP_00004890 +BDMAP_00003272 +BDMAP_00003377 +BDMAP_00005186 +BDMAP_00002172 +BDMAP_00000091 +BDMAP_00004639 +BDMAP_00000918 +BDMAP_00000671 +BDMAP_00004028 +BDMAP_00004529 +BDMAP_00001907 +BDMAP_00001122 +BDMAP_00003151 +BDMAP_00002252 +BDMAP_00003524 +BDMAP_00004704 +BDMAP_00000362 +BDMAP_00003932 +BDMAP_00004995 +BDMAP_00002748 +BDMAP_00004117 +BDMAP_00000480 +BDMAP_00001010 +BDMAP_00000100 +BDMAP_00001200 +BDMAP_00004103 +BDMAP_00004878 +BDMAP_00002282 +BDMAP_00001471 +BDMAP_00000232 +BDMAP_00003439 +BDMAP_00003857 +BDMAP_00004943 +BDMAP_00005130 +BDMAP_00002479 +BDMAP_00002909 +BDMAP_00004185 +BDMAP_00003569 +BDMAP_00001185 +BDMAP_00001078 +BDMAP_00003031 +BDMAP_00002253 +BDMAP_00001732 +BDMAP_00000874 +BDMAP_00003847 +BDMAP_00003268 +BDMAP_00002846 +BDMAP_00001438 +BDMAP_00004650 +BDMAP_00003109 +BDMAP_00004121 +BDMAP_00004165 +BDMAP_00004676 +BDMAP_00003890 +BDMAP_00003327 +BDMAP_00000132 +BDMAP_00001215 +BDMAP_00001769 +BDMAP_00003412 +BDMAP_00002318 +BDMAP_00004624 +BDMAP_00000345 +BDMAP_00002230 +BDMAP_00003111 +BDMAP_00001015 +BDMAP_00001514 +BDMAP_00001924 +BDMAP_00002845 +BDMAP_00002598 +BDMAP_00001209 +BDMAP_00000373 +BDMAP_00001737 +BDMAP_00003113 +BDMAP_00004876 +BDMAP_00003640 +BDMAP_00001985 +BDMAP_00000138 +BDMAP_00000881 +BDMAP_00002739 +BDMAP_00003560 +BDMAP_00002612 +BDMAP_00001445 +BDMAP_00003827 +BDMAP_00001024 +BDMAP_00000568 +BDMAP_00001095 +BDMAP_00002458 +BDMAP_00002986 +BDMAP_00000913 +BDMAP_00002849 +BDMAP_00003556 +BDMAP_00003052 +BDMAP_00000971 +BDMAP_00003330 +BDMAP_00000113 +BDMAP_00004600 +BDMAP_00002529 +BDMAP_00000437 +BDMAP_00003074 +BDMAP_00005139 +BDMAP_00001966 +BDMAP_00002791 +BDMAP_00001692 +BDMAP_00001786 +BDMAP_00001697 +BDMAP_00003798 +BDMAP_00000273 +BDMAP_00001114 +BDMAP_00003898 +BDMAP_00001397 +BDMAP_00003867 +BDMAP_00005065 +BDMAP_00001802 +BDMAP_00001539 +BDMAP_00000084 +BDMAP_00002955 +BDMAP_00002271 +BDMAP_00004459 +BDMAP_00004378 +BDMAP_00004435 +BDMAP_00001093 +BDMAP_00003897 +BDMAP_00003236 +BDMAP_00001502 +BDMAP_00001834 +BDMAP_00000347 +BDMAP_00000831 +BDMAP_00002717 +BDMAP_00003244 +BDMAP_00005074 +BDMAP_00004804 +BDMAP_00004672 +BDMAP_00003133 +BDMAP_00004969 +BDMAP_00002278 +BDMAP_00001862 +BDMAP_00005185 +BDMAP_00004880 +BDMAP_00004770 +BDMAP_00002690 +BDMAP_00002944 +BDMAP_00003744 +BDMAP_00002021 +BDMAP_00003141 +BDMAP_00004927 +BDMAP_00001476 +BDMAP_00003551 +BDMAP_00004964 +BDMAP_00001605 +BDMAP_00002298 +BDMAP_00001746 +BDMAP_00000332 +BDMAP_00003590 +BDMAP_00000956 +BDMAP_00001649 +BDMAP_00003781 +BDMAP_00001523 +BDMAP_00003347 +BDMAP_00005022 +BDMAP_00004128 +BDMAP_00003612 +BDMAP_00003658 +BDMAP_00003812 +BDMAP_00003427 +BDMAP_00003502 +BDMAP_00001823 +BDMAP_00004847 +BDMAP_00003776 +BDMAP_00001205 +BDMAP_00000192 +BDMAP_00004511 +BDMAP_00001564 +BDMAP_00000416 +BDMAP_00005070 +BDMAP_00001040 +BDMAP_00004231 +BDMAP_00002945 +BDMAP_00001704 +BDMAP_00002402 +BDMAP_00000940 +BDMAP_00000243 +BDMAP_00001464 +BDMAP_00002793 +BDMAP_00001646 +BDMAP_00005020 +BDMAP_00004992 +BDMAP_00003017 +BDMAP_00001096 +BDMAP_00001067 +BDMAP_00001331 +BDMAP_00000696 +BDMAP_00001461 +BDMAP_00003326 +BDMAP_00000715 +BDMAP_00000855 +BDMAP_00000087 +BDMAP_00000093 +BDMAP_00000324 +BDMAP_00003440 +BDMAP_00002387 +BDMAP_00004060 +BDMAP_00000714 +BDMAP_00001617 +BDMAP_00004494 +BDMAP_00002616 +BDMAP_00000225 +BDMAP_00001754 +BDMAP_00005075 +BDMAP_00002328 +BDMAP_00004229 +BDMAP_00000541 +BDMAP_00004447 +BDMAP_00004106 +BDMAP_00003592 +BDMAP_00003036 +BDMAP_00001125 diff --git a/Generation_Pipeline_filter/resample.py b/Generation_Pipeline_filter/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..1179efedb86328d1832d50547f1b29877708d608 --- /dev/null +++ b/Generation_Pipeline_filter/resample.py @@ -0,0 +1,120 @@ +import os, time, csv +import numpy as np +import torch +from sklearn.metrics import confusion_matrix +from scipy import ndimage +from scipy.ndimage import label +from functools import partial +import monai +from monai.transforms import AsDiscrete,AsDiscreted,Compose,Invertd,SaveImaged +from monai import transforms, data +import nibabel as nib + +import warnings +warnings.filterwarnings("ignore") + +import argparse +parser = argparse.ArgumentParser(description='colon tumor validation') + +# file dir +parser.add_argument('--data_root', default=None, type=str) +parser.add_argument('--organ_type', default='colon', type=str) +parser.add_argument('--save_dir', default='out', type=str) +parser.add_argument('--data_file', default='out', type=str) +parser.add_argument('--ddim_ts', default=50, type=int) +parser.add_argument('--fg_thresh', default=30, type=int) +parser.add_argument('--start', default=0, type=int) +parser.add_argument('--end', default=1000, type=int) + +def voxel2R(A): + return (np.array(A)/4*3/np.pi)**(1/3) + +def _get_loader(args): + # val_data_dir = args.val_dir + # datalist_json = args.json_dir + val_org_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image"], dtype=np.int16), + transforms.AddChanneld(keys=["image"]), + transforms.Orientationd(keys=["image"], axcodes="RAS"), + transforms.Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear")), + # transforms.ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), + transforms.ToTensord(keys=["image"]), + ] + ) + + val_img=[] + val_name=[] + + for line in open(args.data_file): + # name = line.strip().split()[1].split('.')[0] + # val_img.append(args.data_root + line.strip().split()[0]) + # val_lbl.append(args.data_root + line.strip().split()[1]) + # breakpoint() + name = line.strip() + val_img.append(os.path.join(args.data_root, name, 'ct.nii.gz')) + val_name.append(name) + data_dicts_val = [{'image': image, 'name': name} + for image, name in zip(val_img, val_name)] + + if args.end < len(data_dicts_val): + data_dicts_val = data_dicts_val[args.start:args.end] + else: + data_dicts_val = data_dicts_val[args.start:] + print('val len {}'.format(len(data_dicts_val))) + val_org_ds = data.Dataset(data_dicts_val, transform=val_org_transform) + val_org_loader = data.DataLoader(val_org_ds, batch_size=1, shuffle=False, num_workers=4, sampler=None, pin_memory=True) + + post_transforms = Compose([ + Invertd( + keys=['image'], + transform=val_org_transform, + orig_keys="image", + nearest_interp=False, + # nearest_interp=True, + to_tensor=True, + ) + ]) + return val_org_loader, post_transforms + +def main(): + args = parser.parse_args() + output_dir = args.save_dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + print("MAIN Argument values:") + for k, v in vars(args).items(): + print(k, '=>', v) + print('-----------------') + + ## loader and post_transform + val_loader, post_transforms = _get_loader(args) + + start_time=0 + with torch.no_grad(): + for idx, val_data in enumerate(val_loader): + print('idx',idx) + if idx == 0: + start_time = time.time() + + data_names = val_data['name'] + case_name = data_names[0].split('/')[-1] + print('case_name', case_name) + original_affine = val_data["image_meta_dict"]["original_affine"][0].numpy() + + val_data = [post_transforms(i) for i in data.decollate_batch(val_data)] + synt_data = val_data[0]['image'][0] + final_data = synt_data.cpu().numpy() + + # synt_data = val_data['image'] + # final_data = synt_data.cpu().numpy()[0,0] + # breakpoint() + os.makedirs(os.path.join(output_dir, f'{case_name}'), exist_ok=True) + nib.save(nib.Nifti1Image(final_data, original_affine), os.path.join(output_dir, f'{case_name}/ct.nii.gz')) + # breakpoint() + print('time = ', time.time()-start_time) + start_time = time.time() + + +if __name__ == "__main__": + main() diff --git a/Generation_Pipeline_filter/syn_colon/CT_syn_colon_data_new.py b/Generation_Pipeline_filter/syn_colon/CT_syn_colon_data_new.py new file mode 100644 index 0000000000000000000000000000000000000000..47826dbd7c886251c0dee31969415805bb3ac729 --- /dev/null +++ b/Generation_Pipeline_filter/syn_colon/CT_syn_colon_data_new.py @@ -0,0 +1,230 @@ +import os, time, csv +import numpy as np +import torch +from sklearn.metrics import confusion_matrix +from scipy import ndimage +from scipy.ndimage import label +from functools import partial +import monai +from monai.transforms import AsDiscrete,AsDiscreted,Compose,Invertd,SaveImaged +from monai import transforms, data +from TumorGeneration.utils import synthesize_colon_tumor, synt_model_prepare +import nibabel as nib + +import warnings +warnings.filterwarnings("ignore") + +import argparse +parser = argparse.ArgumentParser(description='colon tumor validation') + +# file dir +parser.add_argument('--data_root', default=None, type=str) +parser.add_argument('--organ_type', default='colon', type=str) +parser.add_argument('--save_dir', default='out', type=str) +parser.add_argument('--data_file', default='out', type=str) +parser.add_argument('--ddim_ts', default=50, type=int) +parser.add_argument('--fg_thresh', default=30, type=int) +parser.add_argument('--start', default=0, type=int) +parser.add_argument('--end', default=1000, type=int) + +def voxel2R(A): + return (np.array(A)/4*3/np.pi)**(1/3) + +class RandCropByPosNegLabeld_select(transforms.RandCropByPosNegLabeld): + def __init__(self, keys, label_key, spatial_size, + pos=1.0, neg=1.0, num_samples=1, + image_key=None, image_threshold=0.0, allow_missing_keys=True, + fg_thresh=0): + super().__init__(keys=keys, label_key=label_key, spatial_size=spatial_size, + pos=pos, neg=neg, num_samples=num_samples, + image_key=image_key, image_threshold=image_threshold, allow_missing_keys=allow_missing_keys) + self.fg_thresh = fg_thresh + + def R2voxel(self,R): + return (4/3*np.pi)*(R)**(3) + + def __call__(self, data): + d = dict(data) + data_name = d['name'] + d.pop('name') + + if '10_Decathlon' in data_name or '05_KiTS' in data_name: + d_crop = super().__call__(d) + + else: + flag=0 + while 1: + flag+=1 + + d_crop = super().__call__(d) + pixel_num = (d_crop[0]['label']>0).sum() + + if pixel_num > self.R2voxel(self.fg_thresh): + break + if flag>5 and pixel_num > self.R2voxel(max(self.fg_thresh-5, 5)): + break + if flag>10 and pixel_num > self.R2voxel(max(self.fg_thresh-10, 5)): + break + if flag>15 and pixel_num > self.R2voxel(max(self.fg_thresh-15, 5)): + break + if flag>20 and pixel_num > self.R2voxel(max(self.fg_thresh-20, 5)): + break + if flag>25 and pixel_num > self.R2voxel(max(self.fg_thresh-25, 5)): + break + if flag>50: + break + + d_crop[0]['name'] = data_name + + return d_crop + +def _get_loader(args): + # val_data_dir = args.val_dir + # datalist_json = args.json_dir + val_org_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label", "raw_image"]), + transforms.AddChanneld(keys=["image", "label", "raw_image"]), + transforms.Orientationd(keys=["image", "label"], axcodes="RAS"), + transforms.Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear")), + transforms.ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), + transforms.SpatialPadd(keys=["image", "label"], mode=["minimum", "constant"], spatial_size=[96, 96, 96]), + RandCropByPosNegLabeld_select( + keys=["image", "label", "name"], + label_key="label", + spatial_size=(96, 96, 96), + pos=1, + neg=0, + num_samples=1, + image_key="image", + image_threshold=0, + fg_thresh = args.fg_thresh, + ), + transforms.ToTensord(keys=["image", "label", "raw_image"]), + ] + ) + + val_img=[] + val_lbl=[] + val_name=[] + + for line in open(args.data_file): + # name = line.strip().split()[1].split('.')[0] + # val_img.append(args.data_root + line.strip().split()[0]) + # val_lbl.append(args.data_root + line.strip().split()[1]) + # breakpoint() + name = line.strip() + val_img.append(os.path.join(args.data_root, name, 'ct.nii.gz')) + val_lbl.append(os.path.join(args.data_root, name, 'segmentations/colon.nii.gz')) + val_name.append(name) + data_dicts_val = [{'image': image, 'raw_image':image, 'label': label, 'name': name} + for image, label, name in zip(val_img, val_lbl, val_name)] + + if args.end < len(data_dicts_val): + data_dicts_val = data_dicts_val[args.start:args.end] + else: + data_dicts_val = data_dicts_val[args.start:] + print('val len {}'.format(len(data_dicts_val))) + val_org_ds = data.Dataset(data_dicts_val, transform=val_org_transform) + val_org_loader = data.DataLoader(val_org_ds, batch_size=1, shuffle=False, num_workers=4, sampler=None, pin_memory=True) + + post_transforms = Compose([ + Invertd( + keys=['image'], + transform=val_org_transform, + orig_keys="image", + nearest_interp=False, + # nearest_interp=True, + to_tensor=True, + ), + Invertd( + keys=['label'], + transform=val_org_transform, + orig_keys="label", + nearest_interp=False, + # nearest_interp=True, + to_tensor=True, + ) + ]) + return val_org_loader, post_transforms + +def main(): + args = parser.parse_args() + output_dir = args.save_dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + print("MAIN Argument values:") + for k, v in vars(args).items(): + print(k, '=>', v) + print('-----------------') + + ## loader and post_transform + val_loader, post_transforms = _get_loader(args) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device) + model.load_state_dict(torch.load("../best_metric_model_classification3d_dict.pth")) + model.eval() + + start_time=0 + with torch.no_grad(): + for idx, val_data in enumerate(val_loader): + print('idx',idx) + if idx == 0: + start_time = time.time() + # val_inputs = val_data["image"] + # name = val_data['label_meta_dict']['filename_or_obj'][0].split('/')[-1].split('.')[0] + + vqgan, noearly_sampler= synt_model_prepare(device = torch.device("cuda"), fold=0, organ=args.organ_type) + + healthy_data, healthy_target, data_names, raw_data = val_data['image'], val_data['label'], val_data['name'], val_data['raw_image'] + case_name = data_names[0].split('/')[-1] + print('case_name', case_name) + original_affine = val_data["label_meta_dict"]["original_affine"][0].numpy() + + if healthy_target.sum() == 0: + val_data = [post_transforms(i) for i in data.decollate_batch(val_data)] + tumor_mask = val_data[0]['label'][0].cpu().numpy().astype(np.uint8) + tumor_mask_ = np.zeros_like(tumor_mask) + nib.save(nib.Nifti1Image(tumor_mask_, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/colon_tumor.nii.gz')) + continue + + healthy_data, healthy_target = healthy_data.cuda(), healthy_target.cuda() + healthy_target = (healthy_target>0).to(healthy_target) + flag=0 + while 1: + flag+=1 + synt_data, synt_target = synthesize_colon_tumor(healthy_data, healthy_target, vqgan, noearly_sampler, ddim_ts=args.ddim_ts) + + syn_confidence = model(synt_data).sigmoid()[:,1] + if syn_confidence>0.01: + break + elif flag > 40 and syn_confidence>0.005: + break + elif flag > 60 and syn_confidence>0.001: + break + + val_data['image'] = synt_data.detach() + val_data['label'] = synt_target.detach() + + val_data = [post_transforms(i) for i in data.decollate_batch(val_data)] + synt_data = val_data[0]['image'][0] + synt_target = val_data[0]['label'][0] + final_data = raw_data[0,0] + + synt_data = (synt_data*(250+175)-175) + final_data[synt_target>1] = synt_data[synt_target>1] + final_data = final_data.cpu().numpy() + final_label = (synt_target>=1.5).cpu().numpy().astype(np.uint8) + + os.makedirs(os.path.join(output_dir, f'{case_name}'), exist_ok=True) + os.makedirs(os.path.join(output_dir, f'{case_name}/segmentations'), exist_ok=True) + nib.save(nib.Nifti1Image(final_data, original_affine), os.path.join(output_dir, f'{case_name}/ct.nii.gz')) + nib.save(nib.Nifti1Image(final_label, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/colon_tumor.nii.gz')) + + print('time = ', time.time()-start_time) + start_time = time.time() + + +if __name__ == "__main__": + main() diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/.DS_Store b/Generation_Pipeline_filter/syn_colon/TumorGeneration/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..fb0d207e124cfecc777faae5b4a56c5ca1b9cd2e Binary files /dev/null and b/Generation_Pipeline_filter/syn_colon/TumorGeneration/.DS_Store differ diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/README.md b/Generation_Pipeline_filter/syn_colon/TumorGeneration/README.md new file mode 100644 index 0000000000000000000000000000000000000000..201911031ca988c410781a60dc3c192a65ee56b3 --- /dev/null +++ b/Generation_Pipeline_filter/syn_colon/TumorGeneration/README.md @@ -0,0 +1,5 @@ +```bash +wget https://www.dropbox.com/scl/fi/k856fhk60kck8uqxtxazw/model_weight.tar.gz?rlkey=hrcn4cbt690dzern1bkbfejll +mv model_weight.tar.gz?rlkey=hrcn4cbt690dzern1bkbfejll model_weight.tar.gz +tar -xzvf model_weight.tar.gz +``` diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/TumorGenerated.py b/Generation_Pipeline_filter/syn_colon/TumorGeneration/TumorGenerated.py new file mode 100644 index 0000000000000000000000000000000000000000..9983cf047b1532f5edd20c0d4c78102d6d611eb9 --- /dev/null +++ b/Generation_Pipeline_filter/syn_colon/TumorGeneration/TumorGenerated.py @@ -0,0 +1,39 @@ +import random +from typing import Hashable, Mapping, Dict + +from monai.config import KeysCollection +from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms.transform import MapTransform, RandomizableTransform + +from .utils_ import SynthesisTumor +import numpy as np + +class TumorGenerated(RandomizableTransform, MapTransform): + def __init__(self, + keys: KeysCollection, + prob: float = 0.1, + tumor_prob = [0.2, 0.2, 0.2, 0.2, 0.2], + allow_missing_keys: bool = False + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob) + random.seed(0) + np.random.seed(0) + + self.tumor_types = ['tiny', 'small', 'medium', 'large', 'mix'] + + assert len(tumor_prob) == 5 + self.tumor_prob = np.array(tumor_prob) + + + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + self.randomize(None) + + if self._do_transform and (np.max(d['label']) <= 1): + tumor_type = np.random.choice(self.tumor_types, p=self.tumor_prob.ravel()) + + d['image'][0], d['label'][0] = SynthesisTumor(d['image'][0], d['label'][0], tumor_type) + # print(tumor_type, d['image'].shape, np.max(d['label'])) + return d diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/__init__.py b/Generation_Pipeline_filter/syn_colon/TumorGeneration/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d9bbd5e8cede113145b2742ebdd63d7226fe6396 --- /dev/null +++ b/Generation_Pipeline_filter/syn_colon/TumorGeneration/__init__.py @@ -0,0 +1,5 @@ +### Online Version TumorGeneration ### + +from .TumorGenerated import TumorGenerated + +# from .utils import synthesize_early_tumor, synthesize_medium_tumor, synthesize_large_tumor \ No newline at end of file diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc b/Generation_Pipeline_filter/syn_colon/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0ebfed0cd33d072a0561f1b2c881ab987c39b98 Binary files /dev/null and b/Generation_Pipeline_filter/syn_colon/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter/syn_colon/TumorGeneration/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a0c6fe95cfdcd5c8b93417763b81cf141f23bef Binary files /dev/null and b/Generation_Pipeline_filter/syn_colon/TumorGeneration/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/__pycache__/utils.cpython-38.pyc b/Generation_Pipeline_filter/syn_colon/TumorGeneration/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0756ebf64f7e3068fc02220df45239da35516ff Binary files /dev/null and b/Generation_Pipeline_filter/syn_colon/TumorGeneration/__pycache__/utils.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/__pycache__/utils_.cpython-38.pyc b/Generation_Pipeline_filter/syn_colon/TumorGeneration/__pycache__/utils_.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2ccff6985d92f26b80eca3e6ab2d9a009aabe71 Binary files /dev/null and b/Generation_Pipeline_filter/syn_colon/TumorGeneration/__pycache__/utils_.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/diffusion_config/ddpm.yaml b/Generation_Pipeline_filter/syn_colon/TumorGeneration/diffusion_config/ddpm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..474a973b1f76c2a026e2df916c4a153b5bbea05f --- /dev/null +++ b/Generation_Pipeline_filter/syn_colon/TumorGeneration/diffusion_config/ddpm.yaml @@ -0,0 +1,29 @@ +vqgan_ckpt: None + +# Have to be derived from VQ-GAN Latent space dimensions +diffusion_img_size: 24 +diffusion_depth_size: 24 +diffusion_num_channels: 17 # 17 +out_dim: 8 +dim_mults: [1,2,4,8] +results_folder: checkpoints/ddpm/ +results_folder_postfix: 'own_dataset_t2' +load_milestone: False # False + +batch_size: 2 # 40 +num_workers: 20 +logger: wandb +objective: pred_x0 +save_and_sample_every: 1000 +denoising_fn: Unet3D +train_lr: 1e-4 +timesteps: 2 # number of steps +sampling_timesteps: 250 # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) +loss_type: l1 # L1 or L2 +train_num_steps: 700000 # total training steps +gradient_accumulate_every: 2 # gradient accumulation steps +ema_decay: 0.995 # exponential moving average decay +amp: False # turn on mixed precision +num_sample_rows: 1 +gpus: 0 + diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/diffusion_config/vq_gan_3d.yaml b/Generation_Pipeline_filter/syn_colon/TumorGeneration/diffusion_config/vq_gan_3d.yaml new file mode 100644 index 0000000000000000000000000000000000000000..29940377c5cadb0c06322de8ac60d0713f799024 --- /dev/null +++ b/Generation_Pipeline_filter/syn_colon/TumorGeneration/diffusion_config/vq_gan_3d.yaml @@ -0,0 +1,37 @@ +seed: 1234 +batch_size: 2 # 30 +num_workers: 32 # 30 + +gpus: 1 +accumulate_grad_batches: 1 +default_root_dir: checkpoints/vq_gan/ +default_root_dir_postfix: 'flair' +resume_from_checkpoint: +max_steps: -1 +max_epochs: -1 +precision: 16 +gradient_clip_val: 1.0 + + +embedding_dim: 8 # 256 +n_codes: 16384 # 2048 +n_hiddens: 16 +lr: 3e-4 +downsample: [2, 2, 2] # [4, 4, 4] +disc_channels: 64 +disc_layers: 3 +discriminator_iter_start: 10000 # 50000 +disc_loss_type: hinge +image_gan_weight: 1.0 +video_gan_weight: 1.0 +l1_weight: 4.0 +gan_feat_weight: 4.0 # 0.0 +perceptual_weight: 4.0 # 0.0 +i3d_feat: False +restart_thres: 1.0 +no_random_restart: False +norm_type: group +padding_type: replicate +num_groups: 32 + + diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/__init__.py b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a01336e46e37471097d8e1420d0ffd5d803a1edd --- /dev/null +++ b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/__init__.py @@ -0,0 +1 @@ +from .diffusion import Unet3D, GaussianDiffusion, Tester diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..800688be0198f7f33c5329c5a467a39ab6f58611 Binary files /dev/null and b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e795973913e71f310074432d53ecbe72a127b9a2 Binary files /dev/null and b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37a2c3b0735cf9b01c20cfb80ae9ced9228687c4 Binary files /dev/null and b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac6f099866a2f7b7329ec1923619ceb7a8114e14 Binary files /dev/null and b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4a010ee9f6d95cbccb3c1cb4897eb530d556419 Binary files /dev/null and b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/ddim.py b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..1bc51f94355732aedb0ffd254b8166144c468370 --- /dev/null +++ b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/ddim.py @@ -0,0 +1,206 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from .util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): # "uniform" 'quad' + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + # breakpoint() + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, T, H, W = shape + # breakpoint() + size = (batch_size, C, T, H, W) + # print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + # iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(time_range): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + # breakpoint() + e_t = self.model.denoise_fn(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.denoise_fn(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/diffusion.py b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..87111dc3b2ab671e798efc3a320ae81d024f20b9 --- /dev/null +++ b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/diffusion.py @@ -0,0 +1,1016 @@ +"Largely taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch" + +import math +import copy +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial + +from torch.utils import data +from pathlib import Path +from torch.optim import Adam +from torchvision import transforms as T, utils +from torch.cuda.amp import autocast, GradScaler +from PIL import Image + +from tqdm import tqdm +from einops import rearrange +from einops_exts import check_shape, rearrange_many + +from rotary_embedding_torch import RotaryEmbedding + +from .text import tokenize, bert_embed, BERT_MODEL_DIM +from torch.utils.data import Dataset, DataLoader +from ..vq_gan_3d.model.vqgan import VQGAN + +import matplotlib.pyplot as plt + +# helpers functions + + +def exists(x): + return x is not None + + +def noop(*args, **kwargs): + pass + + +def is_odd(n): + return (n % 2) == 1 + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +def cycle(dl): + while True: + for data in dl: + yield data + + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + + +def prob_mask_like(shape, prob, device): + if prob == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif prob == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob + + +def is_list_str(x): + if not isinstance(x, (list, tuple)): + return False + return all([type(el) == str for el in x]) + +# relative positional bias + + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads=8, + num_buckets=32, + max_distance=128 + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / + max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min( + val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype=torch.long, device=device) + k_pos = torch.arange(n, dtype=torch.long, device=device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket( + rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') + +# small helper modules + + +class EMA(): + def __init__(self, beta): + super().__init__() + self.beta = beta + + def update_model_average(self, ma_model, current_model): + for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + old_weight, up_weight = ma_params.data, current_params.data + ma_params.data = self.update_average(old_weight, up_weight) + + def update_average(self, old, new): + if old is None: + return new + return old * self.beta + (1 - self.beta) * new + + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +def Upsample(dim): + return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) + + +def Downsample(dim): + return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) + + +class LayerNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1)) + + def forward(self, x): + var = torch.var(x, dim=1, unbiased=False, keepdim=True) + mean = torch.mean(x, dim=1, keepdim=True) + return (x - mean) / (var + self.eps).sqrt() * self.gamma + + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = LayerNorm(dim) + + def forward(self, x, **kwargs): + x = self.norm(x) + return self.fn(x, **kwargs) + +# building block modules + + +class Block(nn.Module): + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding=(0, 1, 1)) + self.norm = nn.GroupNorm(groups, dim_out) + self.act = nn.SiLU() + + def forward(self, x, scale_shift=None): + x = self.proj(x) + x = self.norm(x) + + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + + return self.act(x) + + +class ResnetBlock(nn.Module): + def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): + super().__init__() + self.mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(time_emb_dim, dim_out * 2) + ) if exists(time_emb_dim) else None + + self.block1 = Block(dim, dim_out, groups=groups) + self.block2 = Block(dim_out, dim_out, groups=groups) + self.res_conv = nn.Conv3d( + dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb=None): + + scale_shift = None + if exists(self.mlp): + assert exists(time_emb), 'time emb must be passed in' + time_emb = self.mlp(time_emb) + time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') + scale_shift = time_emb.chunk(2, dim=1) + + h = self.block1(x, scale_shift=scale_shift) + + h = self.block2(h) + return h + self.res_conv(x) + + +class SpatialLinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, f, h, w = x.shape + x = rearrange(x, 'b c f h w -> (b f) c h w') + + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = rearrange_many( + qkv, 'b (h c) x y -> b h c (x y)', h=self.heads) + + q = q.softmax(dim=-2) + k = k.softmax(dim=-1) + + q = q * self.scale + context = torch.einsum('b h d n, b h e n -> b h d e', k, v) + + out = torch.einsum('b h d e, b h d n -> b h e n', context, q) + out = rearrange(out, 'b h c (x y) -> b (h c) x y', + h=self.heads, x=h, y=w) + out = self.to_out(out) + return rearrange(out, '(b f) c h w -> b c f h w', b=b) + +# attention along space and time + + +class EinopsToAndFrom(nn.Module): + def __init__(self, from_einops, to_einops, fn): + super().__init__() + self.from_einops = from_einops + self.to_einops = to_einops + self.fn = fn + + def forward(self, x, **kwargs): + shape = x.shape + reconstitute_kwargs = dict( + tuple(zip(self.from_einops.split(' '), shape))) + x = rearrange(x, f'{self.from_einops} -> {self.to_einops}') + x = self.fn(x, **kwargs) + x = rearrange( + x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + heads=4, + dim_head=32, + rotary_emb=None + ): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + + self.rotary_emb = rotary_emb + self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False) + self.to_out = nn.Linear(hidden_dim, dim, bias=False) + + def forward( + self, + x, + pos_bias=None, + focus_present_mask=None + ): + n, device = x.shape[-2], x.device + + qkv = self.to_qkv(x).chunk(3, dim=-1) + + if exists(focus_present_mask) and focus_present_mask.all(): + # if all batch samples are focusing on present + # it would be equivalent to passing that token's values through to the output + values = qkv[-1] + return self.to_out(values) + + # split out heads + + q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads) + + # scale + + q = q * self.scale + + # rotate positions into queries and keys for time attention + + if exists(self.rotary_emb): + q = self.rotary_emb.rotate_queries_or_keys(q) + k = self.rotary_emb.rotate_queries_or_keys(k) + + # similarity + + sim = einsum('... h i d, ... h j d -> ... h i j', q, k) + + # relative positional bias + + if exists(pos_bias): + sim = sim + pos_bias + + if exists(focus_present_mask) and not (~focus_present_mask).all(): + attend_all_mask = torch.ones( + (n, n), device=device, dtype=torch.bool) + attend_self_mask = torch.eye(n, device=device, dtype=torch.bool) + + mask = torch.where( + rearrange(focus_present_mask, 'b -> b 1 1 1 1'), + rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), + rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), + ) + + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + # numerical stability + + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + # aggregate values + + out = einsum('... h i j, ... h j d -> ... h i d', attn, v) + out = rearrange(out, '... h n d -> ... n (h d)') + return self.to_out(out) + +# model + + +class Unet3D(nn.Module): + def __init__( + self, + dim, + cond_dim=None, + out_dim=None, + dim_mults=(1, 2, 4, 8), + channels=3, + attn_heads=8, + attn_dim_head=32, + use_bert_text_cond=False, + init_dim=None, + init_kernel_size=7, + use_sparse_linear_attn=True, + block_type='resnet', + resnet_groups=8 + ): + super().__init__() + self.channels = channels + + # temporal attention and its relative positional encoding + + rotary_emb = RotaryEmbedding(min(32, attn_dim_head)) + + def temporal_attn(dim): return EinopsToAndFrom('b c f h w', 'b (h w) f c', Attention( + dim, heads=attn_heads, dim_head=attn_dim_head, rotary_emb=rotary_emb)) + + # realistically will not be able to generate that many frames of video... yet + self.time_rel_pos_bias = RelativePositionBias( + heads=attn_heads, max_distance=32) + + # initial conv + + init_dim = default(init_dim, dim) + assert is_odd(init_kernel_size) + + init_padding = init_kernel_size // 2 + self.init_conv = nn.Conv3d(channels, init_dim, (1, init_kernel_size, + init_kernel_size), padding=(0, init_padding, init_padding)) + + self.init_temporal_attn = Residual( + PreNorm(init_dim, temporal_attn(init_dim))) + + # dimensions + + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + # time conditioning + + time_dim = dim * 4 + self.time_mlp = nn.Sequential( + SinusoidalPosEmb(dim), + nn.Linear(dim, time_dim), + nn.GELU(), + nn.Linear(time_dim, time_dim) + ) + + # text conditioning + + self.has_cond = exists(cond_dim) or use_bert_text_cond + cond_dim = BERT_MODEL_DIM if use_bert_text_cond else cond_dim + + self.null_cond_emb = nn.Parameter( + torch.randn(1, cond_dim)) if self.has_cond else None + + cond_dim = time_dim + int(cond_dim or 0) + + # layers + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + + num_resolutions = len(in_out) + # block type + + block_klass = partial(ResnetBlock, groups=resnet_groups) + block_klass_cond = partial(block_klass, time_emb_dim=cond_dim) + + # modules for all layers + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + + self.downs.append(nn.ModuleList([ + block_klass_cond(dim_in, dim_out), + block_klass_cond(dim_out, dim_out), + Residual(PreNorm(dim_out, SpatialLinearAttention( + dim_out, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(), + Residual(PreNorm(dim_out, temporal_attn(dim_out))), + Downsample(dim_out) if not is_last else nn.Identity() + ])) + + mid_dim = dims[-1] + self.mid_block1 = block_klass_cond(mid_dim, mid_dim) + + spatial_attn = EinopsToAndFrom( + 'b c f h w', 'b f (h w) c', Attention(mid_dim, heads=attn_heads)) + + self.mid_spatial_attn = Residual(PreNorm(mid_dim, spatial_attn)) + self.mid_temporal_attn = Residual( + PreNorm(mid_dim, temporal_attn(mid_dim))) + + self.mid_block2 = block_klass_cond(mid_dim, mid_dim) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): + is_last = ind >= (num_resolutions - 1) + + self.ups.append(nn.ModuleList([ + block_klass_cond(dim_out * 2, dim_in), + block_klass_cond(dim_in, dim_in), + Residual(PreNorm(dim_in, SpatialLinearAttention( + dim_in, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(), + Residual(PreNorm(dim_in, temporal_attn(dim_in))), + Upsample(dim_in) if not is_last else nn.Identity() + ])) + + out_dim = default(out_dim, channels) + self.final_conv = nn.Sequential( + block_klass(dim * 2, dim), + nn.Conv3d(dim, out_dim, 1) + ) + + def forward_with_cond_scale( + self, + *args, + cond_scale=2., + **kwargs + ): + logits = self.forward(*args, null_cond_prob=0., **kwargs) + if cond_scale == 1 or not self.has_cond: + return logits + + null_logits = self.forward(*args, null_cond_prob=1., **kwargs) + return null_logits + (logits - null_logits) * cond_scale + + def forward( + self, + x, + time, + cond=None, + null_cond_prob=0., + focus_present_mask=None, + # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time) + prob_focus_present=0. + ): + assert not (self.has_cond and not exists(cond) + ), 'cond must be passed in if cond_dim specified' + x = torch.cat([x, cond], dim=1) + + batch, device = x.shape[0], x.device + + focus_present_mask = default(focus_present_mask, lambda: prob_mask_like( + (batch,), prob_focus_present, device=device)) + + time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device=x.device) + + x = self.init_conv(x) + r = x.clone() + + x = self.init_temporal_attn(x, pos_bias=time_rel_pos_bias) + + t = self.time_mlp(time) if exists(self.time_mlp) else None # [2, 128] + + # classifier free guidance + + if self.has_cond: + batch, device = x.shape[0], x.device + mask = prob_mask_like((batch,), null_cond_prob, device=device) + cond = torch.where(rearrange(mask, 'b -> b 1'), + self.null_cond_emb, cond) + t = torch.cat((t, cond), dim=-1) + + h = [] + + for block1, block2, spatial_attn, temporal_attn, downsample in self.downs: + x = block1(x, t) + x = block2(x, t) + x = spatial_attn(x) + x = temporal_attn(x, pos_bias=time_rel_pos_bias, + focus_present_mask=focus_present_mask) + h.append(x) + x = downsample(x) + + # [2, 256, 32, 4, 4] + x = self.mid_block1(x, t) + x = self.mid_spatial_attn(x) + x = self.mid_temporal_attn( + x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask) + x = self.mid_block2(x, t) + + for block1, block2, spatial_attn, temporal_attn, upsample in self.ups: + x = torch.cat((x, h.pop()), dim=1) + x = block1(x, t) + x = block2(x, t) + x = spatial_attn(x) + x = temporal_attn(x, pos_bias=time_rel_pos_bias, + focus_present_mask=focus_present_mask) + x = upsample(x) + + x = torch.cat((x, r), dim=1) + return self.final_conv(x) + +# gaussian diffusion trainer class + + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = torch.linspace(0, timesteps, steps, dtype=torch.float64) + alphas_cumprod = torch.cos( + ((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return torch.clip(betas, 0, 0.9999) + + +class GaussianDiffusion(nn.Module): + def __init__( + self, + denoise_fn, + *, + image_size, + num_frames, + text_use_bert_cls=False, + channels=3, + timesteps=1000, + loss_type='l1', + use_dynamic_thres=False, # from the Imagen paper + dynamic_thres_percentile=0.9, + vqgan_ckpt=None, + device=None + ): + super().__init__() + self.channels = channels + self.image_size = image_size + self.num_frames = num_frames + self.denoise_fn = denoise_fn + self.device = device + + if vqgan_ckpt: + self.vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt).cuda() + self.vqgan.eval() + else: + self.vqgan = None + + betas = cosine_beta_schedule(timesteps) + + alphas = 1. - betas + alphas_cumprod = torch.cumprod(alphas, axis=0) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.loss_type = loss_type + + # register buffer helper function that casts float64 to float32 + + def register_buffer(name, val): return self.register_buffer( + name, val.to(torch.float32)) + + register_buffer('betas', betas) + register_buffer('alphas_cumprod', alphas_cumprod) + register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) + + # calculations for diffusion q(x_t | x_{t-1}) and others + + register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) + register_buffer('sqrt_one_minus_alphas_cumprod', + torch.sqrt(1. - alphas_cumprod)) + register_buffer('log_one_minus_alphas_cumprod', + torch.log(1. - alphas_cumprod)) + register_buffer('sqrt_recip_alphas_cumprod', + torch.sqrt(1. / alphas_cumprod)) + register_buffer('sqrt_recipm1_alphas_cumprod', + torch.sqrt(1. / alphas_cumprod - 1)) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + + posterior_variance = betas * \ + (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + + register_buffer('posterior_variance', posterior_variance) + + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + + register_buffer('posterior_log_variance_clipped', + torch.log(posterior_variance.clamp(min=1e-20))) + register_buffer('posterior_mean_coef1', betas * + torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) + register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) + * torch.sqrt(alphas) / (1. - alphas_cumprod)) + + # text conditioning parameters + + self.text_use_bert_cls = text_use_bert_cls + + # dynamic thresholding when sampling + + self.use_dynamic_thres = use_dynamic_thres + self.dynamic_thres_percentile = dynamic_thres_percentile + + def q_mean_variance(self, x_start, t): + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract(1. - self.alphas_cumprod, t, x_start.shape) + log_variance = extract( + self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract( + self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool, cond=None, cond_scale=1.): + x_recon = self.predict_start_from_noise( + x, t=t, noise=self.denoise_fn.forward_with_cond_scale(x, t, cond=cond, cond_scale=cond_scale)) + + if clip_denoised: + s = 1. + if self.use_dynamic_thres: + s = torch.quantile( + rearrange(x_recon, 'b ... -> b (...)').abs(), + self.dynamic_thres_percentile, + dim=-1 + ) + + s.clamp_(min=1.) + s = s.view(-1, *((1,) * (x_recon.ndim - 1))) + + # clip by threshold, depending on whether static or dynamic + x_recon = x_recon.clamp(-s, s) / s + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.inference_mode() + def p_sample(self, x, t, cond=None, cond_scale=1., clip_denoised=True): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance( + x=x, t=t, clip_denoised=clip_denoised, cond=cond, cond_scale=cond_scale) + noise = torch.randn_like(x) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, + *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.inference_mode() + def p_sample_loop(self, shape, cond=None, cond_scale=1.): + device = self.betas.device + + b = shape[0] + img = torch.randn(shape, device=device) + # print('cond', cond.shape) + for i in reversed(range(0, self.num_timesteps)): + img = self.p_sample(img, torch.full( + (b,), i, device=device, dtype=torch.long), cond=cond, cond_scale=cond_scale) + + return img + + @torch.inference_mode() + def sample(self, cond=None, cond_scale=1., batch_size=16): + device = next(self.denoise_fn.parameters()).device + + if is_list_str(cond): + cond = bert_embed(tokenize(cond)).to(device) + + # batch_size = cond.shape[0] if exists(cond) else batch_size + batch_size = batch_size + image_size = self.image_size + channels = 8 # self.channels + num_frames = self.num_frames + # print((batch_size, channels, num_frames, image_size, image_size)) + # print('cond_',cond.shape) + _sample = self.p_sample_loop( + (batch_size, channels, num_frames, image_size, image_size), cond=cond, cond_scale=cond_scale) + + if isinstance(self.vqgan, VQGAN): + # denormalize TODO: Remove eventually + _sample = (((_sample + 1.0) / 2.0) * (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) + self.vqgan.codebook.embeddings.min() + + _sample = self.vqgan.decode(_sample, quantize=True) + else: + unnormalize_img(_sample) + + return _sample + + @torch.inference_mode() + def interpolate(self, x1, x2, t=None, lam=0.5): + b, *_, device = *x1.shape, x1.device + t = default(t, self.num_timesteps - 1) + + assert x1.shape == x2.shape + + t_batched = torch.stack([torch.tensor(t, device=device)] * b) + xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) + + img = (1 - lam) * xt1 + lam * xt2 + for i in reversed(range(0, t)): + img = self.p_sample(img, torch.full( + (b,), i, device=device, dtype=torch.long)) + + return img + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, + t, x_start.shape) * noise + ) + + def p_losses(self, x_start, t, cond=None, noise=None, **kwargs): + b, c, f, h, w, device = *x_start.shape, x_start.device + noise = default(noise, lambda: torch.randn_like(x_start)) + # breakpoint() + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # [2, 8, 32, 32, 32] + + if is_list_str(cond): + cond = bert_embed( + tokenize(cond), return_cls_repr=self.text_use_bert_cls) + cond = cond.to(device) + + x_recon = self.denoise_fn(x_noisy, t, cond=cond, **kwargs) + + if self.loss_type == 'l1': + loss = F.l1_loss(noise, x_recon) + elif self.loss_type == 'l2': + loss = F.mse_loss(noise, x_recon) + else: + raise NotImplementedError() + + return loss + + def forward(self, x, *args, **kwargs): + bs = int(x.shape[0]/2) + img=x[:bs,...] + mask=x[bs:,...] + mask_=(1-mask).detach() + masked_img = (img*mask_).detach() + masked_img=masked_img.permute(0,1,-1,-3,-2) + img=img.permute(0,1,-1,-3,-2) + mask=mask.permute(0,1,-1,-3,-2) + # breakpoint() + if isinstance(self.vqgan, VQGAN): + with torch.no_grad(): + img = self.vqgan.encode( + img, quantize=False, include_embeddings=True) + # normalize to -1 and 1 + img = ((img - self.vqgan.codebook.embeddings.min()) / + (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + masked_img = self.vqgan.encode( + masked_img, quantize=False, include_embeddings=True) + # normalize to -1 and 1 + masked_img = ((masked_img - self.vqgan.codebook.embeddings.min()) / + (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + else: + print("Hi") + img = normalize_img(img) + masked_img = normalize_img(masked_img) + mask = mask*2.0 - 1.0 + cc = torch.nn.functional.interpolate(mask, size=masked_img.shape[-3:]) + cond = torch.cat((masked_img, cc), dim=1) + + b, device, img_size, = img.shape[0], img.device, self.image_size + t = torch.randint(0, self.num_timesteps, (b,), device=device).long() + # breakpoint() + return self.p_losses(img, t, cond=cond, *args, **kwargs) + +# trainer class + + +CHANNELS_TO_MODE = { + 1: 'L', + 3: 'RGB', + 4: 'RGBA' +} + + +def seek_all_images(img, channels=3): + assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid' + mode = CHANNELS_TO_MODE[channels] + + i = 0 + while True: + try: + img.seek(i) + yield img.convert(mode) + except EOFError: + break + i += 1 + +# tensor of shape (channels, frames, height, width) -> gif + + +def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True): + tensor = ((tensor - tensor.min()) / (tensor.max() - tensor.min())) * 1.0 + images = map(T.ToPILImage(), tensor.unbind(dim=1)) + first_img, *rest_imgs = images + first_img.save(path, save_all=True, append_images=rest_imgs, + duration=duration, loop=loop, optimize=optimize) + return images + +# gif -> (channels, frame, height, width) tensor + + +def gif_to_tensor(path, channels=3, transform=T.ToTensor()): + img = Image.open(path) + tensors = tuple(map(transform, seek_all_images(img, channels=channels))) + return torch.stack(tensors, dim=1) + + +def identity(t, *args, **kwargs): + return t + + +def normalize_img(t): + return t * 2 - 1 + + +def unnormalize_img(t): + return (t + 1) * 0.5 + + +def cast_num_frames(t, *, frames): + f = t.shape[1] + + if f == frames: + return t + + if f > frames: + return t[:, :frames] + + return F.pad(t, (0, 0, 0, 0, 0, frames - f)) + + +class Dataset(data.Dataset): + def __init__( + self, + folder, + image_size, + channels=3, + num_frames=16, + horizontal_flip=False, + force_num_frames=True, + exts=['gif'] + ): + super().__init__() + self.folder = folder + self.image_size = image_size + self.channels = channels + self.paths = [p for ext in exts for p in Path( + f'{folder}').glob(f'**/*.{ext}')] + + self.cast_num_frames_fn = partial( + cast_num_frames, frames=num_frames) if force_num_frames else identity + + self.transform = T.Compose([ + T.Resize(image_size), + T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity), + T.CenterCrop(image_size), + T.ToTensor() + ]) + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = self.paths[index] + tensor = gif_to_tensor(path, self.channels, transform=self.transform) + return self.cast_num_frames_fn(tensor) + +# trainer class + + +class Tester(object): + def __init__( + self, + diffusion_model, + ): + super().__init__() + self.model = diffusion_model + self.ema_model = copy.deepcopy(self.model) + self.step=0 + self.image_size = diffusion_model.image_size + + self.reset_parameters() + + def reset_parameters(self): + self.ema_model.load_state_dict(self.model.state_dict()) + + + def load(self, milestone, map_location=None, **kwargs): + if milestone == -1: + all_milestones = [int(p.stem.split('-')[-1]) + for p in Path(self.results_folder).glob('**/*.pt')] + assert len( + all_milestones) > 0, 'need to have at least one milestone to load from latest checkpoint (milestone == -1)' + milestone = max(all_milestones) + + if map_location: + data = torch.load(milestone, map_location=map_location) + else: + data = torch.load(milestone) + + self.step = data['step'] + self.model.load_state_dict(data['model'], **kwargs) + self.ema_model.load_state_dict(data['ema'], **kwargs) + + diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/text.py b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/text.py new file mode 100644 index 0000000000000000000000000000000000000000..42dca78bec5075fb4f59c522aa3d00cc395d7536 --- /dev/null +++ b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/text.py @@ -0,0 +1,94 @@ +"Taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch" + +import torch +from einops import rearrange + + +def exists(val): + return val is not None + +# singleton globals + + +MODEL = None +TOKENIZER = None +BERT_MODEL_DIM = 768 + + +def get_tokenizer(): + global TOKENIZER + if not exists(TOKENIZER): + TOKENIZER = torch.hub.load( + 'huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased') + return TOKENIZER + + +def get_bert(): + global MODEL + if not exists(MODEL): + MODEL = torch.hub.load( + 'huggingface/pytorch-transformers', 'model', 'bert-base-cased') + if torch.cuda.is_available(): + MODEL = MODEL.cuda() + + return MODEL + +# tokenize + + +def tokenize(texts, add_special_tokens=True): + if not isinstance(texts, (list, tuple)): + texts = [texts] + + tokenizer = get_tokenizer() + + encoding = tokenizer.batch_encode_plus( + texts, + add_special_tokens=add_special_tokens, + padding=True, + return_tensors='pt' + ) + + token_ids = encoding.input_ids + return token_ids + +# embedding function + + +@torch.no_grad() +def bert_embed( + token_ids, + return_cls_repr=False, + eps=1e-8, + pad_id=0. +): + model = get_bert() + mask = token_ids != pad_id + + if torch.cuda.is_available(): + token_ids = token_ids.cuda() + mask = mask.cuda() + + outputs = model( + input_ids=token_ids, + attention_mask=mask, + output_hidden_states=True + ) + + hidden_state = outputs.hidden_states[-1] + + if return_cls_repr: + # return [cls] as representation + return hidden_state[:, 0] + + if not exists(mask): + return hidden_state.mean(dim=1) + + # mean all tokens excluding [cls], accounting for length + mask = mask[:, 1:] + mask = rearrange(mask, 'b n -> b n 1') + + numer = (hidden_state[:, 1:] * mask).sum(dim=1) + denom = mask.sum(dim=1) + masked_mean = numer / (denom + eps) + return diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/time_embedding.py b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/time_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..7c911fa4f485295005cda9c9c2099db3ddbda15c --- /dev/null +++ b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/time_embedding.py @@ -0,0 +1,75 @@ +import math +import torch +import torch.nn as nn + +from monai.networks.layers.utils import get_act_layer + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, emb_dim=16, downscale_freq_shift=1, max_period=10000, flip_sin_to_cos=False): + super().__init__() + self.emb_dim = emb_dim + self.downscale_freq_shift = downscale_freq_shift + self.max_period = max_period + self.flip_sin_to_cos = flip_sin_to_cos + + def forward(self, x): + device = x.device + half_dim = self.emb_dim // 2 + emb = math.log(self.max_period) / \ + (half_dim - self.downscale_freq_shift) + emb = torch.exp(-emb*torch.arange(half_dim, device=device)) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + + if self.flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + if self.emb_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class LearnedSinusoidalPosEmb(nn.Module): + """ following @crowsonkb 's lead with learned sinusoidal pos emb """ + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, emb_dim): + super().__init__() + self.emb_dim = emb_dim + half_dim = emb_dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x): + x = x[:, None] + freqs = x * self.weights[None, :] * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + if self.emb_dim % 2 == 1: + fouriered = torch.nn.functional.pad(fouriered, (0, 1, 0, 0)) + return fouriered + + +class TimeEmbbeding(nn.Module): + def __init__( + self, + emb_dim=64, + pos_embedder=SinusoidalPosEmb, + pos_embedder_kwargs={}, + act_name=("SWISH", {}) # Swish = SiLU + ): + super().__init__() + self.emb_dim = emb_dim + self.pos_emb_dim = pos_embedder_kwargs.get('emb_dim', emb_dim//4) + pos_embedder_kwargs['emb_dim'] = self.pos_emb_dim + self.pos_embedder = pos_embedder(**pos_embedder_kwargs) + + self.time_emb = nn.Sequential( + self.pos_embedder, + nn.Linear(self.pos_emb_dim, self.emb_dim), + get_act_layer(act_name), + nn.Linear(self.emb_dim, self.emb_dim) + ) + + def forward(self, time): + return self.time_emb(time) diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/unet.py b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..e9cae8021e874a92b6baaaf4decd802516c2dc87 --- /dev/null +++ b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/unet.py @@ -0,0 +1,226 @@ +from ddpm.time_embedding import TimeEmbbeding + +import monai.networks.nets as nets +import torch +import torch.nn as nn +from einops import rearrange + +from monai.networks.blocks import UnetBasicBlock, UnetResBlock, UnetUpBlock, Convolution, UnetOutBlock +from monai.networks.layers.utils import get_act_layer + + +class DownBlock(nn.Module): + def __init__( + self, + spatial_dims, + in_ch, + out_ch, + time_emb_dim, + cond_emb_dim, + act_name=("swish", {}), + **kwargs): + super(DownBlock, self).__init__() + self.loca_time_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(time_emb_dim, in_ch) # in_ch * 2 + ) + if cond_emb_dim is not None: + self.loca_cond_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(cond_emb_dim, in_ch), + ) + self.down_op = UnetBasicBlock( + spatial_dims, in_ch, out_ch, act_name=act_name, **kwargs) + + def forward(self, x, time_emb, cond_emb): + b, c, *_ = x.shape + sp_dim = x.ndim-2 + + # ------------ Time ---------- + time_emb = self.loca_time_embedder(time_emb) + time_emb = time_emb.reshape(b, c, *((1,)*sp_dim)) + # scale, shift = time_emb.chunk(2, dim = 1) + + # ------------ Combine ------------ + # x = x * (scale + 1) + shift + x = x + time_emb + + # ----------- Condition ------------ + if cond_emb is not None: + cond_emb = self.loca_cond_embedder(cond_emb) + cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim)) + x = x + cond_emb + + # ----------- Image --------- + y = self.down_op(x) + return y + + +class UpBlock(nn.Module): + def __init__( + self, + spatial_dims, + skip_ch, + enc_ch, + time_emb_dim, + cond_emb_dim, + act_name=("swish", {}), + **kwargs): + super(UpBlock, self).__init__() + self.up_op = UnetUpBlock(spatial_dims, enc_ch, + skip_ch, act_name=act_name, **kwargs) + self.loca_time_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(time_emb_dim, skip_ch * 2), + ) + if cond_emb_dim is not None: + self.loca_cond_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(cond_emb_dim, skip_ch * 2), + ) + + def forward(self, x_skip, x_enc, time_emb, cond_emb): + b, c, *_ = x_enc.shape + sp_dim = x_enc.ndim-2 + + # ----------- Time -------------- + time_emb = self.loca_time_embedder(time_emb) + time_emb = time_emb.reshape(b, c, *((1,)*sp_dim)) + # scale, shift = time_emb.chunk(2, dim = 1) + + # -------- Combine ------------- + # y = x * (scale + 1) + shift + x_enc = x_enc + time_emb + + # ----------- Condition ------------ + if cond_emb is not None: + cond_emb = self.loca_cond_embedder(cond_emb) + cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim)) + x_enc = x_enc + cond_emb + + # ----------- Image ------------- + y = self.up_op(x_enc, x_skip) + + # -------- Combine ------------- + # y = y * (scale + 1) + shift + + return y + + +class UNet(nn.Module): + + def __init__(self, + in_ch=1, + out_ch=1, + spatial_dims=3, + hid_chs=[32, 64, 128, 256, 512], + kernel_sizes=[(1, 3, 3), (1, 3, 3), (1, 3, 3), 3, 3], + strides=[1, (1, 2, 2), (1, 2, 2), 2, 2], + upsample_kernel_sizes=None, + act_name=("SWISH", {}), + norm_name=("INSTANCE", {"affine": True}), + time_embedder=TimeEmbbeding, + time_embedder_kwargs={}, + cond_embedder=None, + cond_embedder_kwargs={}, + # True = all but last layer, 0/False=disable, 1=only first layer, ... + deep_ver_supervision=True, + estimate_variance=False, + use_self_conditioning=False, + **kwargs + ): + super().__init__() + if upsample_kernel_sizes is None: + upsample_kernel_sizes = strides[1:] + + # ------------- Time-Embedder----------- + self.time_embedder = time_embedder(**time_embedder_kwargs) + + # ------------- Condition-Embedder----------- + if cond_embedder is not None: + self.cond_embedder = cond_embedder(**cond_embedder_kwargs) + cond_emb_dim = self.cond_embedder.emb_dim + else: + self.cond_embedder = None + cond_emb_dim = None + + # ----------- In-Convolution ------------ + in_ch = in_ch*2 if use_self_conditioning else in_ch + self.inc = UnetBasicBlock(spatial_dims, in_ch, hid_chs[0], kernel_size=kernel_sizes[0], stride=strides[0], + act_name=act_name, norm_name=norm_name, **kwargs) + + # ----------- Encoder ---------------- + self.encoders = nn.ModuleList([ + DownBlock(spatial_dims, hid_chs[i-1], hid_chs[i], time_emb_dim=self.time_embedder.emb_dim, + cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[ + i], stride=strides[i], act_name=act_name, + norm_name=norm_name, **kwargs) + for i in range(1, len(strides)) + ]) + + # ------------ Decoder ---------- + self.decoders = nn.ModuleList([ + UpBlock(spatial_dims, hid_chs[i], hid_chs[i+1], time_emb_dim=self.time_embedder.emb_dim, + cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[i + + 1], stride=strides[i+1], act_name=act_name, + norm_name=norm_name, upsample_kernel_size=upsample_kernel_sizes[i], **kwargs) + for i in range(len(strides)-1) + ]) + + # --------------- Out-Convolution ---------------- + out_ch_hor = out_ch*2 if estimate_variance else out_ch + self.outc = UnetOutBlock( + spatial_dims, hid_chs[0], out_ch_hor, dropout=None) + if isinstance(deep_ver_supervision, bool): + deep_ver_supervision = len( + strides)-2 if deep_ver_supervision else 0 + self.outc_ver = nn.ModuleList([ + UnetOutBlock(spatial_dims, hid_chs[i], out_ch, dropout=None) + for i in range(1, deep_ver_supervision+1) + ]) + + def forward(self, x_t, t, cond=None, self_cond=None, **kwargs): + condition = cond + # x_t [B, C, (D), H, W] + # t [B,] + + # -------- In-Convolution -------------- + x = [None for _ in range(len(self.encoders)+1)] + x_t = torch.cat([x_t, self_cond], + dim=1) if self_cond is not None else x_t + x[0] = self.inc(x_t) + + # -------- Time Embedding (Gloabl) ----------- + time_emb = self.time_embedder(t) # [B, C] + + # -------- Condition Embedding (Gloabl) ----------- + if (condition is None) or (self.cond_embedder is None): + cond_emb = None + else: + cond_emb = self.cond_embedder(condition) # [B, C] + + # --------- Encoder -------------- + for i in range(len(self.encoders)): + x[i+1] = self.encoders[i](x[i], time_emb, cond_emb) + + # -------- Decoder ----------- + for i in range(len(self.decoders), 0, -1): + x[i-1] = self.decoders[i-1](x[i-1], x[i], time_emb, cond_emb) + + # ---------Out-Convolution ------------ + y_hor = self.outc(x[0]) + y_ver = [outc_ver_i(x[i+1]) + for i, outc_ver_i in enumerate(self.outc_ver)] + + return y_hor # , y_ver + + def forward_with_cond_scale(self, *args, cond_scale=0., **kwargs): + return self.forward(*args, **kwargs) + + +if __name__ == '__main__': + model = UNet(in_ch=3) + input = torch.randn((1, 3, 16, 128, 128)) + time = torch.randn((1,)) + out_hor, out_ver = model(input, time) + print(out_hor[0].shape) diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/util.py b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..2047cf56802046bbf37cfd33f819ed75a239a7df --- /dev/null +++ b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/ddpm/util.py @@ -0,0 +1,271 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +# from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + if c != 1: + steps_out = ddim_timesteps + 1 + else: + steps_out = ddim_timesteps + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac7bdfd0edc6757534df4352a2952faf3c5c588b Binary files /dev/null and b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__init__.py b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38c08845c1ee2f7b4b52bf95fa282b0808931a3a --- /dev/null +++ b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__init__.py @@ -0,0 +1,3 @@ +from .vqgan import VQGAN +from .codebook import Codebook +from .lpips import LPIPS diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62702f0203b5aca1df69361cf1271d4659fc63e8 Binary files /dev/null and b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..809398d341751492f5da4cb0646ed6a5e2a58fd1 Binary files /dev/null and b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..832b32b8e9b977c5dfdf1fff0f125c414e662370 Binary files /dev/null and b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef6892b233146a743404465dc24fd5974e4e72c5 Binary files /dev/null and b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth new file mode 100644 index 0000000000000000000000000000000000000000..f57dcf5cc764d61c8a460365847fb2137ff0a62d --- /dev/null +++ b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868 +size 7289 diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/codebook.py b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/codebook.py new file mode 100644 index 0000000000000000000000000000000000000000..c17ed701a2824cc0dcd75670d47a6ab3842e9d35 --- /dev/null +++ b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/codebook.py @@ -0,0 +1,109 @@ +""" Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from ..utils import shift_dim + + +class Codebook(nn.Module): + def __init__(self, n_codes, embedding_dim, no_random_restart=False, restart_thres=1.0): + super().__init__() + self.register_buffer('embeddings', torch.randn(n_codes, embedding_dim)) + self.register_buffer('N', torch.zeros(n_codes)) + self.register_buffer('z_avg', self.embeddings.data.clone()) + + self.n_codes = n_codes + self.embedding_dim = embedding_dim + self._need_init = True + self.no_random_restart = no_random_restart + self.restart_thres = restart_thres + + def _tile(self, x): + d, ew = x.shape + if d < self.n_codes: + n_repeats = (self.n_codes + d - 1) // d + std = 0.01 / np.sqrt(ew) + x = x.repeat(n_repeats, 1) + x = x + torch.randn_like(x) * std + return x + + def _init_embeddings(self, z): + # z: [b, c, t, h, w] + self._need_init = False + breakpoint() + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [65536, 8] [2, 8, 32, 32, 32] + y = self._tile(flat_inputs) # [65536, 8] + + d = y.shape[0] + _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + self.embeddings.data.copy_(_k_rand) + self.z_avg.data.copy_(_k_rand) + self.N.data.copy_(torch.ones(self.n_codes)) + + def forward(self, z): + # z: [b, c, t, h, w] + if self._need_init and self.training: + self._init_embeddings(z) + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [bthw, c] [65536, 8] + distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \ + - 2 * flat_inputs @ self.embeddings.t() \ + + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) # [bthw, c] [65536, 8] + + encoding_indices = torch.argmin(distances, dim=1) # [65536] + encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as( + flat_inputs) # [bthw, ncode] [65536, 16384] + encoding_indices = encoding_indices.view( + z.shape[0], *z.shape[2:]) # [b, t, h, w, ncode] [2, 32, 32, 32] + + embeddings = F.embedding( + encoding_indices, self.embeddings) # [b, t, h, w, c] self.embeddings [16384, 8] + embeddings = shift_dim(embeddings, -1, 1) # [b, c, t, h, w] [2, 8, 32, 32, 32] + + commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) + + # EMA codebook update + if self.training: + n_total = encode_onehot.sum(dim=0) # [16384] + encode_sum = flat_inputs.t() @ encode_onehot # [8, 16384] + if dist.is_initialized(): + dist.all_reduce(n_total) + dist.all_reduce(encode_sum) + + self.N.data.mul_(0.99).add_(n_total, alpha=0.01) + self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) + + n = self.N.sum() + weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n + encode_normalized = self.z_avg / weights.unsqueeze(1) + self.embeddings.data.copy_(encode_normalized) + + y = self._tile(flat_inputs) + _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + + if not self.no_random_restart: + usage = (self.N.view(self.n_codes, 1) + >= self.restart_thres).float() + self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) + + embeddings_st = (embeddings - z).detach() + z + + avg_probs = torch.mean(encode_onehot, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * + torch.log(avg_probs + 1e-10))) + + return dict(embeddings=embeddings_st, encodings=encoding_indices, + commitment_loss=commitment_loss, perplexity=perplexity) + + def dictionary_lookup(self, encodings): + embeddings = F.embedding(encodings, self.embeddings) + return embeddings diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/lpips.py b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..019434b9b29945d67e6e0a95dec324df7ff908f3 --- /dev/null +++ b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/lpips.py @@ -0,0 +1,181 @@ +"""Adapted from https://github.com/SongweiGe/TATS""" + +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + + +from collections import namedtuple +from torchvision import models +import torch.nn as nn +import torch +from tqdm import tqdm +import requests +import os +import hashlib +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format( + name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + # self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + self.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name is not "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + model.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer( + input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor( + outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) + for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor( + [-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor( + [.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, + padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple( + "VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, + h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x/(norm_factor+eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py new file mode 100644 index 0000000000000000000000000000000000000000..a64ec8878c0ecbf418775e2958b2cbf578ce2918 --- /dev/null +++ b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py @@ -0,0 +1,561 @@ +"""Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import math +import argparse +import numpy as np +import pickle as pkl + +import pytorch_lightning as pl +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from ..utils import shift_dim, adopt_weight, comp_getattr +from .lpips import LPIPS +from .codebook import Codebook + + +def silu(x): + return x*torch.sigmoid(x) + + +class SiLU(nn.Module): + def __init__(self): + super(SiLU, self).__init__() + + def forward(self, x): + return silu(x) + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1. - logits_real)) + loss_fake = torch.mean(F.relu(1. + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake))) + return d_loss + + +class VQGAN(pl.LightningModule): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.embedding_dim = cfg.model.embedding_dim # 8 + self.n_codes = cfg.model.n_codes # 16384 + + self.encoder = Encoder(cfg.model.n_hiddens, # 16 + cfg.model.downsample, # [2, 2, 2] + cfg.dataset.image_channels, # 1 + cfg.model.norm_type, # group + cfg.model.padding_type, # replicate + cfg.model.num_groups, # 32 + ) + self.decoder = Decoder( + cfg.model.n_hiddens, cfg.model.downsample, cfg.dataset.image_channels, cfg.model.norm_type, cfg.model.num_groups) + self.enc_out_ch = self.encoder.out_channels + self.pre_vq_conv = SamePadConv3d( + self.enc_out_ch, cfg.model.embedding_dim, 1, padding_type=cfg.model.padding_type) + self.post_vq_conv = SamePadConv3d( + cfg.model.embedding_dim, self.enc_out_ch, 1) + + self.codebook = Codebook(cfg.model.n_codes, cfg.model.embedding_dim, + no_random_restart=cfg.model.no_random_restart, restart_thres=cfg.model.restart_thres) + + self.gan_feat_weight = cfg.model.gan_feat_weight + # TODO: Changed batchnorm from sync to normal + self.image_discriminator = NLayerDiscriminator( + cfg.dataset.image_channels, cfg.model.disc_channels, cfg.model.disc_layers, norm_layer=nn.BatchNorm2d) + self.video_discriminator = NLayerDiscriminator3D( + cfg.dataset.image_channels, cfg.model.disc_channels, cfg.model.disc_layers, norm_layer=nn.BatchNorm3d) + + if cfg.model.disc_loss_type == 'vanilla': + self.disc_loss = vanilla_d_loss + elif cfg.model.disc_loss_type == 'hinge': + self.disc_loss = hinge_d_loss + + self.perceptual_model = LPIPS().eval() + + self.image_gan_weight = cfg.model.image_gan_weight + self.video_gan_weight = cfg.model.video_gan_weight + + self.perceptual_weight = cfg.model.perceptual_weight + + self.l1_weight = cfg.model.l1_weight + self.save_hyperparameters() + + def encode(self, x, include_embeddings=False, quantize=True): + h = self.pre_vq_conv(self.encoder(x)) + if quantize: + vq_output = self.codebook(h) + if include_embeddings: + return vq_output['embeddings'], vq_output['encodings'] + else: + return vq_output['encodings'] + return h + + def decode(self, latent, quantize=False): + if quantize: + vq_output = self.codebook(latent) + latent = vq_output['encodings'] + h = F.embedding(latent, self.codebook.embeddings) + h = self.post_vq_conv(shift_dim(h, -1, 1)) + return self.decoder(h) + + def forward(self, x, optimizer_idx=None, log_image=False): + B, C, T, H, W = x.shape + z = self.pre_vq_conv(self.encoder(x)) # [2, 32, 32, 32, 32] [2, 8, 32, 32, 32] + vq_output = self.codebook(z) # ['embeddings', 'encodings', 'commitment_loss', 'perplexity'] + x_recon = self.decoder(self.post_vq_conv(vq_output['embeddings'])) # [2, 8, 32, 32, 32] [2, 32, 32, 32, 32] + + recon_loss = F.l1_loss(x_recon, x) * self.l1_weight + + # Selects one random 2D image from each 3D Image + frame_idx = torch.randint(0, T, [B]).cuda() + frame_idx_selected = frame_idx.reshape(-1, + 1, 1, 1, 1).repeat(1, C, 1, H, W) # [2, 1, 1, 64, 64] + frames = torch.gather(x, 2, frame_idx_selected).squeeze(2) # [2, 1, 64, 64] + frames_recon = torch.gather(x_recon, 2, frame_idx_selected).squeeze(2) # [2, 1, 64, 64] + + if log_image: + return frames, frames_recon, x, x_recon + + if optimizer_idx == 0: + # Autoencoder - train the "generator" + + # Perceptual loss + perceptual_loss = 0 + if self.perceptual_weight > 0: + perceptual_loss = self.perceptual_model( + frames, frames_recon).mean() * self.perceptual_weight + + # Discriminator loss (turned on after a certain epoch) + logits_image_fake, pred_image_fake = self.image_discriminator( + frames_recon) + logits_video_fake, pred_video_fake = self.video_discriminator( + x_recon) + g_image_loss = -torch.mean(logits_image_fake) + g_video_loss = -torch.mean(logits_video_fake) + g_loss = self.image_gan_weight*g_image_loss + self.video_gan_weight*g_video_loss + disc_factor = adopt_weight( + self.global_step, threshold=self.cfg.model.discriminator_iter_start) + aeloss = disc_factor * g_loss + + # GAN feature matching loss - tune features such that we get the same prediction result on the discriminator + image_gan_feat_loss = 0 + video_gan_feat_loss = 0 + feat_weights = 4.0 / (3 + 1) + if self.image_gan_weight > 0: + logits_image_real, pred_image_real = self.image_discriminator( + frames) + for i in range(len(pred_image_fake)-1): + image_gan_feat_loss += feat_weights * \ + F.l1_loss(pred_image_fake[i], pred_image_real[i].detach( + )) * (self.image_gan_weight > 0) + if self.video_gan_weight > 0: + logits_video_real, pred_video_real = self.video_discriminator( + x) + for i in range(len(pred_video_fake)-1): + video_gan_feat_loss += feat_weights * \ + F.l1_loss(pred_video_fake[i], pred_video_real[i].detach( + )) * (self.video_gan_weight > 0) + gan_feat_loss = disc_factor * self.gan_feat_weight * \ + (image_gan_feat_loss + video_gan_feat_loss) + + self.log("train/g_image_loss", g_image_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/g_video_loss", g_video_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/image_gan_feat_loss", image_gan_feat_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/video_gan_feat_loss", video_gan_feat_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/perceptual_loss", perceptual_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log("train/recon_loss", recon_loss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + self.log("train/aeloss", aeloss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + self.log("train/commitment_loss", vq_output['commitment_loss'], + prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log('train/perplexity', vq_output['perplexity'], + prog_bar=True, logger=True, on_step=True, on_epoch=True) + return recon_loss, x_recon, vq_output, aeloss, perceptual_loss, gan_feat_loss + + if optimizer_idx == 1: + # Train discriminator + logits_image_real, _ = self.image_discriminator(frames.detach()) + logits_video_real, _ = self.video_discriminator(x.detach()) + + logits_image_fake, _ = self.image_discriminator( + frames_recon.detach()) + logits_video_fake, _ = self.video_discriminator(x_recon.detach()) + + d_image_loss = self.disc_loss(logits_image_real, logits_image_fake) + d_video_loss = self.disc_loss(logits_video_real, logits_video_fake) + disc_factor = adopt_weight( + self.global_step, threshold=self.cfg.model.discriminator_iter_start) + discloss = disc_factor * \ + (self.image_gan_weight*d_image_loss + + self.video_gan_weight*d_video_loss) + + self.log("train/logits_image_real", logits_image_real.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_image_fake", logits_image_fake.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_video_real", logits_video_real.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_video_fake", logits_video_fake.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/d_image_loss", d_image_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/d_video_loss", d_video_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/discloss", discloss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + return discloss + + perceptual_loss = self.perceptual_model( + frames, frames_recon) * self.perceptual_weight + return recon_loss, x_recon, vq_output, perceptual_loss + + def training_step(self, batch, batch_idx, optimizer_idx): + x = batch['image'] + if optimizer_idx == 0: + recon_loss, _, vq_output, aeloss, perceptual_loss, gan_feat_loss = self.forward( + x, optimizer_idx) + commitment_loss = vq_output['commitment_loss'] + loss = recon_loss + commitment_loss + aeloss + perceptual_loss + gan_feat_loss + if optimizer_idx == 1: + discloss = self.forward(x, optimizer_idx) + loss = discloss + return loss + + def validation_step(self, batch, batch_idx): + x = batch['image'] # TODO: batch['stft'] + recon_loss, _, vq_output, perceptual_loss = self.forward(x) + self.log('val/recon_loss', recon_loss, prog_bar=True) + self.log('val/perceptual_loss', perceptual_loss, prog_bar=True) + self.log('val/perplexity', vq_output['perplexity'], prog_bar=True) + self.log('val/commitment_loss', + vq_output['commitment_loss'], prog_bar=True) + + def configure_optimizers(self): + lr = self.cfg.model.lr + opt_ae = torch.optim.Adam(list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.pre_vq_conv.parameters()) + + list(self.post_vq_conv.parameters()) + + list(self.codebook.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(list(self.image_discriminator.parameters()) + + list(self.video_discriminator.parameters()), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def log_images(self, batch, **kwargs): + log = dict() + x = batch['image'] + x = x.to(self.device) + frames, frames_rec, _, _ = self(x, log_image=True) + log["inputs"] = frames + log["reconstructions"] = frames_rec + #log['mean_org'] = batch['mean_org'] + #log['std_org'] = batch['std_org'] + return log + + def log_videos(self, batch, **kwargs): + log = dict() + x = batch['image'] + _, _, x, x_rec = self(x, log_image=True) + log["inputs"] = x + log["reconstructions"] = x_rec + #log['mean_org'] = batch['mean_org'] + #log['std_org'] = batch['std_org'] + return log + + +def Normalize(in_channels, norm_type='group', num_groups=32): + assert norm_type in ['group', 'batch'] + if norm_type == 'group': + # TODO Changed num_groups from 32 to 8 + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == 'batch': + return torch.nn.SyncBatchNorm(in_channels) + + +class Encoder(nn.Module): + def __init__(self, n_hiddens, downsample, image_channel=3, norm_type='group', padding_type='replicate', num_groups=32): + super().__init__() + n_times_downsample = np.array([int(math.log2(d)) for d in downsample]) + self.conv_blocks = nn.ModuleList() + max_ds = n_times_downsample.max() + + self.conv_first = SamePadConv3d( + image_channel, n_hiddens, kernel_size=3, padding_type=padding_type) + + for i in range(max_ds): + block = nn.Module() + in_channels = n_hiddens * 2**i + out_channels = n_hiddens * 2**(i+1) + stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) + block.down = SamePadConv3d( + in_channels, out_channels, 4, stride=stride, padding_type=padding_type) + block.res = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + self.conv_blocks.append(block) + n_times_downsample -= 1 + + self.final_block = nn.Sequential( + Normalize(out_channels, norm_type, num_groups=num_groups), + SiLU() + ) + + self.out_channels = out_channels + + def forward(self, x): + h = self.conv_first(x) + for block in self.conv_blocks: + h = block.down(h) + h = block.res(h) + h = self.final_block(h) + return h + + +class Decoder(nn.Module): + def __init__(self, n_hiddens, upsample, image_channel, norm_type='group', num_groups=32): + super().__init__() + + n_times_upsample = np.array([int(math.log2(d)) for d in upsample]) + max_us = n_times_upsample.max() + + in_channels = n_hiddens*2**max_us + self.final_block = nn.Sequential( + Normalize(in_channels, norm_type, num_groups=num_groups), + SiLU() + ) + + self.conv_blocks = nn.ModuleList() + for i in range(max_us): + block = nn.Module() + in_channels = in_channels if i == 0 else n_hiddens*2**(max_us-i+1) + out_channels = n_hiddens*2**(max_us-i) + us = tuple([2 if d > 0 else 1 for d in n_times_upsample]) + block.up = SamePadConvTranspose3d( + in_channels, out_channels, 4, stride=us) + block.res1 = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + block.res2 = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + self.conv_blocks.append(block) + n_times_upsample -= 1 + + self.conv_last = SamePadConv3d( + out_channels, image_channel, kernel_size=3) + + def forward(self, x): + h = self.final_block(x) + for i, block in enumerate(self.conv_blocks): + h = block.up(h) + h = block.res1(h) + h = block.res2(h) + h = self.conv_last(h) + return h + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group', padding_type='replicate', num_groups=32): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, norm_type, num_groups=num_groups) + self.conv1 = SamePadConv3d( + in_channels, out_channels, kernel_size=3, padding_type=padding_type) + self.dropout = torch.nn.Dropout(dropout) + self.norm2 = Normalize(in_channels, norm_type, num_groups=num_groups) + self.conv2 = SamePadConv3d( + out_channels, out_channels, kernel_size=3, padding_type=padding_type) + if self.in_channels != self.out_channels: + self.conv_shortcut = SamePadConv3d( + in_channels, out_channels, kernel_size=3, padding_type=padding_type) + + def forward(self, x): + h = x + h = self.norm1(h) + h = silu(h) + h = self.conv1(h) + h = self.norm2(h) + h = silu(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.conv_shortcut(x) + + return x+h + + +# Does not support dilation +class SamePadConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + if isinstance(stride, int): + stride = (stride,) * 3 + + # assumes that the input shape is divisible by stride + total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) + pad_input = [] + for p in total_pad[::-1]: # reverse since F.pad starts from last dim + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + self.pad_input = pad_input + self.padding_type = padding_type + + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, + stride=stride, padding=0, bias=bias) + + def forward(self, x): + return self.conv(F.pad(x, self.pad_input, mode=self.padding_type)) + + +class SamePadConvTranspose3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + if isinstance(stride, int): + stride = (stride,) * 3 + + total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) + pad_input = [] + for p in total_pad[::-1]: # reverse since F.pad starts from last dim + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + self.pad_input = pad_input + self.padding_type = padding_type + + self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, + stride=stride, bias=bias, + padding=tuple([k - 1 for k in kernel_size])) + + def forward(self, x): + return self.convt(F.pad(x, self.pad_input, mode=self.padding_type)) + + +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True): + # def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=True): + super(NLayerDiscriminator, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, + stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, + stride=1, padding=padw)]] + + if use_sigmoid: + sequence += [[nn.Sigmoid()]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[-1], res[1:] + else: + return self.model(input), _ + + +class NLayerDiscriminator3D(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True): + super(NLayerDiscriminator3D, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv3d(input_nc, ndf, kernel_size=kw, + stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv3d(nf, 1, kernel_size=kw, + stride=1, padding=padw)]] + + if use_sigmoid: + sequence += [[nn.Sigmoid()]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[-1], res[1:] + else: + return self.model(input), _ diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/utils.py b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..acf1a21270d98d0a1ea9935d00f9f16d89d54551 --- /dev/null +++ b/Generation_Pipeline_filter/syn_colon/TumorGeneration/ldm/vq_gan_3d/utils.py @@ -0,0 +1,177 @@ +""" Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import warnings +import torch +import imageio + +import math +import numpy as np + +import sys +import pdb as pdb_original +import logging + +import imageio.core.util +logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR) + + +class ForkedPdb(pdb_original.Pdb): + """A Pdb subclass that may be used + from a forked multiprocessing child + + """ + + def interaction(self, *args, **kwargs): + _stdin = sys.stdin + try: + sys.stdin = open('/dev/stdin') + pdb_original.Pdb.interaction(self, *args, **kwargs) + finally: + sys.stdin = _stdin + + +# Shifts src_tf dim to dest dim +# i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c) +def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): + n_dims = len(x.shape) + if src_dim < 0: + src_dim = n_dims + src_dim + if dest_dim < 0: + dest_dim = n_dims + dest_dim + + assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims + + dims = list(range(n_dims)) + del dims[src_dim] + + permutation = [] + ctr = 0 + for i in range(n_dims): + if i == dest_dim: + permutation.append(src_dim) + else: + permutation.append(dims[ctr]) + ctr += 1 + x = x.permute(permutation) + if make_contiguous: + x = x.contiguous() + return x + + +# reshapes tensor start from dim i (inclusive) +# to dim j (exclusive) to the desired shape +# e.g. if x.shape = (b, thw, c) then +# view_range(x, 1, 2, (t, h, w)) returns +# x of shape (b, t, h, w, c) +def view_range(x, i, j, shape): + shape = tuple(shape) + + n_dims = len(x.shape) + if i < 0: + i = n_dims + i + + if j is None: + j = n_dims + elif j < 0: + j = n_dims + j + + assert 0 <= i < j <= n_dims + + x_shape = x.shape + target_shape = x_shape[:i] + shape + x_shape[j:] + return x.view(target_shape) + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def tensor_slice(x, begin, size): + assert all([b >= 0 for b in begin]) + size = [l - b if s == -1 else s + for s, b, l in zip(size, begin, x.shape)] + assert all([s >= 0 for s in size]) + + slices = [slice(b, b + s) for b, s in zip(begin, size)] + return x[slices] + + +def adopt_weight(global_step, threshold=0, value=0.): + weight = 1 + if global_step < threshold: + weight = value + return weight + + +def save_video_grid(video, fname, nrow=None, fps=6): + b, c, t, h, w = video.shape + video = video.permute(0, 2, 3, 4, 1) + video = (video.cpu().numpy() * 255).astype('uint8') + if nrow is None: + nrow = math.ceil(math.sqrt(b)) + ncol = math.ceil(b / nrow) + padding = 1 + video_grid = np.zeros((t, (padding + h) * nrow + padding, + (padding + w) * ncol + padding, c), dtype='uint8') + for i in range(b): + r = i // ncol + c = i % ncol + start_r = (padding + h) * r + start_c = (padding + w) * c + video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] + video = [] + for i in range(t): + video.append(video_grid[i]) + imageio.mimsave(fname, video, fps=fps) + ## skvideo.io.vwrite(fname, video_grid, inputdict={'-r': '5'}) + #print('saved videos to', fname) + + +def comp_getattr(args, attr_name, default=None): + if hasattr(args, attr_name): + return getattr(args, attr_name) + else: + return default + + +def visualize_tensors(t, name=None, nest=0): + if name is not None: + print(name, "current nest: ", nest) + print("type: ", type(t)) + if 'dict' in str(type(t)): + print(t.keys()) + for k in t.keys(): + if t[k] is None: + print(k, "None") + else: + if 'Tensor' in str(type(t[k])): + print(k, t[k].shape) + elif 'dict' in str(type(t[k])): + print(k, 'dict') + visualize_tensors(t[k], name, nest + 1) + elif 'list' in str(type(t[k])): + print(k, len(t[k])) + visualize_tensors(t[k], name, nest + 1) + elif 'list' in str(type(t)): + print("list length: ", len(t)) + for t2 in t: + visualize_tensors(t2, name, nest + 1) + elif 'Tensor' in str(type(t)): + print(t.shape) + else: + print(t) + return "" diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/utils.py b/Generation_Pipeline_filter/syn_colon/TumorGeneration/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d9d7972cd1ad2f2e4e1070d747c50df734755a68 --- /dev/null +++ b/Generation_Pipeline_filter/syn_colon/TumorGeneration/utils.py @@ -0,0 +1,233 @@ +### Tumor Generateion +import random +import cv2 +import elasticdeform +import numpy as np +from scipy.ndimage import gaussian_filter +from TumorGeneration.ldm.ddpm.ddim import DDIMSampler +import skimage + +def im2col(A, BSZ, stepsize=1): + # Parameters + M, N = A.shape + # Get Starting block indices + start_idx = np.arange( + 0, M-BSZ[0]+1, stepsize)[:, None]*N + np.arange(0, N-BSZ[1]+1, stepsize) + # Get offsetted indices across the height and width of input array + offset_idx = np.arange(BSZ[0])[:, None]*N + np.arange(BSZ[1]) + # Get all actual indices & index into input array for final output + return np.take(A, start_idx.ravel()[:, None] + offset_idx.ravel()) + +def seg_to_instance_bd(seg: np.ndarray, + tsz_h: int = 1) -> np.ndarray: + """Generate instance contour map from segmentation masks. + """ + + tsz = tsz_h*2+1 + tsz=int(tsz) + kernel = np.ones((tsz, tsz, tsz), np.uint8) + dilated_seg_mask = skimage.morphology.binary_erosion(seg.astype('uint8'), kernel) + + dilated_seg_mask = dilated_seg_mask.astype(np.uint8) + bd = seg-dilated_seg_mask + bd = (bd>0).astype('uint8') + + return bd + +def sector_mask(shape,centre,radius,angle_range): + """ + Return a boolean mask for a circular sector. The start/stop angles in + `angle_range` should be given in clockwise order. + """ + + x,y = np.ogrid[:shape[0],:shape[1]] + cx,cy = centre + tmin,tmax = np.deg2rad(angle_range) + + # ensure stop angle > start angle + if tmax < tmin: + tmax += 2*np.pi + + # convert cartesian --> polar coordinates + r2 = (x-cx)*(x-cx) + (y-cy)*(y-cy) + theta = np.arctan2(x-cx,y-cy) - tmin + + # wrap angles between 0 and 2*pi + theta %= (2*np.pi) + + # circular mask + circmask = r2 <= radius*radius + + # angular mask + anglemask = theta <= (tmax-tmin) + + return circmask*anglemask + +from scipy.ndimage import label +import elasticdeform +def generate_random_mask(organ_mask): + # initialize tumor mask + tumor_mask = np.zeros_like(organ_mask) + + # randowm mask angle + start_angle = random.randint(0, 360) + angle_range = random.randint(90, 360) + + # generate organ boundary + erode_sz = angle_range//45 * 1 + 3 + # select_size = [3.5, 4, 4.5, 5.0, 5.5, 6.0] + # erode_sz = np.random.choice(select_size) + # print('erode_sz', erode_sz) + organ_bd = seg_to_instance_bd(organ_mask, tsz_h=erode_sz) + + # organ mask range + z_valid_list = np.where(np.any(organ_bd, axis=(0, 1)))[0] + valid_num = len(z_valid_list) + z_valid_list = z_valid_list[round(valid_num*0.25):round(valid_num*0.75)] + # print(z_valid_list) + z = random.choice(z_valid_list) + + # sample thickness + z_thickness = random.randint(10, 20) # 10-20 + # print('z, z_thickness', z, z_thickness) + # crop + tumor_mask[:,:,max(0,z-z_thickness):min(95,z+z_thickness)] = organ_bd[:,:,max(0,z-z_thickness):min(95,z+z_thickness)] + + # random select one + tumor_mask, nb = label(tumor_mask) + sample_id = random.randint(1, nb) + sample_tumor_mask = (tumor_mask==sample_id).astype(np.uint8) + + z_valid = np.where(np.any(sample_tumor_mask, axis=(0, 1)))[0] + z = z_valid[round(0.5 * len(z_valid))] + + # randowm mask region + selected_slice = sample_tumor_mask[..., z] + coordinates = np.argwhere(selected_slice == 1) + center_x, center_y = int(coordinates[:,0].mean()), int(coordinates[:,1].mean()) + # start_angle = random.randint(0, 360) + # angle_range = random.randint(90, 360) + mask_region = sector_mask(selected_slice.shape,(center_x,center_y), 48, (start_angle,start_angle+angle_range)) + mask_region = np.repeat(mask_region[:,:,np.newaxis], axis=-1, repeats=96) + + # elasticdeform + # sigma = random.uniform(1,2) + sigma = random.uniform(2,5) + # sigma = random.uniform(5,10) + deform_tumor_mask = elasticdeform.deform_random_grid(sample_tumor_mask, sigma=sigma, points=3, order=0, axis=(0,1)) + # deform_tumor_mask = elasticdeform.deform_random_grid(deform_tumor_mask, sigma=sigma, points=3, order=0, axis=(1,2)) + # deform_tumor_mask = elasticdeform.deform_random_grid(deform_tumor_mask, sigma=sigma, points=3, order=0, axis=(0,2)) + + # final_tumor_mask = deform_tumor_mask*mask_region*organ_mask + final_tumor_mask = deform_tumor_mask*mask_region + + return final_tumor_mask + + +from .ldm.vq_gan_3d.model.vqgan import VQGAN +import matplotlib.pyplot as plt +import SimpleITK as sitk +from .ldm.ddpm import Unet3D, GaussianDiffusion, Tester +from hydra import initialize, compose +import torch +import yaml +def synt_model_prepare(device, vqgan_ckpt='TumorGeneration/model_weight/recon_colon.ckpt', diffusion_ckpt='TumorGeneration/model_weight/', fold=0, organ='colon'): + with initialize(config_path="diffusion_config/"): + cfg=compose(config_name="ddpm.yaml") + vqgan_ckpt = 'TumorGeneration/model_weight/recon_colon.ckpt' + diffusion_ckpt = 'TumorGeneration/model_weight/diffusion_colon.pt' + print('diffusion_ckpt',diffusion_ckpt) + vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt) + vqgan = vqgan.to(device) + vqgan.eval() + + + noearly_Unet3D = Unet3D( + dim=cfg.diffusion_img_size, + dim_mults=cfg.dim_mults, + channels=cfg.diffusion_num_channels, + out_dim=cfg.out_dim + ).to(device) + + noearly_diffusion = GaussianDiffusion( + noearly_Unet3D, + vqgan_ckpt= vqgan_ckpt, # cfg.vqgan_ckpt, + image_size=cfg.diffusion_img_size, + num_frames=cfg.diffusion_depth_size, + channels=cfg.diffusion_num_channels, + timesteps=200, # cfg.timesteps, + loss_type=cfg.loss_type, + device=device + ).to(device) + + noearly_checkpoint = torch.load(diffusion_ckpt, map_location=device) + + noearly_diffusion.load_state_dict(noearly_checkpoint['ema']) + + noearly_sampler = DDIMSampler(noearly_diffusion, schedule="cosine") + # breakpoint() + return vqgan, noearly_sampler + +def synthesize_colon_tumor(ct_volume, organ_mask, vqgan, sampler, ddim_ts=50): + device=ct_volume.device + + # generate tumor mask + total_tumor_mask = [] + organ_mask_np = organ_mask.cpu().numpy() + with torch.no_grad(): + # get model input + for bs in range(organ_mask_np.shape[0]): + tumor_mask = generate_random_mask(organ_mask_np[bs,0]) + # tumor_mask = organ_mask_np[bs,0] + total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:]) + total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device) + # breakpoint() + volume = ct_volume*2.0 - 1.0 + mask = total_tumor_mask*2.0 - 1.0 + mask_ = 1-total_tumor_mask + masked_volume = (volume*mask_).detach() + + volume = volume.permute(0,1,-1,-3,-2) + masked_volume = masked_volume.permute(0,1,-1,-3,-2) + mask = mask.permute(0,1,-1,-3,-2) + + # vqgan encoder inference + masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True) + masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) / + (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:]) + cond = torch.cat((masked_volume_feat, cc), dim=1) + + # diffusion inference and decoder + shape = masked_volume_feat.shape[-4:] + samples_ddim, _ = sampler.sample(S=ddim_ts, + conditioning=cond, + batch_size=1, + shape=shape, + verbose=False) + samples_ddim = (((samples_ddim + 1.0) / 2.0) * (vqgan.codebook.embeddings.max() - + vqgan.codebook.embeddings.min())) + vqgan.codebook.embeddings.min() + + sample = vqgan.decode(samples_ddim, quantize=True) + + # post-process + mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0) + sigma = np.random.uniform(1, 2) # (1, 2) + mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma]) + # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy() + + volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0) + sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0) + + mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device) + final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_ + final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0) + + # final_volume_ = (sample+1.0)/2.0 + final_volume_ = final_volume_.permute(0,1,-2,-1,-3) + organ_tumor_mask = torch.zeros_like(organ_mask) + organ_tumor_mask[organ_mask==1] = 1 + organ_tumor_mask[total_tumor_mask==1] = 2 + + return final_volume_, organ_tumor_mask diff --git a/Generation_Pipeline_filter/syn_colon/TumorGeneration/utils_.py b/Generation_Pipeline_filter/syn_colon/TumorGeneration/utils_.py new file mode 100644 index 0000000000000000000000000000000000000000..312e72d1571b3cd996fd567bcedf939443a4e182 --- /dev/null +++ b/Generation_Pipeline_filter/syn_colon/TumorGeneration/utils_.py @@ -0,0 +1,298 @@ +### Tumor Generateion +import random +import cv2 +import elasticdeform +import numpy as np +from scipy.ndimage import gaussian_filter + +# Step 1: Random select (numbers) location for tumor. +def random_select(mask_scan): + # we first find z index and then sample point with z slice + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0][[0, -1]] + + # we need to strict number z's position (0.3 - 0.7 in the middle of liver) + z = round(random.uniform(0.3, 0.7) * (z_end - z_start)) + z_start + + liver_mask = mask_scan[..., z] + + # erode the mask (we don't want the edge points) + kernel = np.ones((5,5), dtype=np.uint8) + liver_mask = cv2.erode(liver_mask, kernel, iterations=1) + + coordinates = np.argwhere(liver_mask == 1) + random_index = np.random.randint(0, len(coordinates)) + xyz = coordinates[random_index].tolist() # get x,y + xyz.append(z) + potential_points = xyz + + return potential_points + +# Step 2 : generate the ellipsoid +def get_ellipsoid(x, y, z): + """" + x, y, z is the radius of this ellipsoid in x, y, z direction respectly. + """ + sh = (4*x, 4*y, 4*z) + out = np.zeros(sh, int) + aux = np.zeros(sh) + radii = np.array([x, y, z]) + com = np.array([2*x, 2*y, 2*z]) # center point + + # calculate the ellipsoid + bboxl = np.floor(com-radii).clip(0,None).astype(int) + bboxh = (np.ceil(com+radii)+1).clip(None, sh).astype(int) + roi = out[tuple(map(slice,bboxl,bboxh))] + roiaux = aux[tuple(map(slice,bboxl,bboxh))] + logrid = *map(np.square,np.ogrid[tuple( + map(slice,(bboxl-com)/radii,(bboxh-com-1)/radii,1j*(bboxh-bboxl)))]), + dst = (1-sum(logrid)).clip(0,None) + mask = dst>roiaux + roi[mask] = 1 + np.copyto(roiaux,dst,where=mask) + + return out + +def get_fixed_geo(mask_scan, tumor_type): + + enlarge_x, enlarge_y, enlarge_z = 160, 160, 160 + geo_mask = np.zeros((mask_scan.shape[0] + enlarge_x, mask_scan.shape[1] + enlarge_y, mask_scan.shape[2] + enlarge_z), dtype=np.int8) + tiny_radius, small_radius, medium_radius, large_radius = 4, 8, 16, 32 + + if tumor_type == 'tiny': + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Tiny tumor + x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + sigma = random.uniform(0.5,1) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + + if tumor_type == 'small': + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Small tumor + x = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + y = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + z = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + sigma = random.randint(1, 2) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'medium': + num_tumor = random.randint(2, 5) + for _ in range(num_tumor): + # medium tumor + x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + sigma = random.randint(3, 6) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste medium tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'large': + num_tumor = random.randint(1,3) + for _ in range(num_tumor): + # Large tumor + x = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + y = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + z = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + sigma = random.randint(5, 10) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == "mix": + # tiny + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Tiny tumor + x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + sigma = random.uniform(0.5,1) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + + # small + num_tumor = random.randint(5,10) + for _ in range(num_tumor): + # Small tumor + x = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + y = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + z = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + sigma = random.randint(1, 2) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + # medium + num_tumor = random.randint(2, 5) + for _ in range(num_tumor): + # medium tumor + x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + sigma = random.randint(3, 6) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste medium tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + # large + num_tumor = random.randint(1,3) + for _ in range(num_tumor): + # Large tumor + x = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + y = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + z = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + sigma = random.randint(5, 10) + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + geo_mask = geo_mask[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + # texture_map = texture_map[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + geo_mask = (geo_mask * mask_scan) >=1 + + return geo_mask + + +def get_tumor(volume_scan, mask_scan, tumor_type): + tumor_mask = get_fixed_geo(mask_scan, tumor_type) + + sigma = np.random.uniform(1, 2) + # difference = np.random.uniform(65, 145) + difference = 1 + + # blur the boundary + tumor_mask_blur = gaussian_filter(tumor_mask*1.0, sigma) + + + abnormally_full = volume_scan * (1 - mask_scan) + abnormally + abnormally_mask = mask_scan + geo_mask + + return abnormally_full, abnormally_mask + +def SynthesisTumor(volume_scan, mask_scan, tumor_type): + # for speed_generate_tumor, we only send the liver part into the generate program + x_start, x_end = np.where(np.any(mask_scan, axis=(1, 2)))[0][[0, -1]] + y_start, y_end = np.where(np.any(mask_scan, axis=(0, 2)))[0][[0, -1]] + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0][[0, -1]] + + # shrink the boundary + x_start, x_end = max(0, x_start+1), min(mask_scan.shape[0], x_end-1) + y_start, y_end = max(0, y_start+1), min(mask_scan.shape[1], y_end-1) + z_start, z_end = max(0, z_start+1), min(mask_scan.shape[2], z_end-1) + + ct_volume = volume_scan[x_start:x_end, y_start:y_end, z_start:z_end] + organ_mask = mask_scan[x_start:x_end, y_start:y_end, z_start:z_end] + + # input texture shape: 420, 300, 320 + # we need to cut it into the shape of liver_mask + # for examples, the liver_mask.shape = 286, 173, 46; we should change the texture shape + x_length, y_length, z_length = 64, 64, 64 + crop_x = random.randint(x_start, x_end - x_length - 1) # random select the start point, -1 is to avoid boundary check + crop_y = random.randint(y_start, y_end - y_length - 1) + crop_z = random.randint(z_start, z_end - z_length - 1) + + ct_volume, organ_tumor_mask = get_tumor(ct_volume, organ_mask, tumor_type) + volume_scan[x_start:x_end, y_start:y_end, z_start:z_end] = ct_volume + mask_scan[x_start:x_end, y_start:y_end, z_start:z_end] = organ_tumor_mask + + return volume_scan, mask_scan diff --git a/Generation_Pipeline_filter/syn_colon/healthy_colon_1k.txt b/Generation_Pipeline_filter/syn_colon/healthy_colon_1k.txt new file mode 100644 index 0000000000000000000000000000000000000000..780aae5af26bdf8a51701d20142eabd922526d0d --- /dev/null +++ b/Generation_Pipeline_filter/syn_colon/healthy_colon_1k.txt @@ -0,0 +1,928 @@ +BDMAP_00001823 +BDMAP_00003074 +BDMAP_00001305 +BDMAP_00001635 +BDMAP_00002359 +BDMAP_00001265 +BDMAP_00000701 +BDMAP_00000771 +BDMAP_00003581 +BDMAP_00002523 +BDMAP_00004028 +BDMAP_00005151 +BDMAP_00001183 +BDMAP_00001656 +BDMAP_00003898 +BDMAP_00001845 +BDMAP_00000481 +BDMAP_00003324 +BDMAP_00002688 +BDMAP_00000948 +BDMAP_00004796 +BDMAP_00004198 +BDMAP_00003514 +BDMAP_00000432 +BDMAP_00003832 +BDMAP_00001296 +BDMAP_00003683 +BDMAP_00001607 +BDMAP_00004745 +BDMAP_00005167 +BDMAP_00005154 +BDMAP_00003598 +BDMAP_00003551 +BDMAP_00000176 +BDMAP_00004719 +BDMAP_00003722 +BDMAP_00002690 +BDMAP_00002244 +BDMAP_00000883 +BDMAP_00000926 +BDMAP_00002849 +BDMAP_00004549 +BDMAP_00004017 +BDMAP_00003482 +BDMAP_00003225 +BDMAP_00000416 +BDMAP_00002387 +BDMAP_00002022 +BDMAP_00002909 +BDMAP_00003236 +BDMAP_00000465 +BDMAP_00001784 +BDMAP_00004103 +BDMAP_00000656 +BDMAP_00004850 +BDMAP_00002955 +BDMAP_00003633 +BDMAP_00000137 +BDMAP_00004529 +BDMAP_00004903 +BDMAP_00001309 +BDMAP_00002216 +BDMAP_00001444 +BDMAP_00000263 +BDMAP_00004066 +BDMAP_00003920 +BDMAP_00001434 +BDMAP_00004890 +BDMAP_00000400 +BDMAP_00001238 +BDMAP_00003592 +BDMAP_00000431 +BDMAP_00002304 +BDMAP_00000285 +BDMAP_00004995 +BDMAP_00004264 +BDMAP_00001440 +BDMAP_00001383 +BDMAP_00003614 +BDMAP_00005157 +BDMAP_00003608 +BDMAP_00002619 +BDMAP_00000615 +BDMAP_00000084 +BDMAP_00002804 +BDMAP_00002592 +BDMAP_00001868 +BDMAP_00002021 +BDMAP_00000297 +BDMAP_00003202 +BDMAP_00000411 +BDMAP_00005070 +BDMAP_00003364 +BDMAP_00004395 +BDMAP_00002075 +BDMAP_00002844 +BDMAP_00002712 +BDMAP_00000714 +BDMAP_00002717 +BDMAP_00004895 +BDMAP_00000698 +BDMAP_00003384 +BDMAP_00001286 +BDMAP_00001562 +BDMAP_00004228 +BDMAP_00000831 +BDMAP_00000855 +BDMAP_00004672 +BDMAP_00000882 +BDMAP_00004992 +BDMAP_00002232 +BDMAP_00003849 +BDMAP_00004880 +BDMAP_00004074 +BDMAP_00002626 +BDMAP_00004262 +BDMAP_00000368 +BDMAP_00002826 +BDMAP_00000837 +BDMAP_00001911 +BDMAP_00001557 +BDMAP_00001126 +BDMAP_00002328 +BDMAP_00002959 +BDMAP_00002562 +BDMAP_00003600 +BDMAP_00001057 +BDMAP_00000940 +BDMAP_00002120 +BDMAP_00002227 +BDMAP_00000122 +BDMAP_00002479 +BDMAP_00002805 +BDMAP_00004980 +BDMAP_00001862 +BDMAP_00000778 +BDMAP_00003749 +BDMAP_00000245 +BDMAP_00000989 +BDMAP_00001247 +BDMAP_00000623 +BDMAP_00004113 +BDMAP_00002278 +BDMAP_00004841 +BDMAP_00001602 +BDMAP_00001464 +BDMAP_00001712 +BDMAP_00003815 +BDMAP_00002407 +BDMAP_00003150 +BDMAP_00001711 +BDMAP_00002273 +BDMAP_00002751 +BDMAP_00005074 +BDMAP_00001068 +BDMAP_00004447 +BDMAP_00000977 +BDMAP_00004297 +BDMAP_00000812 +BDMAP_00004641 +BDMAP_00001422 +BDMAP_00003385 +BDMAP_00003164 +BDMAP_00002475 +BDMAP_00002166 +BDMAP_00004232 +BDMAP_00000826 +BDMAP_00003769 +BDMAP_00003569 +BDMAP_00003853 +BDMAP_00004494 +BDMAP_00004011 +BDMAP_00002776 +BDMAP_00001517 +BDMAP_00004304 +BDMAP_00004645 +BDMAP_00000091 +BDMAP_00004738 +BDMAP_00000725 +BDMAP_00003771 +BDMAP_00002524 +BDMAP_00000161 +BDMAP_00000902 +BDMAP_00001786 +BDMAP_00002332 +BDMAP_00004175 +BDMAP_00002419 +BDMAP_00004077 +BDMAP_00004295 +BDMAP_00002871 +BDMAP_00004148 +BDMAP_00000676 +BDMAP_00001782 +BDMAP_00003947 +BDMAP_00003513 +BDMAP_00003130 +BDMAP_00001545 +BDMAP_00000667 +BDMAP_00005078 +BDMAP_00003435 +BDMAP_00002545 +BDMAP_00002498 +BDMAP_00001255 +BDMAP_00004065 +BDMAP_00002099 +BDMAP_00001504 +BDMAP_00001863 +BDMAP_00000542 +BDMAP_00002326 +BDMAP_00005155 +BDMAP_00001476 +BDMAP_00000388 +BDMAP_00000159 +BDMAP_00004060 +BDMAP_00000332 +BDMAP_00004087 +BDMAP_00000516 +BDMAP_00000574 +BDMAP_00004943 +BDMAP_00004514 +BDMAP_00003329 +BDMAP_00001597 +BDMAP_00002172 +BDMAP_00000833 +BDMAP_00004187 +BDMAP_00004744 +BDMAP_00001676 +BDMAP_00003558 +BDMAP_00003438 +BDMAP_00001957 +BDMAP_00004128 +BDMAP_00005140 +BDMAP_00002656 +BDMAP_00004817 +BDMAP_00000745 +BDMAP_00000205 +BDMAP_00000671 +BDMAP_00001962 +BDMAP_00003543 +BDMAP_00001620 +BDMAP_00003128 +BDMAP_00003409 +BDMAP_00000982 +BDMAP_00004015 +BDMAP_00001707 +BDMAP_00002068 +BDMAP_00001236 +BDMAP_00003973 +BDMAP_00004870 +BDMAP_00000366 +BDMAP_00003685 +BDMAP_00001096 +BDMAP_00003347 +BDMAP_00001892 +BDMAP_00003740 +BDMAP_00004773 +BDMAP_00002260 +BDMAP_00002815 +BDMAP_00000972 +BDMAP_00000998 +BDMAP_00003063 +BDMAP_00001791 +BDMAP_00002085 +BDMAP_00002275 +BDMAP_00004016 +BDMAP_00000438 +BDMAP_00000709 +BDMAP_00004416 +BDMAP_00003884 +BDMAP_00002237 +BDMAP_00001794 +BDMAP_00004378 +BDMAP_00000713 +BDMAP_00004286 +BDMAP_00001109 +BDMAP_00001223 +BDMAP_00001027 +BDMAP_00001001 +BDMAP_00005097 +BDMAP_00002942 +BDMAP_00000607 +BDMAP_00002940 +BDMAP_00002930 +BDMAP_00003377 +BDMAP_00004509 +BDMAP_00000923 +BDMAP_00001413 +BDMAP_00001636 +BDMAP_00001705 +BDMAP_00000273 +BDMAP_00003840 +BDMAP_00001333 +BDMAP_00005092 +BDMAP_00001368 +BDMAP_00003994 +BDMAP_00004925 +BDMAP_00001370 +BDMAP_00003455 +BDMAP_00002631 +BDMAP_00005174 +BDMAP_00005009 +BDMAP_00001549 +BDMAP_00001941 +BDMAP_00000154 +BDMAP_00001521 +BDMAP_00002653 +BDMAP_00001148 +BDMAP_00000774 +BDMAP_00005105 +BDMAP_00002421 +BDMAP_00000139 +BDMAP_00003867 +BDMAP_00003479 +BDMAP_00004741 +BDMAP_00001516 +BDMAP_00002396 +BDMAP_00003481 +BDMAP_00000324 +BDMAP_00002841 +BDMAP_00003326 +BDMAP_00002437 +BDMAP_00000100 +BDMAP_00004586 +BDMAP_00004867 +BDMAP_00001040 +BDMAP_00001185 +BDMAP_00001461 +BDMAP_00000692 +BDMAP_00001563 +BDMAP_00002289 +BDMAP_00004901 +BDMAP_00001632 +BDMAP_00000558 +BDMAP_00000469 +BDMAP_00001966 +BDMAP_00003315 +BDMAP_00002313 +BDMAP_00005006 +BDMAP_00000439 +BDMAP_00004551 +BDMAP_00003294 +BDMAP_00001807 +BDMAP_00004579 +BDMAP_00002057 +BDMAP_00002060 +BDMAP_00004508 +BDMAP_00004104 +BDMAP_00000052 +BDMAP_00003439 +BDMAP_00001502 +BDMAP_00005186 +BDMAP_00002529 +BDMAP_00002775 +BDMAP_00004834 +BDMAP_00001496 +BDMAP_00002319 +BDMAP_00002856 +BDMAP_00004552 +BDMAP_00004878 +BDMAP_00001331 +BDMAP_00001912 +BDMAP_00002758 +BDMAP_00000414 +BDMAP_00004288 +BDMAP_00000805 +BDMAP_00004597 +BDMAP_00003178 +BDMAP_00001752 +BDMAP_00003943 +BDMAP_00004652 +BDMAP_00004541 +BDMAP_00000614 +BDMAP_00004639 +BDMAP_00001804 +BDMAP_00005063 +BDMAP_00002807 +BDMAP_00000062 +BDMAP_00005119 +BDMAP_00004417 +BDMAP_00005075 +BDMAP_00001441 +BDMAP_00002373 +BDMAP_00002041 +BDMAP_00003727 +BDMAP_00001483 +BDMAP_00001128 +BDMAP_00004927 +BDMAP_00001119 +BDMAP_00004106 +BDMAP_00000355 +BDMAP_00002354 +BDMAP_00004030 +BDMAP_00004847 +BDMAP_00000618 +BDMAP_00003736 +BDMAP_00002803 +BDMAP_00005099 +BDMAP_00003168 +BDMAP_00000941 +BDMAP_00000243 +BDMAP_00001664 +BDMAP_00001747 +BDMAP_00003774 +BDMAP_00004917 +BDMAP_00000867 +BDMAP_00000435 +BDMAP_00003822 +BDMAP_00003411 +BDMAP_00000965 +BDMAP_00003612 +BDMAP_00004023 +BDMAP_00002333 +BDMAP_00001270 +BDMAP_00002616 +BDMAP_00004511 +BDMAP_00005130 +BDMAP_00000642 +BDMAP_00002471 +BDMAP_00000589 +BDMAP_00002509 +BDMAP_00004561 +BDMAP_00001275 +BDMAP_00003133 +BDMAP_00000626 +BDMAP_00003491 +BDMAP_00000993 +BDMAP_00003493 +BDMAP_00004499 +BDMAP_00002065 +BDMAP_00001175 +BDMAP_00002696 +BDMAP_00000319 +BDMAP_00002410 +BDMAP_00002485 +BDMAP_00001258 +BDMAP_00000660 +BDMAP_00003272 +BDMAP_00004183 +BDMAP_00003359 +BDMAP_00000956 +BDMAP_00004462 +BDMAP_00001704 +BDMAP_00000039 +BDMAP_00001853 +BDMAP_00003857 +BDMAP_00000572 +BDMAP_00005168 +BDMAP_00000304 +BDMAP_00002426 +BDMAP_00000244 +BDMAP_00001646 +BDMAP_00000413 +BDMAP_00004735 +BDMAP_00002476 +BDMAP_00004039 +BDMAP_00000219 +BDMAP_00004651 +BDMAP_00005065 +BDMAP_00004281 +BDMAP_00000113 +BDMAP_00003956 +BDMAP_00002226 +BDMAP_00004130 +BDMAP_00002707 +BDMAP_00000430 +BDMAP_00002661 +BDMAP_00001617 +BDMAP_00002298 +BDMAP_00003930 +BDMAP_00000687 +BDMAP_00004195 +BDMAP_00001647 +BDMAP_00000487 +BDMAP_00003367 +BDMAP_00003277 +BDMAP_00004600 +BDMAP_00003497 +BDMAP_00004546 +BDMAP_00004808 +BDMAP_00002981 +BDMAP_00000229 +BDMAP_00004185 +BDMAP_00003406 +BDMAP_00002422 +BDMAP_00002947 +BDMAP_00001261 +BDMAP_00005037 +BDMAP_00003590 +BDMAP_00003058 +BDMAP_00003461 +BDMAP_00003151 +BDMAP_00001035 +BDMAP_00001289 +BDMAP_00000087 +BDMAP_00004981 +BDMAP_00001836 +BDMAP_00004712 +BDMAP_00002363 +BDMAP_00002495 +BDMAP_00004398 +BDMAP_00003457 +BDMAP_00003752 +BDMAP_00001891 +BDMAP_00004373 +BDMAP_00001590 +BDMAP_00003506 +BDMAP_00001921 +BDMAP_00004229 +BDMAP_00001898 +BDMAP_00003483 +BDMAP_00004616 +BDMAP_00002648 +BDMAP_00000562 +BDMAP_00002403 +BDMAP_00003361 +BDMAP_00000887 +BDMAP_00001283 +BDMAP_00002719 +BDMAP_00005064 +BDMAP_00002793 +BDMAP_00002242 +BDMAP_00004278 +BDMAP_00002117 +BDMAP_00000320 +BDMAP_00005191 +BDMAP_00000809 +BDMAP_00000859 +BDMAP_00003955 +BDMAP_00004253 +BDMAP_00004031 +BDMAP_00005139 +BDMAP_00003244 +BDMAP_00000149 +BDMAP_00001414 +BDMAP_00001945 +BDMAP_00004510 +BDMAP_00003824 +BDMAP_00001361 +BDMAP_00000662 +BDMAP_00005022 +BDMAP_00000434 +BDMAP_00000241 +BDMAP_00000710 +BDMAP_00005120 +BDMAP_00002383 +BDMAP_00003036 +BDMAP_00002609 +BDMAP_00004922 +BDMAP_00004407 +BDMAP_00004481 +BDMAP_00001225 +BDMAP_00003556 +BDMAP_00000329 +BDMAP_00003052 +BDMAP_00003396 +BDMAP_00002164 +BDMAP_00001077 +BDMAP_00003153 +BDMAP_00003776 +BDMAP_00002710 +BDMAP_00004746 +BDMAP_00000066 +BDMAP_00005085 +BDMAP_00004435 +BDMAP_00002695 +BDMAP_00001828 +BDMAP_00003392 +BDMAP_00003976 +BDMAP_00002744 +BDMAP_00002214 +BDMAP_00000569 +BDMAP_00000571 +BDMAP_00004888 +BDMAP_00003301 +BDMAP_00004956 +BDMAP_00003809 +BDMAP_00002265 +BDMAP_00002944 +BDMAP_00004457 +BDMAP_00001768 +BDMAP_00001020 +BDMAP_00000541 +BDMAP_00000101 +BDMAP_00003664 +BDMAP_00003255 +BDMAP_00001379 +BDMAP_00002347 +BDMAP_00000128 +BDMAP_00002252 +BDMAP_00001697 +BDMAP_00002953 +BDMAP_00001122 +BDMAP_00003525 +BDMAP_00003070 +BDMAP_00004829 +BDMAP_00002233 +BDMAP_00001288 +BDMAP_00002791 +BDMAP_00004199 +BDMAP_00004184 +BDMAP_00003381 +BDMAP_00001766 +BDMAP_00003114 +BDMAP_00004804 +BDMAP_00002184 +BDMAP_00001138 +BDMAP_00000044 +BDMAP_00002271 +BDMAP_00003603 +BDMAP_00001523 +BDMAP_00004097 +BDMAP_00002440 +BDMAP_00004664 +BDMAP_00003808 +BDMAP_00000427 +BDMAP_00002362 +BDMAP_00005169 +BDMAP_00000023 +BDMAP_00003833 +BDMAP_00001710 +BDMAP_00001518 +BDMAP_00004482 +BDMAP_00003549 +BDMAP_00002171 +BDMAP_00002309 +BDMAP_00000338 +BDMAP_00000715 +BDMAP_00003897 +BDMAP_00003812 +BDMAP_00004257 +BDMAP_00001753 +BDMAP_00000117 +BDMAP_00001456 +BDMAP_00004115 +BDMAP_00003319 +BDMAP_00003744 +BDMAP_00004154 +BDMAP_00003658 +BDMAP_00001214 +BDMAP_00004293 +BDMAP_00001842 +BDMAP_00001420 +BDMAP_00003343 +BDMAP_00001325 +BDMAP_00000921 +BDMAP_00002582 +BDMAP_00002864 +BDMAP_00000889 +BDMAP_00001092 +BDMAP_00000968 +BDMAP_00002402 +BDMAP_00004427 +BDMAP_00001605 +BDMAP_00000462 +BDMAP_00005081 +BDMAP_00002463 +BDMAP_00000839 +BDMAP_00000437 +BDMAP_00000604 +BDMAP_00001104 +BDMAP_00001281 +BDMAP_00000679 +BDMAP_00004717 +BDMAP_00001511 +BDMAP_00003281 +BDMAP_00001977 +BDMAP_00000653 +BDMAP_00000232 +BDMAP_00004328 +BDMAP_00002496 +BDMAP_00000987 +BDMAP_00003717 +BDMAP_00004897 +BDMAP_00003713 +BDMAP_00002889 +BDMAP_00003657 +BDMAP_00002829 +BDMAP_00004839 +BDMAP_00001397 +BDMAP_00001908 +BDMAP_00003911 +BDMAP_00004843 +BDMAP_00004969 +BDMAP_00003918 +BDMAP_00004216 +BDMAP_00000034 +BDMAP_00003923 +BDMAP_00000225 +BDMAP_00003576 +BDMAP_00002884 +BDMAP_00002472 +BDMAP_00001688 +BDMAP_00001246 +BDMAP_00004620 +BDMAP_00005017 +BDMAP_00002990 +BDMAP_00000971 +BDMAP_00004578 +BDMAP_00001735 +BDMAP_00002655 +BDMAP_00000233 +BDMAP_00001205 +BDMAP_00003073 +BDMAP_00003957 +BDMAP_00001093 +BDMAP_00003440 +BDMAP_00001251 +BDMAP_00004793 +BDMAP_00000162 +BDMAP_00003444 +BDMAP_00001533 +BDMAP_00003971 +BDMAP_00001584 +BDMAP_00000036 +BDMAP_00002251 +BDMAP_00003141 +BDMAP_00002484 +BDMAP_00004770 +BDMAP_00001487 +BDMAP_00001754 +BDMAP_00003356 +BDMAP_00000353 +BDMAP_00001419 +BDMAP_00001802 +BDMAP_00003701 +BDMAP_00005141 +BDMAP_00000321 +BDMAP_00001746 +BDMAP_00000364 +BDMAP_00003900 +BDMAP_00001995 +BDMAP_00001025 +BDMAP_00004231 +BDMAP_00000918 +BDMAP_00001130 +BDMAP_00003443 +BDMAP_00003215 +BDMAP_00004815 +BDMAP_00002933 +BDMAP_00000192 +BDMAP_00003615 +BDMAP_00004704 +BDMAP_00001218 +BDMAP_00002295 +BDMAP_00000429 +BDMAP_00000532 +BDMAP_00001474 +BDMAP_00003961 +BDMAP_00004129 +BDMAP_00000362 +BDMAP_00002863 +BDMAP_00003267 +BDMAP_00001198 +BDMAP_00000259 +BDMAP_00000683 +BDMAP_00001256 +BDMAP_00003252 +BDMAP_00004475 +BDMAP_00004250 +BDMAP_00004887 +BDMAP_00000240 +BDMAP_00003767 +BDMAP_00003427 +BDMAP_00000043 +BDMAP_00003448 +BDMAP_00001114 +BDMAP_00001067 +BDMAP_00001089 +BDMAP_00002133 +BDMAP_00004033 +BDMAP_00002896 +BDMAP_00003138 +BDMAP_00001010 +BDMAP_00001059 +BDMAP_00004990 +BDMAP_00000936 +BDMAP_00001359 +BDMAP_00005077 +BDMAP_00000582 +BDMAP_00004296 +BDMAP_00005114 +BDMAP_00004389 +BDMAP_00004673 +BDMAP_00003254 +BDMAP_00003516 +BDMAP_00001475 +BDMAP_00002580 +BDMAP_00002689 +BDMAP_00004671 +BDMAP_00003762 +BDMAP_00003330 +BDMAP_00002188 +BDMAP_00001736 +BDMAP_00002404 +BDMAP_00003502 +BDMAP_00004117 +BDMAP_00004964 +BDMAP_00002742 +BDMAP_00000093 +BDMAP_00002361 +BDMAP_00000794 +BDMAP_00002349 +BDMAP_00001273 +BDMAP_00000449 +BDMAP_00001628 +BDMAP_00002250 +BDMAP_00004479 +BDMAP_00000608 +BDMAP_00001834 +BDMAP_00002267 +BDMAP_00001125 +BDMAP_00000447 +BDMAP_00005113 +BDMAP_00004014 +BDMAP_00001701 +BDMAP_00003952 +BDMAP_00003520 +BDMAP_00000347 +BDMAP_00002936 +BDMAP_00001624 +BDMAP_00000104 +BDMAP_00002487 +BDMAP_00005020 +BDMAP_00000511 +BDMAP_00001564 +BDMAP_00004294 +BDMAP_00004858 +BDMAP_00004608 +BDMAP_00003650 +BDMAP_00001171 +BDMAP_00000059 +BDMAP_00000871 +BDMAP_00003996 +BDMAP_00001169 +BDMAP_00003363 +BDMAP_00003376 +BDMAP_00002167 +BDMAP_00002737 +BDMAP_00003694 +BDMAP_00001396 +BDMAP_00005083 +BDMAP_00002918 +BDMAP_00003580 +BDMAP_00001324 +BDMAP_00002855 +BDMAP_00001649 +BDMAP_00004459 +BDMAP_00002288 +BDMAP_00004830 +BDMAP_00004775 +BDMAP_00000279 +BDMAP_00002114 +BDMAP_00005185 +BDMAP_00004885 +BDMAP_00000236 +BDMAP_00003928 +BDMAP_00002663 +BDMAP_00002282 +BDMAP_00003798 +BDMAP_00001055 +BDMAP_00002945 +BDMAP_00001316 +BDMAP_00003451 +BDMAP_00000696 +BDMAP_00003143 +BDMAP_00001522 +BDMAP_00000452 +BDMAP_00002603 +BDMAP_00004131 +BDMAP_00001045 +BDMAP_00004954 +BDMAP_00003358 +BDMAP_00000980 +BDMAP_00001343 +BDMAP_00001410 +BDMAP_00002173 +BDMAP_00002840 +BDMAP_00001200 +BDMAP_00001905 +BDMAP_00003425 +BDMAP_00003672 +BDMAP_00003781 +BDMAP_00001906 +BDMAP_00004636 +BDMAP_00000836 +BDMAP_00002076 +BDMAP_00001230 +BDMAP_00003932 +BDMAP_00002029 +BDMAP_00000331 +BDMAP_00000197 +BDMAP_00001539 +BDMAP_00003524 +BDMAP_00001692 +BDMAP_00004550 +BDMAP_00004331 +BDMAP_00004825 +BDMAP_00002411 +BDMAP_00003484 +BDMAP_00000480 +BDMAP_00003886 +BDMAP_00001907 +BDMAP_00002746 +BDMAP_00002899 +BDMAP_00004420 +BDMAP_00002748 +BDMAP_00003002 +BDMAP_00003300 +BDMAP_00005108 +BDMAP_00003400 +BDMAP_00002283 +BDMAP_00003486 +BDMAP_00000935 +BDMAP_00001618 +BDMAP_00001565 +BDMAP_00000616 +BDMAP_00000810 +BDMAP_00001826 +BDMAP_00000470 +BDMAP_00002017 +BDMAP_00003631 +BDMAP_00001242 +BDMAP_00000907 +BDMAP_00001806 +BDMAP_00002854 +BDMAP_00002902 +BDMAP_00003017 +BDMAP_00001471 diff --git a/Generation_Pipeline_filter/syn_colon/requirements.txt b/Generation_Pipeline_filter/syn_colon/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9c1067c8eb4f44699bbf517601396113cea4b370 --- /dev/null +++ b/Generation_Pipeline_filter/syn_colon/requirements.txt @@ -0,0 +1,94 @@ +absl-py==1.1.0 +accelerate==0.11.0 +aiohttp==3.8.1 +aiosignal==1.2.0 +antlr4-python3-runtime==4.9.3 +async-timeout==4.0.2 +attrs==21.4.0 +autopep8==1.6.0 +cachetools==5.2.0 +certifi==2022.6.15 +charset-normalizer==2.0.12 +click==8.1.3 +cycler==0.11.0 +Deprecated==1.2.13 +docker-pycreds==0.4.0 +einops==0.4.1 +einops-exts==0.0.3 +ema-pytorch==0.0.8 +fonttools==4.34.4 +frozenlist==1.3.0 +fsspec==2022.5.0 +ftfy==6.1.1 +future==0.18.2 +gitdb==4.0.9 +GitPython==3.1.27 +google-auth==2.9.0 +google-auth-oauthlib==0.4.6 +grpcio==1.47.0 +h5py==3.7.0 +humanize==4.2.2 +hydra-core==1.2.0 +idna==3.3 +imageio==2.19.3 +imageio-ffmpeg==0.4.7 +importlib-metadata==4.12.0 +importlib-resources==5.9.0 +joblib==1.1.0 +kiwisolver==1.4.3 +lxml==4.9.1 +Markdown==3.3.7 +matplotlib==3.5.2 +multidict==6.0.2 +networkx==2.8.5 +nibabel==4.0.1 +nilearn==0.9.1 +numpy==1.23.0 +oauthlib==3.2.0 +omegaconf==2.2.3 +pandas==1.4.3 +Pillow==9.1.1 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pycodestyle==2.8.0 +pyDeprecate==0.3.1 +pydicom==2.3.0 +pytorch-lightning==1.6.4 +pytz==2022.1 +PyWavelets==1.3.0 +PyYAML==6.0 +pyzmq==19.0.2 +regex==2022.6.2 +requests==2.28.0 +requests-oauthlib==1.3.1 +rotary-embedding-torch==0.1.5 +rsa==4.8 +scikit-image==0.19.3 +scikit-learn==1.1.2 +scikit-video==1.1.11 +scipy==1.8.1 +seaborn==0.11.2 +sentry-sdk==1.7.2 +setproctitle==1.2.3 +shortuuid==1.0.9 +SimpleITK==2.1.1.2 +smmap==5.0.0 +tensorboard==2.9.1 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.1 +threadpoolctl==3.1.0 +tifffile==2022.8.3 +toml==0.10.2 +torch-tb-profiler==0.4.0 +torchio==0.18.80 +torchmetrics==0.9.1 +tqdm==4.64.0 +typing_extensions==4.2.0 +urllib3==1.26.9 +wandb==0.12.21 +Werkzeug==2.1.2 +wrapt==1.14.1 +yarl==1.7.2 +zipp==3.8.0 +wandb +tensorboardX==2.4.1 diff --git a/Generation_Pipeline_filter/syn_kidney/CT_syn_kidney_data_new.py b/Generation_Pipeline_filter/syn_kidney/CT_syn_kidney_data_new.py new file mode 100644 index 0000000000000000000000000000000000000000..9d62e38f79497f7ad3e89a9c5c12f40ea02bf4ca --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/CT_syn_kidney_data_new.py @@ -0,0 +1,241 @@ +import os, time, csv +import numpy as np +import torch +from sklearn.metrics import confusion_matrix +from scipy import ndimage +from scipy.ndimage import label +from functools import partial +import monai +from monai.transforms import AsDiscrete,AsDiscreted,Compose,Invertd,SaveImaged +from monai import transforms, data +from TumorGeneration.utils import synthesize_early_tumor, synthesize_medium_tumor, synthesize_large_tumor, synt_model_prepare +import nibabel as nib + +import warnings +warnings.filterwarnings("ignore") + +import argparse +parser = argparse.ArgumentParser(description='kidney tumor validation') + +# file dir +parser.add_argument('--data_root', default=None, type=str) +parser.add_argument('--organ_type', default='kidney', type=str) +parser.add_argument('--save_dir', default='out', type=str) +parser.add_argument('--data_file', default='out', type=str) +parser.add_argument('--ddim_ts', default=50, type=int) +parser.add_argument('--fg_thresh', default=30, type=int) +parser.add_argument('--start', default=0, type=int) +parser.add_argument('--end', default=1000, type=int) + +def voxel2R(A): + return (np.array(A)/4*3/np.pi)**(1/3) + +class RandCropByPosNegLabeld_select(transforms.RandCropByPosNegLabeld): + def __init__(self, keys, label_key, spatial_size, + pos=1.0, neg=1.0, num_samples=1, + image_key=None, image_threshold=0.0, allow_missing_keys=True, + fg_thresh=0): + super().__init__(keys=keys, label_key=label_key, spatial_size=spatial_size, + pos=pos, neg=neg, num_samples=num_samples, + image_key=image_key, image_threshold=image_threshold, allow_missing_keys=allow_missing_keys) + self.fg_thresh = fg_thresh + + def R2voxel(self,R): + return (4/3*np.pi)*(R)**(3) + + def __call__(self, data): + d = dict(data) + data_name = d['name'] + d.pop('name') + + if '10_Decathlon' in data_name or '05_KiTS' in data_name: + d_crop = super().__call__(d) + + else: + flag=0 + while 1: + flag+=1 + + d_crop = super().__call__(d) + pixel_num = (d_crop[0]['label']>0).sum() + + if pixel_num > self.R2voxel(self.fg_thresh): + break + if flag>5 and pixel_num > self.R2voxel(max(self.fg_thresh-5, 5)): + break + if flag>10 and pixel_num > self.R2voxel(max(self.fg_thresh-10, 5)): + break + if flag>15 and pixel_num > self.R2voxel(max(self.fg_thresh-15, 5)): + break + if flag>20 and pixel_num > self.R2voxel(max(self.fg_thresh-20, 5)): + break + if flag>25 and pixel_num > self.R2voxel(max(self.fg_thresh-25, 5)): + break + if flag>50: + break + + d_crop[0]['name'] = data_name + + return d_crop + +def _get_loader(args): + # val_data_dir = args.val_dir + # datalist_json = args.json_dir + val_org_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label", "raw_image"]), + transforms.AddChanneld(keys=["image", "label", "raw_image"]), + transforms.Orientationd(keys=["image", "label"], axcodes="RAS"), + transforms.Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear")), + transforms.ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), + transforms.SpatialPadd(keys=["image", "label"], mode=["minimum", "constant"], spatial_size=[96, 96, 96]), + RandCropByPosNegLabeld_select( + keys=["image", "label", "name"], + label_key="label", + spatial_size=(96, 96, 96), + pos=1, + neg=0, + num_samples=1, + image_key="image", + image_threshold=0, + fg_thresh = args.fg_thresh, + ), + transforms.ToTensord(keys=["image", "label", "raw_image"]), + ] + ) + + val_img=[] + val_lbl=[] + val_name=[] + + for line in open(args.data_file): + # name = line.strip().split()[1].split('.')[0] + # val_img.append(args.data_root + line.strip().split()[0]) + # val_lbl.append(args.data_root + line.strip().split()[1]) + # breakpoint() + name = line.strip() + val_img.append(os.path.join(args.data_root, name, 'ct.nii.gz')) + val_lbl.append(os.path.join(args.data_root, name, 'segmentations/kidney_left.nii.gz')) + val_name.append(name) + data_dicts_val = [{'image': image, 'raw_image':image, 'label': label, 'name': name} + for image, label, name in zip(val_img, val_lbl, val_name)] + + if args.end < len(data_dicts_val): + data_dicts_val = data_dicts_val[args.start:args.end] + else: + data_dicts_val = data_dicts_val[args.start:] + print('val len {}'.format(len(data_dicts_val))) + val_org_ds = data.Dataset(data_dicts_val, transform=val_org_transform) + val_org_loader = data.DataLoader(val_org_ds, batch_size=1, shuffle=False, num_workers=4, sampler=None, pin_memory=True) + + post_transforms = Compose([ + Invertd( + keys=['image'], + transform=val_org_transform, + orig_keys="image", + nearest_interp=False, + # nearest_interp=True, + to_tensor=True, + ), + Invertd( + keys=['label'], + transform=val_org_transform, + orig_keys="label", + nearest_interp=False, + # nearest_interp=True, + to_tensor=True, + ) + ]) + return val_org_loader, post_transforms + +def main(): + args = parser.parse_args() + output_dir = args.save_dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + print("MAIN Argument values:") + for k, v in vars(args).items(): + print(k, '=>', v) + print('-----------------') + + ## loader and post_transform + val_loader, post_transforms = _get_loader(args) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device) + model.load_state_dict(torch.load("../best_metric_model_classification3d_dict.pth")) + model.eval() + + start_time=0 + with torch.no_grad(): + for idx, val_data in enumerate(val_loader): + print('idx',idx) + if idx == 0: + start_time = time.time() + # val_inputs = val_data["image"] + # name = val_data['label_meta_dict']['filename_or_obj'][0].split('/')[-1].split('.')[0] + + vqgan, early_sampler, noearly_sampler= synt_model_prepare(device = torch.device("cuda"), fold=0, organ=args.organ_type) + + healthy_data, healthy_target, data_names, raw_data = val_data['image'], val_data['label'], val_data['name'], val_data['raw_image'] + case_name = data_names[0].split('/')[-1] + print('case_name', case_name) + original_affine = val_data["label_meta_dict"]["original_affine"][0].numpy() + + if healthy_target.sum() == 0: + val_data = [post_transforms(i) for i in data.decollate_batch(val_data)] + tumor_mask = val_data[0]['label'][0].cpu().numpy().astype(np.uint8) + tumor_mask_ = np.zeros_like(tumor_mask) + nib.save(nib.Nifti1Image(tumor_mask_, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/kidney_tumor.nii.gz')) + continue + + healthy_data, healthy_target = healthy_data.cuda(), healthy_target.cuda() + healthy_target = (healthy_target>0).to(healthy_target) + + tumor_types = ['early', 'medium', 'large'] + # tumor_probs = np.array([0.45, 0.45, 0.1]) + # tumor_probs = np.array([1.0, 0.0, 0.0]) + tumor_probs = np.array([0.5, 0.4, 0.1]) + synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + print('synthetic_tumor_type',synthetic_tumor_type) + flag=0 + while 1: + flag+=1 + if synthetic_tumor_type == 'early': + synt_data, synt_target = synthesize_early_tumor(healthy_data, healthy_target, args.organ_type, vqgan, early_sampler) + elif synthetic_tumor_type == 'medium': + synt_data, synt_target = synthesize_medium_tumor(healthy_data, healthy_target, args.organ_type, vqgan, noearly_sampler, ddim_ts=args.ddim_ts) + elif synthetic_tumor_type == 'large': + synt_data, synt_target = synthesize_large_tumor(healthy_data, healthy_target, args.organ_type, vqgan, noearly_sampler, ddim_ts=args.ddim_ts) + + syn_confidence = model(synt_data).sigmoid()[:,1] + if syn_confidence>0.01: + break + elif flag > 40 and syn_confidence>0.005: + break + elif flag > 60 and syn_confidence>0.001: + break + val_data['image'] = synt_data.detach() + val_data['label'] = synt_target.detach() + + val_data = [post_transforms(i) for i in data.decollate_batch(val_data)] + synt_data = val_data[0]['image'][0] + synt_target = val_data[0]['label'][0] + final_data = raw_data[0,0] + + synt_data = (synt_data*(250+175)-175) + final_data[synt_target>1] = synt_data[synt_target>1] + final_data = final_data.cpu().numpy() + final_label = (synt_target>=1.5).cpu().numpy().astype(np.uint8) + + os.makedirs(os.path.join(output_dir, f'{case_name}'), exist_ok=True) + os.makedirs(os.path.join(output_dir, f'{case_name}/segmentations'), exist_ok=True) + nib.save(nib.Nifti1Image(final_data, original_affine), os.path.join(output_dir, f'{case_name}/ct.nii.gz')) + nib.save(nib.Nifti1Image(final_label, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/kidney_tumor.nii.gz')) + + print('time = ', time.time()-start_time) + start_time = time.time() + + +if __name__ == "__main__": + main() diff --git a/Generation_Pipeline_filter/syn_kidney/CT_syn_kidney_data_new2.py b/Generation_Pipeline_filter/syn_kidney/CT_syn_kidney_data_new2.py new file mode 100644 index 0000000000000000000000000000000000000000..e389a274f4260e8481ccf67e879eec4160987a3c --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/CT_syn_kidney_data_new2.py @@ -0,0 +1,251 @@ +import os, time, csv +import numpy as np +import torch +from sklearn.metrics import confusion_matrix +from scipy import ndimage +from scipy.ndimage import label +from functools import partial +import monai +from monai.transforms import AsDiscrete,AsDiscreted,Compose,Invertd,SaveImaged +from monai import transforms, data +from TumorGeneration.utils import synthesize_early_tumor, synthesize_medium_tumor, synthesize_large_tumor, synt_model_prepare +import nibabel as nib + +import warnings +warnings.filterwarnings("ignore") + +import argparse +parser = argparse.ArgumentParser(description='kidney tumor validation') + +# file dir +parser.add_argument('--data_root', default=None, type=str) +parser.add_argument('--organ_type', default='kidney', type=str) +parser.add_argument('--save_dir', default='out', type=str) +parser.add_argument('--data_file', default='out', type=str) +parser.add_argument('--ddim_ts', default=50, type=int) +parser.add_argument('--fg_thresh', default=30, type=int) +parser.add_argument('--start', default=0, type=int) +parser.add_argument('--end', default=1000, type=int) + +def voxel2R(A): + return (np.array(A)/4*3/np.pi)**(1/3) + +class RandCropByPosNegLabeld_select(transforms.RandCropByPosNegLabeld): + def __init__(self, keys, label_key, spatial_size, + pos=1.0, neg=1.0, num_samples=1, + image_key=None, image_threshold=0.0, allow_missing_keys=True, + fg_thresh=0): + super().__init__(keys=keys, label_key=label_key, spatial_size=spatial_size, + pos=pos, neg=neg, num_samples=num_samples, + image_key=image_key, image_threshold=image_threshold, allow_missing_keys=allow_missing_keys) + self.fg_thresh = fg_thresh + + def R2voxel(self,R): + return (4/3*np.pi)*(R)**(3) + + def __call__(self, data): + d = dict(data) + data_name = d['name'] + d.pop('name') + + if '10_Decathlon' in data_name or '05_KiTS' in data_name: + d_crop = super().__call__(d) + + else: + flag=0 + while 1: + flag+=1 + + d_crop = super().__call__(d) + pixel_num = (d_crop[0]['label']>0).sum() + + if pixel_num > self.R2voxel(self.fg_thresh): + break + if flag>5 and pixel_num > self.R2voxel(max(self.fg_thresh-5, 5)): + break + if flag>10 and pixel_num > self.R2voxel(max(self.fg_thresh-10, 5)): + break + if flag>15 and pixel_num > self.R2voxel(max(self.fg_thresh-15, 5)): + break + if flag>20 and pixel_num > self.R2voxel(max(self.fg_thresh-20, 5)): + break + if flag>25 and pixel_num > self.R2voxel(max(self.fg_thresh-25, 5)): + break + if flag>50: + break + + d_crop[0]['name'] = data_name + + return d_crop + +def _get_loader(args): + # val_data_dir = args.val_dir + # datalist_json = args.json_dir + val_org_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label", 'tumor_label', "raw_image"]), + transforms.AddChanneld(keys=["image", "label", 'tumor_label', "raw_image"]), + transforms.Orientationd(keys=["image", "label", 'tumor_label'], axcodes="RAS"), + transforms.Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear")), + transforms.ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), + transforms.SpatialPadd(keys=["image", "label"], mode=["minimum", "constant"], spatial_size=[96, 96, 96]), + RandCropByPosNegLabeld_select( + keys=["image", "label", "name"], + label_key="label", + spatial_size=(96, 96, 96), + pos=1, + neg=0, + num_samples=1, + image_key="image", + image_threshold=0, + fg_thresh = args.fg_thresh, + ), + transforms.ToTensord(keys=["image", "label", 'tumor_label', "raw_image"]), + ] + ) + + val_img=[] + val_lbl=[] + val_name=[] + tumor_lbl=[] + for line in open(args.data_file): + # name = line.strip().split()[1].split('.')[0] + # val_img.append(args.data_root + line.strip().split()[0]) + # val_lbl.append(args.data_root + line.strip().split()[1]) + # breakpoint() + name = line.strip() + val_img.append(os.path.join(args.data_root, name, 'ct.nii.gz')) + val_lbl.append(os.path.join(args.data_root, name, 'segmentations/kidney_right.nii.gz')) + tumor_lbl.append(os.path.join(args.data_root, name, 'segmentations/kidney_tumor.nii.gz')) + val_name.append(name) + data_dicts_val = [{'image': image, 'raw_image':image, 'label': label, 'tumor_label':tumor_label,'name': name} + for image, label, tumor_label, name in zip(val_img, val_lbl, tumor_lbl, val_name)] + + if args.end < len(data_dicts_val): + data_dicts_val = data_dicts_val[args.start:args.end] + else: + data_dicts_val = data_dicts_val[args.start:] + print('val len {}'.format(len(data_dicts_val))) + val_org_ds = data.Dataset(data_dicts_val, transform=val_org_transform) + val_org_loader = data.DataLoader(val_org_ds, batch_size=1, shuffle=False, num_workers=4, sampler=None, pin_memory=True) + + post_transforms = Compose([ + Invertd( + keys=['image'], + transform=val_org_transform, + orig_keys="image", + nearest_interp=False, + # nearest_interp=True, + to_tensor=True, + ), + Invertd( + keys=['label'], + transform=val_org_transform, + orig_keys="label", + nearest_interp=False, + # nearest_interp=True, + to_tensor=True, + ), + Invertd( + keys=['tumor_label'], + transform=val_org_transform, + orig_keys="tumor_label", + nearest_interp=False, + to_tensor=True, + ) + ]) + return val_org_loader, post_transforms + +def main(): + args = parser.parse_args() + output_dir = args.save_dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + print("MAIN Argument values:") + for k, v in vars(args).items(): + print(k, '=>', v) + print('-----------------') + + ## loader and post_transform + val_loader, post_transforms = _get_loader(args) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device) + model.load_state_dict(torch.load("../best_metric_model_classification3d_dict.pth")) + model.eval() + + start_time=0 + with torch.no_grad(): + for idx, val_data in enumerate(val_loader): + print('idx',idx) + if idx == 0: + start_time = time.time() + # val_inputs = val_data["image"] + # name = val_data['label_meta_dict']['filename_or_obj'][0].split('/')[-1].split('.')[0] + + vqgan, early_sampler, noearly_sampler= synt_model_prepare(device = torch.device("cuda"), fold=0, organ=args.organ_type) + + healthy_data, healthy_target, data_names, raw_data = val_data['image'], val_data['label'], val_data['name'], val_data['raw_image'] + case_name = data_names[0].split('/')[-1] + print('case_name', case_name) + original_affine = val_data["label_meta_dict"]["original_affine"][0].numpy() + + if healthy_target.sum() == 0: + # val_data = [post_transforms(i) for i in data.decollate_batch(val_data)] + # tumor_mask = val_data[0]['label'][0].cpu().numpy().astype(np.uint8) + # tumor_mask_ = np.zeros_like(tumor_mask) + # nib.save(nib.Nifti1Image(tumor_mask_, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/kidney_tumor.nii.gz')) + continue + + healthy_data, healthy_target = healthy_data.cuda(), healthy_target.cuda() + healthy_target = (healthy_target>0).to(healthy_target) + + tumor_types = ['early', 'medium', 'large'] + # tumor_probs = np.array([0.45, 0.45, 0.1]) + # tumor_probs = np.array([1.0, 0.0, 0.0]) + tumor_probs = np.array([0.5, 0.4, 0.1]) + synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + print('synthetic_tumor_type',synthetic_tumor_type) + flag=0 + while 1: + flag+=1 + if synthetic_tumor_type == 'early': + synt_data, synt_target = synthesize_early_tumor(healthy_data, healthy_target, args.organ_type, vqgan, early_sampler) + elif synthetic_tumor_type == 'medium': + synt_data, synt_target = synthesize_medium_tumor(healthy_data, healthy_target, args.organ_type, vqgan, noearly_sampler, ddim_ts=args.ddim_ts) + elif synthetic_tumor_type == 'large': + synt_data, synt_target = synthesize_large_tumor(healthy_data, healthy_target, args.organ_type, vqgan, noearly_sampler, ddim_ts=args.ddim_ts) + + syn_confidence = model(synt_data).sigmoid()[:,1] + if syn_confidence>0.01: + break + elif flag > 40 and syn_confidence>0.005: + break + elif flag > 60 and syn_confidence>0.001: + break + val_data['image'] = synt_data.detach() + val_data['label'] = synt_target.detach() + + val_data = [post_transforms(i) for i in data.decollate_batch(val_data)] + synt_data = val_data[0]['image'][0] + synt_target = val_data[0]['label'][0] + tumor_mask = val_data[0]['tumor_label'][0] + final_data = raw_data[0,0] + + synt_data = (synt_data*(250+175)-175) + final_data[synt_target>1] = synt_data[synt_target>1] + final_data = final_data.cpu().numpy() + final_label = (synt_target>=1.5).cpu().numpy().astype(np.uint8) + final_label[tumor_mask==1] = 1 + + os.makedirs(os.path.join(output_dir, f'{case_name}'), exist_ok=True) + os.makedirs(os.path.join(output_dir, f'{case_name}/segmentations'), exist_ok=True) + nib.save(nib.Nifti1Image(final_data, original_affine), os.path.join(output_dir, f'{case_name}/ct.nii.gz')) + nib.save(nib.Nifti1Image(final_label, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/kidney_tumor.nii.gz')) + + print('time = ', time.time()-start_time) + start_time = time.time() + + +if __name__ == "__main__": + main() diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/.DS_Store b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..fb0d207e124cfecc777faae5b4a56c5ca1b9cd2e Binary files /dev/null and b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/.DS_Store differ diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/README.md b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/README.md new file mode 100644 index 0000000000000000000000000000000000000000..201911031ca988c410781a60dc3c192a65ee56b3 --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/README.md @@ -0,0 +1,5 @@ +```bash +wget https://www.dropbox.com/scl/fi/k856fhk60kck8uqxtxazw/model_weight.tar.gz?rlkey=hrcn4cbt690dzern1bkbfejll +mv model_weight.tar.gz?rlkey=hrcn4cbt690dzern1bkbfejll model_weight.tar.gz +tar -xzvf model_weight.tar.gz +``` diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/TumorGenerated.py b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/TumorGenerated.py new file mode 100644 index 0000000000000000000000000000000000000000..9983cf047b1532f5edd20c0d4c78102d6d611eb9 --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/TumorGenerated.py @@ -0,0 +1,39 @@ +import random +from typing import Hashable, Mapping, Dict + +from monai.config import KeysCollection +from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms.transform import MapTransform, RandomizableTransform + +from .utils_ import SynthesisTumor +import numpy as np + +class TumorGenerated(RandomizableTransform, MapTransform): + def __init__(self, + keys: KeysCollection, + prob: float = 0.1, + tumor_prob = [0.2, 0.2, 0.2, 0.2, 0.2], + allow_missing_keys: bool = False + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob) + random.seed(0) + np.random.seed(0) + + self.tumor_types = ['tiny', 'small', 'medium', 'large', 'mix'] + + assert len(tumor_prob) == 5 + self.tumor_prob = np.array(tumor_prob) + + + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + self.randomize(None) + + if self._do_transform and (np.max(d['label']) <= 1): + tumor_type = np.random.choice(self.tumor_types, p=self.tumor_prob.ravel()) + + d['image'][0], d['label'][0] = SynthesisTumor(d['image'][0], d['label'][0], tumor_type) + # print(tumor_type, d['image'].shape, np.max(d['label'])) + return d diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/__init__.py b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1afc8a195ba5fd106ca18d4e219c123a75e6e831 --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/__init__.py @@ -0,0 +1,5 @@ +### Online Version TumorGeneration ### + +from .TumorGenerated import TumorGenerated + +from .utils import synthesize_early_tumor, synthesize_medium_tumor, synthesize_large_tumor \ No newline at end of file diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8eba2478d809b697c49e4425ba2fc619b4554f12 Binary files /dev/null and b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4d7e4379e19d41aaf49dbc8394daeb6ab80b6bc Binary files /dev/null and b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/__pycache__/utils.cpython-38.pyc b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3065f278859daa8566ac262917eed6ba21daffd1 Binary files /dev/null and b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/__pycache__/utils.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/__pycache__/utils_.cpython-38.pyc b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/__pycache__/utils_.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c35e0a87fa9e862aaefa5d34991f85f12516a30 Binary files /dev/null and b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/__pycache__/utils_.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/diffusion_config/ddpm.yaml b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/diffusion_config/ddpm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..474a973b1f76c2a026e2df916c4a153b5bbea05f --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/diffusion_config/ddpm.yaml @@ -0,0 +1,29 @@ +vqgan_ckpt: None + +# Have to be derived from VQ-GAN Latent space dimensions +diffusion_img_size: 24 +diffusion_depth_size: 24 +diffusion_num_channels: 17 # 17 +out_dim: 8 +dim_mults: [1,2,4,8] +results_folder: checkpoints/ddpm/ +results_folder_postfix: 'own_dataset_t2' +load_milestone: False # False + +batch_size: 2 # 40 +num_workers: 20 +logger: wandb +objective: pred_x0 +save_and_sample_every: 1000 +denoising_fn: Unet3D +train_lr: 1e-4 +timesteps: 2 # number of steps +sampling_timesteps: 250 # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) +loss_type: l1 # L1 or L2 +train_num_steps: 700000 # total training steps +gradient_accumulate_every: 2 # gradient accumulation steps +ema_decay: 0.995 # exponential moving average decay +amp: False # turn on mixed precision +num_sample_rows: 1 +gpus: 0 + diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/diffusion_config/vq_gan_3d.yaml b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/diffusion_config/vq_gan_3d.yaml new file mode 100644 index 0000000000000000000000000000000000000000..29940377c5cadb0c06322de8ac60d0713f799024 --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/diffusion_config/vq_gan_3d.yaml @@ -0,0 +1,37 @@ +seed: 1234 +batch_size: 2 # 30 +num_workers: 32 # 30 + +gpus: 1 +accumulate_grad_batches: 1 +default_root_dir: checkpoints/vq_gan/ +default_root_dir_postfix: 'flair' +resume_from_checkpoint: +max_steps: -1 +max_epochs: -1 +precision: 16 +gradient_clip_val: 1.0 + + +embedding_dim: 8 # 256 +n_codes: 16384 # 2048 +n_hiddens: 16 +lr: 3e-4 +downsample: [2, 2, 2] # [4, 4, 4] +disc_channels: 64 +disc_layers: 3 +discriminator_iter_start: 10000 # 50000 +disc_loss_type: hinge +image_gan_weight: 1.0 +video_gan_weight: 1.0 +l1_weight: 4.0 +gan_feat_weight: 4.0 # 0.0 +perceptual_weight: 4.0 # 0.0 +i3d_feat: False +restart_thres: 1.0 +no_random_restart: False +norm_type: group +padding_type: replicate +num_groups: 32 + + diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__init__.py b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a01336e46e37471097d8e1420d0ffd5d803a1edd --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__init__.py @@ -0,0 +1 @@ +from .diffusion import Unet3D, GaussianDiffusion, Tester diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2e233d6ef842d8b6623b3c34f1efde66e12e471 Binary files /dev/null and b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7b9717ebc5153d3d8a33bbefbb8ac5a9e73f742 Binary files /dev/null and b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1854ef5db100c68b0b4add2f150825df3ae5eb4 Binary files /dev/null and b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ec4dc0860dc1632c385c75fa9723a8920ceac40 Binary files /dev/null and b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7ac1a6aea44a1d1741652f2f0a598fb18b2730d Binary files /dev/null and b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20c6786f9995cbd455c900944b0ab9501a97c202 Binary files /dev/null and b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e6ce39946ee6c03ff7518cbcb1309135e0354b2 Binary files /dev/null and b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/ddim.py b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..1bc51f94355732aedb0ffd254b8166144c468370 --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/ddim.py @@ -0,0 +1,206 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from .util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): # "uniform" 'quad' + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + # breakpoint() + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, T, H, W = shape + # breakpoint() + size = (batch_size, C, T, H, W) + # print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + # iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(time_range): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + # breakpoint() + e_t = self.model.denoise_fn(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.denoise_fn(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/diffusion.py b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..87111dc3b2ab671e798efc3a320ae81d024f20b9 --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/diffusion.py @@ -0,0 +1,1016 @@ +"Largely taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch" + +import math +import copy +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial + +from torch.utils import data +from pathlib import Path +from torch.optim import Adam +from torchvision import transforms as T, utils +from torch.cuda.amp import autocast, GradScaler +from PIL import Image + +from tqdm import tqdm +from einops import rearrange +from einops_exts import check_shape, rearrange_many + +from rotary_embedding_torch import RotaryEmbedding + +from .text import tokenize, bert_embed, BERT_MODEL_DIM +from torch.utils.data import Dataset, DataLoader +from ..vq_gan_3d.model.vqgan import VQGAN + +import matplotlib.pyplot as plt + +# helpers functions + + +def exists(x): + return x is not None + + +def noop(*args, **kwargs): + pass + + +def is_odd(n): + return (n % 2) == 1 + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +def cycle(dl): + while True: + for data in dl: + yield data + + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + + +def prob_mask_like(shape, prob, device): + if prob == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif prob == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob + + +def is_list_str(x): + if not isinstance(x, (list, tuple)): + return False + return all([type(el) == str for el in x]) + +# relative positional bias + + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads=8, + num_buckets=32, + max_distance=128 + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / + max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min( + val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype=torch.long, device=device) + k_pos = torch.arange(n, dtype=torch.long, device=device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket( + rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') + +# small helper modules + + +class EMA(): + def __init__(self, beta): + super().__init__() + self.beta = beta + + def update_model_average(self, ma_model, current_model): + for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + old_weight, up_weight = ma_params.data, current_params.data + ma_params.data = self.update_average(old_weight, up_weight) + + def update_average(self, old, new): + if old is None: + return new + return old * self.beta + (1 - self.beta) * new + + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +def Upsample(dim): + return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) + + +def Downsample(dim): + return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) + + +class LayerNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1)) + + def forward(self, x): + var = torch.var(x, dim=1, unbiased=False, keepdim=True) + mean = torch.mean(x, dim=1, keepdim=True) + return (x - mean) / (var + self.eps).sqrt() * self.gamma + + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = LayerNorm(dim) + + def forward(self, x, **kwargs): + x = self.norm(x) + return self.fn(x, **kwargs) + +# building block modules + + +class Block(nn.Module): + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding=(0, 1, 1)) + self.norm = nn.GroupNorm(groups, dim_out) + self.act = nn.SiLU() + + def forward(self, x, scale_shift=None): + x = self.proj(x) + x = self.norm(x) + + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + + return self.act(x) + + +class ResnetBlock(nn.Module): + def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): + super().__init__() + self.mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(time_emb_dim, dim_out * 2) + ) if exists(time_emb_dim) else None + + self.block1 = Block(dim, dim_out, groups=groups) + self.block2 = Block(dim_out, dim_out, groups=groups) + self.res_conv = nn.Conv3d( + dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb=None): + + scale_shift = None + if exists(self.mlp): + assert exists(time_emb), 'time emb must be passed in' + time_emb = self.mlp(time_emb) + time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') + scale_shift = time_emb.chunk(2, dim=1) + + h = self.block1(x, scale_shift=scale_shift) + + h = self.block2(h) + return h + self.res_conv(x) + + +class SpatialLinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, f, h, w = x.shape + x = rearrange(x, 'b c f h w -> (b f) c h w') + + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = rearrange_many( + qkv, 'b (h c) x y -> b h c (x y)', h=self.heads) + + q = q.softmax(dim=-2) + k = k.softmax(dim=-1) + + q = q * self.scale + context = torch.einsum('b h d n, b h e n -> b h d e', k, v) + + out = torch.einsum('b h d e, b h d n -> b h e n', context, q) + out = rearrange(out, 'b h c (x y) -> b (h c) x y', + h=self.heads, x=h, y=w) + out = self.to_out(out) + return rearrange(out, '(b f) c h w -> b c f h w', b=b) + +# attention along space and time + + +class EinopsToAndFrom(nn.Module): + def __init__(self, from_einops, to_einops, fn): + super().__init__() + self.from_einops = from_einops + self.to_einops = to_einops + self.fn = fn + + def forward(self, x, **kwargs): + shape = x.shape + reconstitute_kwargs = dict( + tuple(zip(self.from_einops.split(' '), shape))) + x = rearrange(x, f'{self.from_einops} -> {self.to_einops}') + x = self.fn(x, **kwargs) + x = rearrange( + x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + heads=4, + dim_head=32, + rotary_emb=None + ): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + + self.rotary_emb = rotary_emb + self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False) + self.to_out = nn.Linear(hidden_dim, dim, bias=False) + + def forward( + self, + x, + pos_bias=None, + focus_present_mask=None + ): + n, device = x.shape[-2], x.device + + qkv = self.to_qkv(x).chunk(3, dim=-1) + + if exists(focus_present_mask) and focus_present_mask.all(): + # if all batch samples are focusing on present + # it would be equivalent to passing that token's values through to the output + values = qkv[-1] + return self.to_out(values) + + # split out heads + + q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads) + + # scale + + q = q * self.scale + + # rotate positions into queries and keys for time attention + + if exists(self.rotary_emb): + q = self.rotary_emb.rotate_queries_or_keys(q) + k = self.rotary_emb.rotate_queries_or_keys(k) + + # similarity + + sim = einsum('... h i d, ... h j d -> ... h i j', q, k) + + # relative positional bias + + if exists(pos_bias): + sim = sim + pos_bias + + if exists(focus_present_mask) and not (~focus_present_mask).all(): + attend_all_mask = torch.ones( + (n, n), device=device, dtype=torch.bool) + attend_self_mask = torch.eye(n, device=device, dtype=torch.bool) + + mask = torch.where( + rearrange(focus_present_mask, 'b -> b 1 1 1 1'), + rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), + rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), + ) + + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + # numerical stability + + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + # aggregate values + + out = einsum('... h i j, ... h j d -> ... h i d', attn, v) + out = rearrange(out, '... h n d -> ... n (h d)') + return self.to_out(out) + +# model + + +class Unet3D(nn.Module): + def __init__( + self, + dim, + cond_dim=None, + out_dim=None, + dim_mults=(1, 2, 4, 8), + channels=3, + attn_heads=8, + attn_dim_head=32, + use_bert_text_cond=False, + init_dim=None, + init_kernel_size=7, + use_sparse_linear_attn=True, + block_type='resnet', + resnet_groups=8 + ): + super().__init__() + self.channels = channels + + # temporal attention and its relative positional encoding + + rotary_emb = RotaryEmbedding(min(32, attn_dim_head)) + + def temporal_attn(dim): return EinopsToAndFrom('b c f h w', 'b (h w) f c', Attention( + dim, heads=attn_heads, dim_head=attn_dim_head, rotary_emb=rotary_emb)) + + # realistically will not be able to generate that many frames of video... yet + self.time_rel_pos_bias = RelativePositionBias( + heads=attn_heads, max_distance=32) + + # initial conv + + init_dim = default(init_dim, dim) + assert is_odd(init_kernel_size) + + init_padding = init_kernel_size // 2 + self.init_conv = nn.Conv3d(channels, init_dim, (1, init_kernel_size, + init_kernel_size), padding=(0, init_padding, init_padding)) + + self.init_temporal_attn = Residual( + PreNorm(init_dim, temporal_attn(init_dim))) + + # dimensions + + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + # time conditioning + + time_dim = dim * 4 + self.time_mlp = nn.Sequential( + SinusoidalPosEmb(dim), + nn.Linear(dim, time_dim), + nn.GELU(), + nn.Linear(time_dim, time_dim) + ) + + # text conditioning + + self.has_cond = exists(cond_dim) or use_bert_text_cond + cond_dim = BERT_MODEL_DIM if use_bert_text_cond else cond_dim + + self.null_cond_emb = nn.Parameter( + torch.randn(1, cond_dim)) if self.has_cond else None + + cond_dim = time_dim + int(cond_dim or 0) + + # layers + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + + num_resolutions = len(in_out) + # block type + + block_klass = partial(ResnetBlock, groups=resnet_groups) + block_klass_cond = partial(block_klass, time_emb_dim=cond_dim) + + # modules for all layers + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + + self.downs.append(nn.ModuleList([ + block_klass_cond(dim_in, dim_out), + block_klass_cond(dim_out, dim_out), + Residual(PreNorm(dim_out, SpatialLinearAttention( + dim_out, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(), + Residual(PreNorm(dim_out, temporal_attn(dim_out))), + Downsample(dim_out) if not is_last else nn.Identity() + ])) + + mid_dim = dims[-1] + self.mid_block1 = block_klass_cond(mid_dim, mid_dim) + + spatial_attn = EinopsToAndFrom( + 'b c f h w', 'b f (h w) c', Attention(mid_dim, heads=attn_heads)) + + self.mid_spatial_attn = Residual(PreNorm(mid_dim, spatial_attn)) + self.mid_temporal_attn = Residual( + PreNorm(mid_dim, temporal_attn(mid_dim))) + + self.mid_block2 = block_klass_cond(mid_dim, mid_dim) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): + is_last = ind >= (num_resolutions - 1) + + self.ups.append(nn.ModuleList([ + block_klass_cond(dim_out * 2, dim_in), + block_klass_cond(dim_in, dim_in), + Residual(PreNorm(dim_in, SpatialLinearAttention( + dim_in, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(), + Residual(PreNorm(dim_in, temporal_attn(dim_in))), + Upsample(dim_in) if not is_last else nn.Identity() + ])) + + out_dim = default(out_dim, channels) + self.final_conv = nn.Sequential( + block_klass(dim * 2, dim), + nn.Conv3d(dim, out_dim, 1) + ) + + def forward_with_cond_scale( + self, + *args, + cond_scale=2., + **kwargs + ): + logits = self.forward(*args, null_cond_prob=0., **kwargs) + if cond_scale == 1 or not self.has_cond: + return logits + + null_logits = self.forward(*args, null_cond_prob=1., **kwargs) + return null_logits + (logits - null_logits) * cond_scale + + def forward( + self, + x, + time, + cond=None, + null_cond_prob=0., + focus_present_mask=None, + # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time) + prob_focus_present=0. + ): + assert not (self.has_cond and not exists(cond) + ), 'cond must be passed in if cond_dim specified' + x = torch.cat([x, cond], dim=1) + + batch, device = x.shape[0], x.device + + focus_present_mask = default(focus_present_mask, lambda: prob_mask_like( + (batch,), prob_focus_present, device=device)) + + time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device=x.device) + + x = self.init_conv(x) + r = x.clone() + + x = self.init_temporal_attn(x, pos_bias=time_rel_pos_bias) + + t = self.time_mlp(time) if exists(self.time_mlp) else None # [2, 128] + + # classifier free guidance + + if self.has_cond: + batch, device = x.shape[0], x.device + mask = prob_mask_like((batch,), null_cond_prob, device=device) + cond = torch.where(rearrange(mask, 'b -> b 1'), + self.null_cond_emb, cond) + t = torch.cat((t, cond), dim=-1) + + h = [] + + for block1, block2, spatial_attn, temporal_attn, downsample in self.downs: + x = block1(x, t) + x = block2(x, t) + x = spatial_attn(x) + x = temporal_attn(x, pos_bias=time_rel_pos_bias, + focus_present_mask=focus_present_mask) + h.append(x) + x = downsample(x) + + # [2, 256, 32, 4, 4] + x = self.mid_block1(x, t) + x = self.mid_spatial_attn(x) + x = self.mid_temporal_attn( + x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask) + x = self.mid_block2(x, t) + + for block1, block2, spatial_attn, temporal_attn, upsample in self.ups: + x = torch.cat((x, h.pop()), dim=1) + x = block1(x, t) + x = block2(x, t) + x = spatial_attn(x) + x = temporal_attn(x, pos_bias=time_rel_pos_bias, + focus_present_mask=focus_present_mask) + x = upsample(x) + + x = torch.cat((x, r), dim=1) + return self.final_conv(x) + +# gaussian diffusion trainer class + + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = torch.linspace(0, timesteps, steps, dtype=torch.float64) + alphas_cumprod = torch.cos( + ((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return torch.clip(betas, 0, 0.9999) + + +class GaussianDiffusion(nn.Module): + def __init__( + self, + denoise_fn, + *, + image_size, + num_frames, + text_use_bert_cls=False, + channels=3, + timesteps=1000, + loss_type='l1', + use_dynamic_thres=False, # from the Imagen paper + dynamic_thres_percentile=0.9, + vqgan_ckpt=None, + device=None + ): + super().__init__() + self.channels = channels + self.image_size = image_size + self.num_frames = num_frames + self.denoise_fn = denoise_fn + self.device = device + + if vqgan_ckpt: + self.vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt).cuda() + self.vqgan.eval() + else: + self.vqgan = None + + betas = cosine_beta_schedule(timesteps) + + alphas = 1. - betas + alphas_cumprod = torch.cumprod(alphas, axis=0) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.loss_type = loss_type + + # register buffer helper function that casts float64 to float32 + + def register_buffer(name, val): return self.register_buffer( + name, val.to(torch.float32)) + + register_buffer('betas', betas) + register_buffer('alphas_cumprod', alphas_cumprod) + register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) + + # calculations for diffusion q(x_t | x_{t-1}) and others + + register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) + register_buffer('sqrt_one_minus_alphas_cumprod', + torch.sqrt(1. - alphas_cumprod)) + register_buffer('log_one_minus_alphas_cumprod', + torch.log(1. - alphas_cumprod)) + register_buffer('sqrt_recip_alphas_cumprod', + torch.sqrt(1. / alphas_cumprod)) + register_buffer('sqrt_recipm1_alphas_cumprod', + torch.sqrt(1. / alphas_cumprod - 1)) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + + posterior_variance = betas * \ + (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + + register_buffer('posterior_variance', posterior_variance) + + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + + register_buffer('posterior_log_variance_clipped', + torch.log(posterior_variance.clamp(min=1e-20))) + register_buffer('posterior_mean_coef1', betas * + torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) + register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) + * torch.sqrt(alphas) / (1. - alphas_cumprod)) + + # text conditioning parameters + + self.text_use_bert_cls = text_use_bert_cls + + # dynamic thresholding when sampling + + self.use_dynamic_thres = use_dynamic_thres + self.dynamic_thres_percentile = dynamic_thres_percentile + + def q_mean_variance(self, x_start, t): + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract(1. - self.alphas_cumprod, t, x_start.shape) + log_variance = extract( + self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract( + self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool, cond=None, cond_scale=1.): + x_recon = self.predict_start_from_noise( + x, t=t, noise=self.denoise_fn.forward_with_cond_scale(x, t, cond=cond, cond_scale=cond_scale)) + + if clip_denoised: + s = 1. + if self.use_dynamic_thres: + s = torch.quantile( + rearrange(x_recon, 'b ... -> b (...)').abs(), + self.dynamic_thres_percentile, + dim=-1 + ) + + s.clamp_(min=1.) + s = s.view(-1, *((1,) * (x_recon.ndim - 1))) + + # clip by threshold, depending on whether static or dynamic + x_recon = x_recon.clamp(-s, s) / s + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.inference_mode() + def p_sample(self, x, t, cond=None, cond_scale=1., clip_denoised=True): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance( + x=x, t=t, clip_denoised=clip_denoised, cond=cond, cond_scale=cond_scale) + noise = torch.randn_like(x) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, + *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.inference_mode() + def p_sample_loop(self, shape, cond=None, cond_scale=1.): + device = self.betas.device + + b = shape[0] + img = torch.randn(shape, device=device) + # print('cond', cond.shape) + for i in reversed(range(0, self.num_timesteps)): + img = self.p_sample(img, torch.full( + (b,), i, device=device, dtype=torch.long), cond=cond, cond_scale=cond_scale) + + return img + + @torch.inference_mode() + def sample(self, cond=None, cond_scale=1., batch_size=16): + device = next(self.denoise_fn.parameters()).device + + if is_list_str(cond): + cond = bert_embed(tokenize(cond)).to(device) + + # batch_size = cond.shape[0] if exists(cond) else batch_size + batch_size = batch_size + image_size = self.image_size + channels = 8 # self.channels + num_frames = self.num_frames + # print((batch_size, channels, num_frames, image_size, image_size)) + # print('cond_',cond.shape) + _sample = self.p_sample_loop( + (batch_size, channels, num_frames, image_size, image_size), cond=cond, cond_scale=cond_scale) + + if isinstance(self.vqgan, VQGAN): + # denormalize TODO: Remove eventually + _sample = (((_sample + 1.0) / 2.0) * (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) + self.vqgan.codebook.embeddings.min() + + _sample = self.vqgan.decode(_sample, quantize=True) + else: + unnormalize_img(_sample) + + return _sample + + @torch.inference_mode() + def interpolate(self, x1, x2, t=None, lam=0.5): + b, *_, device = *x1.shape, x1.device + t = default(t, self.num_timesteps - 1) + + assert x1.shape == x2.shape + + t_batched = torch.stack([torch.tensor(t, device=device)] * b) + xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) + + img = (1 - lam) * xt1 + lam * xt2 + for i in reversed(range(0, t)): + img = self.p_sample(img, torch.full( + (b,), i, device=device, dtype=torch.long)) + + return img + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, + t, x_start.shape) * noise + ) + + def p_losses(self, x_start, t, cond=None, noise=None, **kwargs): + b, c, f, h, w, device = *x_start.shape, x_start.device + noise = default(noise, lambda: torch.randn_like(x_start)) + # breakpoint() + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # [2, 8, 32, 32, 32] + + if is_list_str(cond): + cond = bert_embed( + tokenize(cond), return_cls_repr=self.text_use_bert_cls) + cond = cond.to(device) + + x_recon = self.denoise_fn(x_noisy, t, cond=cond, **kwargs) + + if self.loss_type == 'l1': + loss = F.l1_loss(noise, x_recon) + elif self.loss_type == 'l2': + loss = F.mse_loss(noise, x_recon) + else: + raise NotImplementedError() + + return loss + + def forward(self, x, *args, **kwargs): + bs = int(x.shape[0]/2) + img=x[:bs,...] + mask=x[bs:,...] + mask_=(1-mask).detach() + masked_img = (img*mask_).detach() + masked_img=masked_img.permute(0,1,-1,-3,-2) + img=img.permute(0,1,-1,-3,-2) + mask=mask.permute(0,1,-1,-3,-2) + # breakpoint() + if isinstance(self.vqgan, VQGAN): + with torch.no_grad(): + img = self.vqgan.encode( + img, quantize=False, include_embeddings=True) + # normalize to -1 and 1 + img = ((img - self.vqgan.codebook.embeddings.min()) / + (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + masked_img = self.vqgan.encode( + masked_img, quantize=False, include_embeddings=True) + # normalize to -1 and 1 + masked_img = ((masked_img - self.vqgan.codebook.embeddings.min()) / + (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + else: + print("Hi") + img = normalize_img(img) + masked_img = normalize_img(masked_img) + mask = mask*2.0 - 1.0 + cc = torch.nn.functional.interpolate(mask, size=masked_img.shape[-3:]) + cond = torch.cat((masked_img, cc), dim=1) + + b, device, img_size, = img.shape[0], img.device, self.image_size + t = torch.randint(0, self.num_timesteps, (b,), device=device).long() + # breakpoint() + return self.p_losses(img, t, cond=cond, *args, **kwargs) + +# trainer class + + +CHANNELS_TO_MODE = { + 1: 'L', + 3: 'RGB', + 4: 'RGBA' +} + + +def seek_all_images(img, channels=3): + assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid' + mode = CHANNELS_TO_MODE[channels] + + i = 0 + while True: + try: + img.seek(i) + yield img.convert(mode) + except EOFError: + break + i += 1 + +# tensor of shape (channels, frames, height, width) -> gif + + +def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True): + tensor = ((tensor - tensor.min()) / (tensor.max() - tensor.min())) * 1.0 + images = map(T.ToPILImage(), tensor.unbind(dim=1)) + first_img, *rest_imgs = images + first_img.save(path, save_all=True, append_images=rest_imgs, + duration=duration, loop=loop, optimize=optimize) + return images + +# gif -> (channels, frame, height, width) tensor + + +def gif_to_tensor(path, channels=3, transform=T.ToTensor()): + img = Image.open(path) + tensors = tuple(map(transform, seek_all_images(img, channels=channels))) + return torch.stack(tensors, dim=1) + + +def identity(t, *args, **kwargs): + return t + + +def normalize_img(t): + return t * 2 - 1 + + +def unnormalize_img(t): + return (t + 1) * 0.5 + + +def cast_num_frames(t, *, frames): + f = t.shape[1] + + if f == frames: + return t + + if f > frames: + return t[:, :frames] + + return F.pad(t, (0, 0, 0, 0, 0, frames - f)) + + +class Dataset(data.Dataset): + def __init__( + self, + folder, + image_size, + channels=3, + num_frames=16, + horizontal_flip=False, + force_num_frames=True, + exts=['gif'] + ): + super().__init__() + self.folder = folder + self.image_size = image_size + self.channels = channels + self.paths = [p for ext in exts for p in Path( + f'{folder}').glob(f'**/*.{ext}')] + + self.cast_num_frames_fn = partial( + cast_num_frames, frames=num_frames) if force_num_frames else identity + + self.transform = T.Compose([ + T.Resize(image_size), + T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity), + T.CenterCrop(image_size), + T.ToTensor() + ]) + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = self.paths[index] + tensor = gif_to_tensor(path, self.channels, transform=self.transform) + return self.cast_num_frames_fn(tensor) + +# trainer class + + +class Tester(object): + def __init__( + self, + diffusion_model, + ): + super().__init__() + self.model = diffusion_model + self.ema_model = copy.deepcopy(self.model) + self.step=0 + self.image_size = diffusion_model.image_size + + self.reset_parameters() + + def reset_parameters(self): + self.ema_model.load_state_dict(self.model.state_dict()) + + + def load(self, milestone, map_location=None, **kwargs): + if milestone == -1: + all_milestones = [int(p.stem.split('-')[-1]) + for p in Path(self.results_folder).glob('**/*.pt')] + assert len( + all_milestones) > 0, 'need to have at least one milestone to load from latest checkpoint (milestone == -1)' + milestone = max(all_milestones) + + if map_location: + data = torch.load(milestone, map_location=map_location) + else: + data = torch.load(milestone) + + self.step = data['step'] + self.model.load_state_dict(data['model'], **kwargs) + self.ema_model.load_state_dict(data['ema'], **kwargs) + + diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/text.py b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/text.py new file mode 100644 index 0000000000000000000000000000000000000000..42dca78bec5075fb4f59c522aa3d00cc395d7536 --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/text.py @@ -0,0 +1,94 @@ +"Taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch" + +import torch +from einops import rearrange + + +def exists(val): + return val is not None + +# singleton globals + + +MODEL = None +TOKENIZER = None +BERT_MODEL_DIM = 768 + + +def get_tokenizer(): + global TOKENIZER + if not exists(TOKENIZER): + TOKENIZER = torch.hub.load( + 'huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased') + return TOKENIZER + + +def get_bert(): + global MODEL + if not exists(MODEL): + MODEL = torch.hub.load( + 'huggingface/pytorch-transformers', 'model', 'bert-base-cased') + if torch.cuda.is_available(): + MODEL = MODEL.cuda() + + return MODEL + +# tokenize + + +def tokenize(texts, add_special_tokens=True): + if not isinstance(texts, (list, tuple)): + texts = [texts] + + tokenizer = get_tokenizer() + + encoding = tokenizer.batch_encode_plus( + texts, + add_special_tokens=add_special_tokens, + padding=True, + return_tensors='pt' + ) + + token_ids = encoding.input_ids + return token_ids + +# embedding function + + +@torch.no_grad() +def bert_embed( + token_ids, + return_cls_repr=False, + eps=1e-8, + pad_id=0. +): + model = get_bert() + mask = token_ids != pad_id + + if torch.cuda.is_available(): + token_ids = token_ids.cuda() + mask = mask.cuda() + + outputs = model( + input_ids=token_ids, + attention_mask=mask, + output_hidden_states=True + ) + + hidden_state = outputs.hidden_states[-1] + + if return_cls_repr: + # return [cls] as representation + return hidden_state[:, 0] + + if not exists(mask): + return hidden_state.mean(dim=1) + + # mean all tokens excluding [cls], accounting for length + mask = mask[:, 1:] + mask = rearrange(mask, 'b n -> b n 1') + + numer = (hidden_state[:, 1:] * mask).sum(dim=1) + denom = mask.sum(dim=1) + masked_mean = numer / (denom + eps) + return diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/time_embedding.py b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/time_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..7c911fa4f485295005cda9c9c2099db3ddbda15c --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/time_embedding.py @@ -0,0 +1,75 @@ +import math +import torch +import torch.nn as nn + +from monai.networks.layers.utils import get_act_layer + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, emb_dim=16, downscale_freq_shift=1, max_period=10000, flip_sin_to_cos=False): + super().__init__() + self.emb_dim = emb_dim + self.downscale_freq_shift = downscale_freq_shift + self.max_period = max_period + self.flip_sin_to_cos = flip_sin_to_cos + + def forward(self, x): + device = x.device + half_dim = self.emb_dim // 2 + emb = math.log(self.max_period) / \ + (half_dim - self.downscale_freq_shift) + emb = torch.exp(-emb*torch.arange(half_dim, device=device)) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + + if self.flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + if self.emb_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class LearnedSinusoidalPosEmb(nn.Module): + """ following @crowsonkb 's lead with learned sinusoidal pos emb """ + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, emb_dim): + super().__init__() + self.emb_dim = emb_dim + half_dim = emb_dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x): + x = x[:, None] + freqs = x * self.weights[None, :] * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + if self.emb_dim % 2 == 1: + fouriered = torch.nn.functional.pad(fouriered, (0, 1, 0, 0)) + return fouriered + + +class TimeEmbbeding(nn.Module): + def __init__( + self, + emb_dim=64, + pos_embedder=SinusoidalPosEmb, + pos_embedder_kwargs={}, + act_name=("SWISH", {}) # Swish = SiLU + ): + super().__init__() + self.emb_dim = emb_dim + self.pos_emb_dim = pos_embedder_kwargs.get('emb_dim', emb_dim//4) + pos_embedder_kwargs['emb_dim'] = self.pos_emb_dim + self.pos_embedder = pos_embedder(**pos_embedder_kwargs) + + self.time_emb = nn.Sequential( + self.pos_embedder, + nn.Linear(self.pos_emb_dim, self.emb_dim), + get_act_layer(act_name), + nn.Linear(self.emb_dim, self.emb_dim) + ) + + def forward(self, time): + return self.time_emb(time) diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/unet.py b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..e9cae8021e874a92b6baaaf4decd802516c2dc87 --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/unet.py @@ -0,0 +1,226 @@ +from ddpm.time_embedding import TimeEmbbeding + +import monai.networks.nets as nets +import torch +import torch.nn as nn +from einops import rearrange + +from monai.networks.blocks import UnetBasicBlock, UnetResBlock, UnetUpBlock, Convolution, UnetOutBlock +from monai.networks.layers.utils import get_act_layer + + +class DownBlock(nn.Module): + def __init__( + self, + spatial_dims, + in_ch, + out_ch, + time_emb_dim, + cond_emb_dim, + act_name=("swish", {}), + **kwargs): + super(DownBlock, self).__init__() + self.loca_time_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(time_emb_dim, in_ch) # in_ch * 2 + ) + if cond_emb_dim is not None: + self.loca_cond_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(cond_emb_dim, in_ch), + ) + self.down_op = UnetBasicBlock( + spatial_dims, in_ch, out_ch, act_name=act_name, **kwargs) + + def forward(self, x, time_emb, cond_emb): + b, c, *_ = x.shape + sp_dim = x.ndim-2 + + # ------------ Time ---------- + time_emb = self.loca_time_embedder(time_emb) + time_emb = time_emb.reshape(b, c, *((1,)*sp_dim)) + # scale, shift = time_emb.chunk(2, dim = 1) + + # ------------ Combine ------------ + # x = x * (scale + 1) + shift + x = x + time_emb + + # ----------- Condition ------------ + if cond_emb is not None: + cond_emb = self.loca_cond_embedder(cond_emb) + cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim)) + x = x + cond_emb + + # ----------- Image --------- + y = self.down_op(x) + return y + + +class UpBlock(nn.Module): + def __init__( + self, + spatial_dims, + skip_ch, + enc_ch, + time_emb_dim, + cond_emb_dim, + act_name=("swish", {}), + **kwargs): + super(UpBlock, self).__init__() + self.up_op = UnetUpBlock(spatial_dims, enc_ch, + skip_ch, act_name=act_name, **kwargs) + self.loca_time_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(time_emb_dim, skip_ch * 2), + ) + if cond_emb_dim is not None: + self.loca_cond_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(cond_emb_dim, skip_ch * 2), + ) + + def forward(self, x_skip, x_enc, time_emb, cond_emb): + b, c, *_ = x_enc.shape + sp_dim = x_enc.ndim-2 + + # ----------- Time -------------- + time_emb = self.loca_time_embedder(time_emb) + time_emb = time_emb.reshape(b, c, *((1,)*sp_dim)) + # scale, shift = time_emb.chunk(2, dim = 1) + + # -------- Combine ------------- + # y = x * (scale + 1) + shift + x_enc = x_enc + time_emb + + # ----------- Condition ------------ + if cond_emb is not None: + cond_emb = self.loca_cond_embedder(cond_emb) + cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim)) + x_enc = x_enc + cond_emb + + # ----------- Image ------------- + y = self.up_op(x_enc, x_skip) + + # -------- Combine ------------- + # y = y * (scale + 1) + shift + + return y + + +class UNet(nn.Module): + + def __init__(self, + in_ch=1, + out_ch=1, + spatial_dims=3, + hid_chs=[32, 64, 128, 256, 512], + kernel_sizes=[(1, 3, 3), (1, 3, 3), (1, 3, 3), 3, 3], + strides=[1, (1, 2, 2), (1, 2, 2), 2, 2], + upsample_kernel_sizes=None, + act_name=("SWISH", {}), + norm_name=("INSTANCE", {"affine": True}), + time_embedder=TimeEmbbeding, + time_embedder_kwargs={}, + cond_embedder=None, + cond_embedder_kwargs={}, + # True = all but last layer, 0/False=disable, 1=only first layer, ... + deep_ver_supervision=True, + estimate_variance=False, + use_self_conditioning=False, + **kwargs + ): + super().__init__() + if upsample_kernel_sizes is None: + upsample_kernel_sizes = strides[1:] + + # ------------- Time-Embedder----------- + self.time_embedder = time_embedder(**time_embedder_kwargs) + + # ------------- Condition-Embedder----------- + if cond_embedder is not None: + self.cond_embedder = cond_embedder(**cond_embedder_kwargs) + cond_emb_dim = self.cond_embedder.emb_dim + else: + self.cond_embedder = None + cond_emb_dim = None + + # ----------- In-Convolution ------------ + in_ch = in_ch*2 if use_self_conditioning else in_ch + self.inc = UnetBasicBlock(spatial_dims, in_ch, hid_chs[0], kernel_size=kernel_sizes[0], stride=strides[0], + act_name=act_name, norm_name=norm_name, **kwargs) + + # ----------- Encoder ---------------- + self.encoders = nn.ModuleList([ + DownBlock(spatial_dims, hid_chs[i-1], hid_chs[i], time_emb_dim=self.time_embedder.emb_dim, + cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[ + i], stride=strides[i], act_name=act_name, + norm_name=norm_name, **kwargs) + for i in range(1, len(strides)) + ]) + + # ------------ Decoder ---------- + self.decoders = nn.ModuleList([ + UpBlock(spatial_dims, hid_chs[i], hid_chs[i+1], time_emb_dim=self.time_embedder.emb_dim, + cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[i + + 1], stride=strides[i+1], act_name=act_name, + norm_name=norm_name, upsample_kernel_size=upsample_kernel_sizes[i], **kwargs) + for i in range(len(strides)-1) + ]) + + # --------------- Out-Convolution ---------------- + out_ch_hor = out_ch*2 if estimate_variance else out_ch + self.outc = UnetOutBlock( + spatial_dims, hid_chs[0], out_ch_hor, dropout=None) + if isinstance(deep_ver_supervision, bool): + deep_ver_supervision = len( + strides)-2 if deep_ver_supervision else 0 + self.outc_ver = nn.ModuleList([ + UnetOutBlock(spatial_dims, hid_chs[i], out_ch, dropout=None) + for i in range(1, deep_ver_supervision+1) + ]) + + def forward(self, x_t, t, cond=None, self_cond=None, **kwargs): + condition = cond + # x_t [B, C, (D), H, W] + # t [B,] + + # -------- In-Convolution -------------- + x = [None for _ in range(len(self.encoders)+1)] + x_t = torch.cat([x_t, self_cond], + dim=1) if self_cond is not None else x_t + x[0] = self.inc(x_t) + + # -------- Time Embedding (Gloabl) ----------- + time_emb = self.time_embedder(t) # [B, C] + + # -------- Condition Embedding (Gloabl) ----------- + if (condition is None) or (self.cond_embedder is None): + cond_emb = None + else: + cond_emb = self.cond_embedder(condition) # [B, C] + + # --------- Encoder -------------- + for i in range(len(self.encoders)): + x[i+1] = self.encoders[i](x[i], time_emb, cond_emb) + + # -------- Decoder ----------- + for i in range(len(self.decoders), 0, -1): + x[i-1] = self.decoders[i-1](x[i-1], x[i], time_emb, cond_emb) + + # ---------Out-Convolution ------------ + y_hor = self.outc(x[0]) + y_ver = [outc_ver_i(x[i+1]) + for i, outc_ver_i in enumerate(self.outc_ver)] + + return y_hor # , y_ver + + def forward_with_cond_scale(self, *args, cond_scale=0., **kwargs): + return self.forward(*args, **kwargs) + + +if __name__ == '__main__': + model = UNet(in_ch=3) + input = torch.randn((1, 3, 16, 128, 128)) + time = torch.randn((1,)) + out_hor, out_ver = model(input, time) + print(out_hor[0].shape) diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/util.py b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..2047cf56802046bbf37cfd33f819ed75a239a7df --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/ddpm/util.py @@ -0,0 +1,271 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +# from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + if c != 1: + steps_out = ddim_timesteps + 1 + else: + steps_out = ddim_timesteps + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbf3deb85a40168e55d5395f42018dd06034c55b Binary files /dev/null and b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__init__.py b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38c08845c1ee2f7b4b52bf95fa282b0808931a3a --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__init__.py @@ -0,0 +1,3 @@ +from .vqgan import VQGAN +from .codebook import Codebook +from .lpips import LPIPS diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c2851c21c12d1381767f75f705c7b896990b106 Binary files /dev/null and b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff829982a3c7189eca26440f3947bc15d9040771 Binary files /dev/null and b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55de5141f6e8504572dfd74c30009219697c8c7d Binary files /dev/null and b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0c1659ef1fe1d88ada76ea763eafd39c5b70820 Binary files /dev/null and b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth new file mode 100644 index 0000000000000000000000000000000000000000..f57dcf5cc764d61c8a460365847fb2137ff0a62d --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868 +size 7289 diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/codebook.py b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/codebook.py new file mode 100644 index 0000000000000000000000000000000000000000..c17ed701a2824cc0dcd75670d47a6ab3842e9d35 --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/codebook.py @@ -0,0 +1,109 @@ +""" Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from ..utils import shift_dim + + +class Codebook(nn.Module): + def __init__(self, n_codes, embedding_dim, no_random_restart=False, restart_thres=1.0): + super().__init__() + self.register_buffer('embeddings', torch.randn(n_codes, embedding_dim)) + self.register_buffer('N', torch.zeros(n_codes)) + self.register_buffer('z_avg', self.embeddings.data.clone()) + + self.n_codes = n_codes + self.embedding_dim = embedding_dim + self._need_init = True + self.no_random_restart = no_random_restart + self.restart_thres = restart_thres + + def _tile(self, x): + d, ew = x.shape + if d < self.n_codes: + n_repeats = (self.n_codes + d - 1) // d + std = 0.01 / np.sqrt(ew) + x = x.repeat(n_repeats, 1) + x = x + torch.randn_like(x) * std + return x + + def _init_embeddings(self, z): + # z: [b, c, t, h, w] + self._need_init = False + breakpoint() + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [65536, 8] [2, 8, 32, 32, 32] + y = self._tile(flat_inputs) # [65536, 8] + + d = y.shape[0] + _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + self.embeddings.data.copy_(_k_rand) + self.z_avg.data.copy_(_k_rand) + self.N.data.copy_(torch.ones(self.n_codes)) + + def forward(self, z): + # z: [b, c, t, h, w] + if self._need_init and self.training: + self._init_embeddings(z) + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [bthw, c] [65536, 8] + distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \ + - 2 * flat_inputs @ self.embeddings.t() \ + + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) # [bthw, c] [65536, 8] + + encoding_indices = torch.argmin(distances, dim=1) # [65536] + encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as( + flat_inputs) # [bthw, ncode] [65536, 16384] + encoding_indices = encoding_indices.view( + z.shape[0], *z.shape[2:]) # [b, t, h, w, ncode] [2, 32, 32, 32] + + embeddings = F.embedding( + encoding_indices, self.embeddings) # [b, t, h, w, c] self.embeddings [16384, 8] + embeddings = shift_dim(embeddings, -1, 1) # [b, c, t, h, w] [2, 8, 32, 32, 32] + + commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) + + # EMA codebook update + if self.training: + n_total = encode_onehot.sum(dim=0) # [16384] + encode_sum = flat_inputs.t() @ encode_onehot # [8, 16384] + if dist.is_initialized(): + dist.all_reduce(n_total) + dist.all_reduce(encode_sum) + + self.N.data.mul_(0.99).add_(n_total, alpha=0.01) + self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) + + n = self.N.sum() + weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n + encode_normalized = self.z_avg / weights.unsqueeze(1) + self.embeddings.data.copy_(encode_normalized) + + y = self._tile(flat_inputs) + _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + + if not self.no_random_restart: + usage = (self.N.view(self.n_codes, 1) + >= self.restart_thres).float() + self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) + + embeddings_st = (embeddings - z).detach() + z + + avg_probs = torch.mean(encode_onehot, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * + torch.log(avg_probs + 1e-10))) + + return dict(embeddings=embeddings_st, encodings=encoding_indices, + commitment_loss=commitment_loss, perplexity=perplexity) + + def dictionary_lookup(self, encodings): + embeddings = F.embedding(encodings, self.embeddings) + return embeddings diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/lpips.py b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..019434b9b29945d67e6e0a95dec324df7ff908f3 --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/lpips.py @@ -0,0 +1,181 @@ +"""Adapted from https://github.com/SongweiGe/TATS""" + +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + + +from collections import namedtuple +from torchvision import models +import torch.nn as nn +import torch +from tqdm import tqdm +import requests +import os +import hashlib +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format( + name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + # self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + self.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name is not "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + model.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer( + input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor( + outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) + for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor( + [-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor( + [.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, + padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple( + "VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, + h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x/(norm_factor+eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py new file mode 100644 index 0000000000000000000000000000000000000000..a64ec8878c0ecbf418775e2958b2cbf578ce2918 --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py @@ -0,0 +1,561 @@ +"""Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import math +import argparse +import numpy as np +import pickle as pkl + +import pytorch_lightning as pl +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from ..utils import shift_dim, adopt_weight, comp_getattr +from .lpips import LPIPS +from .codebook import Codebook + + +def silu(x): + return x*torch.sigmoid(x) + + +class SiLU(nn.Module): + def __init__(self): + super(SiLU, self).__init__() + + def forward(self, x): + return silu(x) + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1. - logits_real)) + loss_fake = torch.mean(F.relu(1. + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake))) + return d_loss + + +class VQGAN(pl.LightningModule): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.embedding_dim = cfg.model.embedding_dim # 8 + self.n_codes = cfg.model.n_codes # 16384 + + self.encoder = Encoder(cfg.model.n_hiddens, # 16 + cfg.model.downsample, # [2, 2, 2] + cfg.dataset.image_channels, # 1 + cfg.model.norm_type, # group + cfg.model.padding_type, # replicate + cfg.model.num_groups, # 32 + ) + self.decoder = Decoder( + cfg.model.n_hiddens, cfg.model.downsample, cfg.dataset.image_channels, cfg.model.norm_type, cfg.model.num_groups) + self.enc_out_ch = self.encoder.out_channels + self.pre_vq_conv = SamePadConv3d( + self.enc_out_ch, cfg.model.embedding_dim, 1, padding_type=cfg.model.padding_type) + self.post_vq_conv = SamePadConv3d( + cfg.model.embedding_dim, self.enc_out_ch, 1) + + self.codebook = Codebook(cfg.model.n_codes, cfg.model.embedding_dim, + no_random_restart=cfg.model.no_random_restart, restart_thres=cfg.model.restart_thres) + + self.gan_feat_weight = cfg.model.gan_feat_weight + # TODO: Changed batchnorm from sync to normal + self.image_discriminator = NLayerDiscriminator( + cfg.dataset.image_channels, cfg.model.disc_channels, cfg.model.disc_layers, norm_layer=nn.BatchNorm2d) + self.video_discriminator = NLayerDiscriminator3D( + cfg.dataset.image_channels, cfg.model.disc_channels, cfg.model.disc_layers, norm_layer=nn.BatchNorm3d) + + if cfg.model.disc_loss_type == 'vanilla': + self.disc_loss = vanilla_d_loss + elif cfg.model.disc_loss_type == 'hinge': + self.disc_loss = hinge_d_loss + + self.perceptual_model = LPIPS().eval() + + self.image_gan_weight = cfg.model.image_gan_weight + self.video_gan_weight = cfg.model.video_gan_weight + + self.perceptual_weight = cfg.model.perceptual_weight + + self.l1_weight = cfg.model.l1_weight + self.save_hyperparameters() + + def encode(self, x, include_embeddings=False, quantize=True): + h = self.pre_vq_conv(self.encoder(x)) + if quantize: + vq_output = self.codebook(h) + if include_embeddings: + return vq_output['embeddings'], vq_output['encodings'] + else: + return vq_output['encodings'] + return h + + def decode(self, latent, quantize=False): + if quantize: + vq_output = self.codebook(latent) + latent = vq_output['encodings'] + h = F.embedding(latent, self.codebook.embeddings) + h = self.post_vq_conv(shift_dim(h, -1, 1)) + return self.decoder(h) + + def forward(self, x, optimizer_idx=None, log_image=False): + B, C, T, H, W = x.shape + z = self.pre_vq_conv(self.encoder(x)) # [2, 32, 32, 32, 32] [2, 8, 32, 32, 32] + vq_output = self.codebook(z) # ['embeddings', 'encodings', 'commitment_loss', 'perplexity'] + x_recon = self.decoder(self.post_vq_conv(vq_output['embeddings'])) # [2, 8, 32, 32, 32] [2, 32, 32, 32, 32] + + recon_loss = F.l1_loss(x_recon, x) * self.l1_weight + + # Selects one random 2D image from each 3D Image + frame_idx = torch.randint(0, T, [B]).cuda() + frame_idx_selected = frame_idx.reshape(-1, + 1, 1, 1, 1).repeat(1, C, 1, H, W) # [2, 1, 1, 64, 64] + frames = torch.gather(x, 2, frame_idx_selected).squeeze(2) # [2, 1, 64, 64] + frames_recon = torch.gather(x_recon, 2, frame_idx_selected).squeeze(2) # [2, 1, 64, 64] + + if log_image: + return frames, frames_recon, x, x_recon + + if optimizer_idx == 0: + # Autoencoder - train the "generator" + + # Perceptual loss + perceptual_loss = 0 + if self.perceptual_weight > 0: + perceptual_loss = self.perceptual_model( + frames, frames_recon).mean() * self.perceptual_weight + + # Discriminator loss (turned on after a certain epoch) + logits_image_fake, pred_image_fake = self.image_discriminator( + frames_recon) + logits_video_fake, pred_video_fake = self.video_discriminator( + x_recon) + g_image_loss = -torch.mean(logits_image_fake) + g_video_loss = -torch.mean(logits_video_fake) + g_loss = self.image_gan_weight*g_image_loss + self.video_gan_weight*g_video_loss + disc_factor = adopt_weight( + self.global_step, threshold=self.cfg.model.discriminator_iter_start) + aeloss = disc_factor * g_loss + + # GAN feature matching loss - tune features such that we get the same prediction result on the discriminator + image_gan_feat_loss = 0 + video_gan_feat_loss = 0 + feat_weights = 4.0 / (3 + 1) + if self.image_gan_weight > 0: + logits_image_real, pred_image_real = self.image_discriminator( + frames) + for i in range(len(pred_image_fake)-1): + image_gan_feat_loss += feat_weights * \ + F.l1_loss(pred_image_fake[i], pred_image_real[i].detach( + )) * (self.image_gan_weight > 0) + if self.video_gan_weight > 0: + logits_video_real, pred_video_real = self.video_discriminator( + x) + for i in range(len(pred_video_fake)-1): + video_gan_feat_loss += feat_weights * \ + F.l1_loss(pred_video_fake[i], pred_video_real[i].detach( + )) * (self.video_gan_weight > 0) + gan_feat_loss = disc_factor * self.gan_feat_weight * \ + (image_gan_feat_loss + video_gan_feat_loss) + + self.log("train/g_image_loss", g_image_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/g_video_loss", g_video_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/image_gan_feat_loss", image_gan_feat_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/video_gan_feat_loss", video_gan_feat_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/perceptual_loss", perceptual_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log("train/recon_loss", recon_loss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + self.log("train/aeloss", aeloss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + self.log("train/commitment_loss", vq_output['commitment_loss'], + prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log('train/perplexity', vq_output['perplexity'], + prog_bar=True, logger=True, on_step=True, on_epoch=True) + return recon_loss, x_recon, vq_output, aeloss, perceptual_loss, gan_feat_loss + + if optimizer_idx == 1: + # Train discriminator + logits_image_real, _ = self.image_discriminator(frames.detach()) + logits_video_real, _ = self.video_discriminator(x.detach()) + + logits_image_fake, _ = self.image_discriminator( + frames_recon.detach()) + logits_video_fake, _ = self.video_discriminator(x_recon.detach()) + + d_image_loss = self.disc_loss(logits_image_real, logits_image_fake) + d_video_loss = self.disc_loss(logits_video_real, logits_video_fake) + disc_factor = adopt_weight( + self.global_step, threshold=self.cfg.model.discriminator_iter_start) + discloss = disc_factor * \ + (self.image_gan_weight*d_image_loss + + self.video_gan_weight*d_video_loss) + + self.log("train/logits_image_real", logits_image_real.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_image_fake", logits_image_fake.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_video_real", logits_video_real.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_video_fake", logits_video_fake.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/d_image_loss", d_image_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/d_video_loss", d_video_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/discloss", discloss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + return discloss + + perceptual_loss = self.perceptual_model( + frames, frames_recon) * self.perceptual_weight + return recon_loss, x_recon, vq_output, perceptual_loss + + def training_step(self, batch, batch_idx, optimizer_idx): + x = batch['image'] + if optimizer_idx == 0: + recon_loss, _, vq_output, aeloss, perceptual_loss, gan_feat_loss = self.forward( + x, optimizer_idx) + commitment_loss = vq_output['commitment_loss'] + loss = recon_loss + commitment_loss + aeloss + perceptual_loss + gan_feat_loss + if optimizer_idx == 1: + discloss = self.forward(x, optimizer_idx) + loss = discloss + return loss + + def validation_step(self, batch, batch_idx): + x = batch['image'] # TODO: batch['stft'] + recon_loss, _, vq_output, perceptual_loss = self.forward(x) + self.log('val/recon_loss', recon_loss, prog_bar=True) + self.log('val/perceptual_loss', perceptual_loss, prog_bar=True) + self.log('val/perplexity', vq_output['perplexity'], prog_bar=True) + self.log('val/commitment_loss', + vq_output['commitment_loss'], prog_bar=True) + + def configure_optimizers(self): + lr = self.cfg.model.lr + opt_ae = torch.optim.Adam(list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.pre_vq_conv.parameters()) + + list(self.post_vq_conv.parameters()) + + list(self.codebook.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(list(self.image_discriminator.parameters()) + + list(self.video_discriminator.parameters()), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def log_images(self, batch, **kwargs): + log = dict() + x = batch['image'] + x = x.to(self.device) + frames, frames_rec, _, _ = self(x, log_image=True) + log["inputs"] = frames + log["reconstructions"] = frames_rec + #log['mean_org'] = batch['mean_org'] + #log['std_org'] = batch['std_org'] + return log + + def log_videos(self, batch, **kwargs): + log = dict() + x = batch['image'] + _, _, x, x_rec = self(x, log_image=True) + log["inputs"] = x + log["reconstructions"] = x_rec + #log['mean_org'] = batch['mean_org'] + #log['std_org'] = batch['std_org'] + return log + + +def Normalize(in_channels, norm_type='group', num_groups=32): + assert norm_type in ['group', 'batch'] + if norm_type == 'group': + # TODO Changed num_groups from 32 to 8 + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == 'batch': + return torch.nn.SyncBatchNorm(in_channels) + + +class Encoder(nn.Module): + def __init__(self, n_hiddens, downsample, image_channel=3, norm_type='group', padding_type='replicate', num_groups=32): + super().__init__() + n_times_downsample = np.array([int(math.log2(d)) for d in downsample]) + self.conv_blocks = nn.ModuleList() + max_ds = n_times_downsample.max() + + self.conv_first = SamePadConv3d( + image_channel, n_hiddens, kernel_size=3, padding_type=padding_type) + + for i in range(max_ds): + block = nn.Module() + in_channels = n_hiddens * 2**i + out_channels = n_hiddens * 2**(i+1) + stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) + block.down = SamePadConv3d( + in_channels, out_channels, 4, stride=stride, padding_type=padding_type) + block.res = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + self.conv_blocks.append(block) + n_times_downsample -= 1 + + self.final_block = nn.Sequential( + Normalize(out_channels, norm_type, num_groups=num_groups), + SiLU() + ) + + self.out_channels = out_channels + + def forward(self, x): + h = self.conv_first(x) + for block in self.conv_blocks: + h = block.down(h) + h = block.res(h) + h = self.final_block(h) + return h + + +class Decoder(nn.Module): + def __init__(self, n_hiddens, upsample, image_channel, norm_type='group', num_groups=32): + super().__init__() + + n_times_upsample = np.array([int(math.log2(d)) for d in upsample]) + max_us = n_times_upsample.max() + + in_channels = n_hiddens*2**max_us + self.final_block = nn.Sequential( + Normalize(in_channels, norm_type, num_groups=num_groups), + SiLU() + ) + + self.conv_blocks = nn.ModuleList() + for i in range(max_us): + block = nn.Module() + in_channels = in_channels if i == 0 else n_hiddens*2**(max_us-i+1) + out_channels = n_hiddens*2**(max_us-i) + us = tuple([2 if d > 0 else 1 for d in n_times_upsample]) + block.up = SamePadConvTranspose3d( + in_channels, out_channels, 4, stride=us) + block.res1 = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + block.res2 = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + self.conv_blocks.append(block) + n_times_upsample -= 1 + + self.conv_last = SamePadConv3d( + out_channels, image_channel, kernel_size=3) + + def forward(self, x): + h = self.final_block(x) + for i, block in enumerate(self.conv_blocks): + h = block.up(h) + h = block.res1(h) + h = block.res2(h) + h = self.conv_last(h) + return h + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group', padding_type='replicate', num_groups=32): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, norm_type, num_groups=num_groups) + self.conv1 = SamePadConv3d( + in_channels, out_channels, kernel_size=3, padding_type=padding_type) + self.dropout = torch.nn.Dropout(dropout) + self.norm2 = Normalize(in_channels, norm_type, num_groups=num_groups) + self.conv2 = SamePadConv3d( + out_channels, out_channels, kernel_size=3, padding_type=padding_type) + if self.in_channels != self.out_channels: + self.conv_shortcut = SamePadConv3d( + in_channels, out_channels, kernel_size=3, padding_type=padding_type) + + def forward(self, x): + h = x + h = self.norm1(h) + h = silu(h) + h = self.conv1(h) + h = self.norm2(h) + h = silu(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.conv_shortcut(x) + + return x+h + + +# Does not support dilation +class SamePadConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + if isinstance(stride, int): + stride = (stride,) * 3 + + # assumes that the input shape is divisible by stride + total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) + pad_input = [] + for p in total_pad[::-1]: # reverse since F.pad starts from last dim + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + self.pad_input = pad_input + self.padding_type = padding_type + + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, + stride=stride, padding=0, bias=bias) + + def forward(self, x): + return self.conv(F.pad(x, self.pad_input, mode=self.padding_type)) + + +class SamePadConvTranspose3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + if isinstance(stride, int): + stride = (stride,) * 3 + + total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) + pad_input = [] + for p in total_pad[::-1]: # reverse since F.pad starts from last dim + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + self.pad_input = pad_input + self.padding_type = padding_type + + self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, + stride=stride, bias=bias, + padding=tuple([k - 1 for k in kernel_size])) + + def forward(self, x): + return self.convt(F.pad(x, self.pad_input, mode=self.padding_type)) + + +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True): + # def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=True): + super(NLayerDiscriminator, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, + stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, + stride=1, padding=padw)]] + + if use_sigmoid: + sequence += [[nn.Sigmoid()]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[-1], res[1:] + else: + return self.model(input), _ + + +class NLayerDiscriminator3D(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True): + super(NLayerDiscriminator3D, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv3d(input_nc, ndf, kernel_size=kw, + stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv3d(nf, 1, kernel_size=kw, + stride=1, padding=padw)]] + + if use_sigmoid: + sequence += [[nn.Sigmoid()]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[-1], res[1:] + else: + return self.model(input), _ diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/utils.py b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..acf1a21270d98d0a1ea9935d00f9f16d89d54551 --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/ldm/vq_gan_3d/utils.py @@ -0,0 +1,177 @@ +""" Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import warnings +import torch +import imageio + +import math +import numpy as np + +import sys +import pdb as pdb_original +import logging + +import imageio.core.util +logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR) + + +class ForkedPdb(pdb_original.Pdb): + """A Pdb subclass that may be used + from a forked multiprocessing child + + """ + + def interaction(self, *args, **kwargs): + _stdin = sys.stdin + try: + sys.stdin = open('/dev/stdin') + pdb_original.Pdb.interaction(self, *args, **kwargs) + finally: + sys.stdin = _stdin + + +# Shifts src_tf dim to dest dim +# i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c) +def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): + n_dims = len(x.shape) + if src_dim < 0: + src_dim = n_dims + src_dim + if dest_dim < 0: + dest_dim = n_dims + dest_dim + + assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims + + dims = list(range(n_dims)) + del dims[src_dim] + + permutation = [] + ctr = 0 + for i in range(n_dims): + if i == dest_dim: + permutation.append(src_dim) + else: + permutation.append(dims[ctr]) + ctr += 1 + x = x.permute(permutation) + if make_contiguous: + x = x.contiguous() + return x + + +# reshapes tensor start from dim i (inclusive) +# to dim j (exclusive) to the desired shape +# e.g. if x.shape = (b, thw, c) then +# view_range(x, 1, 2, (t, h, w)) returns +# x of shape (b, t, h, w, c) +def view_range(x, i, j, shape): + shape = tuple(shape) + + n_dims = len(x.shape) + if i < 0: + i = n_dims + i + + if j is None: + j = n_dims + elif j < 0: + j = n_dims + j + + assert 0 <= i < j <= n_dims + + x_shape = x.shape + target_shape = x_shape[:i] + shape + x_shape[j:] + return x.view(target_shape) + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def tensor_slice(x, begin, size): + assert all([b >= 0 for b in begin]) + size = [l - b if s == -1 else s + for s, b, l in zip(size, begin, x.shape)] + assert all([s >= 0 for s in size]) + + slices = [slice(b, b + s) for b, s in zip(begin, size)] + return x[slices] + + +def adopt_weight(global_step, threshold=0, value=0.): + weight = 1 + if global_step < threshold: + weight = value + return weight + + +def save_video_grid(video, fname, nrow=None, fps=6): + b, c, t, h, w = video.shape + video = video.permute(0, 2, 3, 4, 1) + video = (video.cpu().numpy() * 255).astype('uint8') + if nrow is None: + nrow = math.ceil(math.sqrt(b)) + ncol = math.ceil(b / nrow) + padding = 1 + video_grid = np.zeros((t, (padding + h) * nrow + padding, + (padding + w) * ncol + padding, c), dtype='uint8') + for i in range(b): + r = i // ncol + c = i % ncol + start_r = (padding + h) * r + start_c = (padding + w) * c + video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] + video = [] + for i in range(t): + video.append(video_grid[i]) + imageio.mimsave(fname, video, fps=fps) + ## skvideo.io.vwrite(fname, video_grid, inputdict={'-r': '5'}) + #print('saved videos to', fname) + + +def comp_getattr(args, attr_name, default=None): + if hasattr(args, attr_name): + return getattr(args, attr_name) + else: + return default + + +def visualize_tensors(t, name=None, nest=0): + if name is not None: + print(name, "current nest: ", nest) + print("type: ", type(t)) + if 'dict' in str(type(t)): + print(t.keys()) + for k in t.keys(): + if t[k] is None: + print(k, "None") + else: + if 'Tensor' in str(type(t[k])): + print(k, t[k].shape) + elif 'dict' in str(type(t[k])): + print(k, 'dict') + visualize_tensors(t[k], name, nest + 1) + elif 'list' in str(type(t[k])): + print(k, len(t[k])) + visualize_tensors(t[k], name, nest + 1) + elif 'list' in str(type(t)): + print("list length: ", len(t)) + for t2 in t: + visualize_tensors(t2, name, nest + 1) + elif 'Tensor' in str(type(t)): + print(t.shape) + else: + print(t) + return "" diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/utils.py b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..54a3d68432d165ab2895859a89d7be4d150e9721 --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/utils.py @@ -0,0 +1,465 @@ +### Tumor Generateion +import random +import cv2 +import elasticdeform +import numpy as np +from scipy.ndimage import gaussian_filter +from TumorGeneration.ldm.ddpm.ddim import DDIMSampler + +# Step 1: Random select (numbers) location for tumor. +def random_select(mask_scan, organ_type): + # we first find z index and then sample point with z slice + # print('mask_scan',np.unique(mask_scan)) + # print('pixel num', (mask_scan == 1).sum()) + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0].min(), np.where(np.any(mask_scan, axis=(0, 1)))[0].max() + # print('z_start, z_end',z_start, z_end) + # we need to strict number z's position (0.3 - 0.7 in the middle of liver) + while 1: + z = round(random.uniform(0.3, 0.7) * (z_end - z_start)) + z_start + liver_mask = mask_scan[..., z] + # erode the mask (we don't want the edge points) + if organ_type == 'liver': + kernel = np.ones((5,5), dtype=np.uint8) + liver_mask = cv2.erode(liver_mask, kernel, iterations=1) + if (liver_mask == 1).sum() > 0: + break + + + + # print('liver_mask', (liver_mask == 1).sum()) + coordinates = np.argwhere(liver_mask == 1) + random_index = np.random.randint(0, len(coordinates)) + xyz = coordinates[random_index].tolist() # get x,y + xyz.append(z) + potential_points = xyz + + return potential_points + +def center_select(mask_scan): + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0].min(), np.where(np.any(mask_scan, axis=(0, 1)))[0].max() + x_start, x_end = np.where(np.any(mask_scan, axis=(1, 2)))[0].min(), np.where(np.any(mask_scan, axis=(1, 2)))[0].max() + y_start, y_end = np.where(np.any(mask_scan, axis=(0, 2)))[0].min(), np.where(np.any(mask_scan, axis=(0, 2)))[0].max() + + z = round(0.5 * (z_end - z_start)) + z_start + x = round(0.5 * (x_end - x_start)) + x_start + y = round(0.5 * (y_end - y_start)) + y_start + + xyz = [x, y, z] + potential_points = xyz + + return potential_points + +# Step 2 : generate the ellipsoid +def get_ellipsoid(x, y, z): + """" + x, y, z is the radius of this ellipsoid in x, y, z direction respectly. + """ + sh = (4*x, 4*y, 4*z) + out = np.zeros(sh, int) + aux = np.zeros(sh) + radii = np.array([x, y, z]) + com = np.array([2*x, 2*y, 2*z]) # center point + + # calculate the ellipsoid + bboxl = np.floor(com-radii).clip(0,None).astype(int) + bboxh = (np.ceil(com+radii)+1).clip(None, sh).astype(int) + roi = out[tuple(map(slice,bboxl,bboxh))] + roiaux = aux[tuple(map(slice,bboxl,bboxh))] + logrid = *map(np.square,np.ogrid[tuple( + map(slice,(bboxl-com)/radii,(bboxh-com-1)/radii,1j*(bboxh-bboxl)))]), + dst = (1-sum(logrid)).clip(0,None) + mask = dst>roiaux + roi[mask] = 1 + np.copyto(roiaux,dst,where=mask) + + return out + +def get_fixed_geo(mask_scan, tumor_type, organ_type): + if tumor_type == 'large': + enlarge_x, enlarge_y, enlarge_z = 280, 280, 280 + else: + enlarge_x, enlarge_y, enlarge_z = 160, 160, 160 + geo_mask = np.zeros((mask_scan.shape[0] + enlarge_x, mask_scan.shape[1] + enlarge_y, mask_scan.shape[2] + enlarge_z), dtype=np.int8) + tiny_radius, small_radius, medium_radius, large_radius = 4, 8, 16, 32 + + if tumor_type == 'tiny': + num_tumor = random.randint(1,3) + # num_tumor = 1 + for _ in range(num_tumor): + # Tiny tumor + x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + sigma = random.uniform(0.5,1) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + + if tumor_type == 'small': + num_tumor = random.randint(1,3) + # num_tumor = 1 + for _ in range(num_tumor): + # Small tumor + x = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + y = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + z = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + sigma = random.randint(1, 2) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'medium': + # num_tumor = random.randint(1, 3) + num_tumor = 1 + for _ in range(num_tumor): + # medium tumor + x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + sigma = random.randint(3, 6) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste medium tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'large': + num_tumor = 1 # random.randint(1,3) + for _ in range(num_tumor): + # Large tumor + + x = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + y = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + z = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + sigma = random.randint(5, 10) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + if organ_type == 'liver' or organ_type == 'kidney' : + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + else: + x_start, x_end = np.where(np.any(geo, axis=(1, 2)))[0].min(), np.where(np.any(geo, axis=(1, 2)))[0].max() + y_start, y_end = np.where(np.any(geo, axis=(0, 2)))[0].min(), np.where(np.any(geo, axis=(0, 2)))[0].max() + z_start, z_end = np.where(np.any(geo, axis=(0, 1)))[0].min(), np.where(np.any(geo, axis=(0, 1)))[0].max() + geo = geo[x_start:x_end, y_start:y_end, z_start:z_end] + + point = center_select(mask_scan) + + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low = new_point[0] - geo.shape[0]//2 + y_low = new_point[1] - geo.shape[1]//2 + z_low = new_point[2] - geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_low+geo.shape[0], y_low:y_low+geo.shape[1], z_low:z_low+geo.shape[2]] += geo + + geo_mask = geo_mask[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + # texture_map = texture_map[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + if ((tumor_type == 'medium') or (tumor_type == 'large')) and (organ_type == 'kidney'): + if random.random() > 0.5: + geo_mask = (geo_mask>=1) + else: + geo_mask = (geo_mask * mask_scan) >=1 + else: + geo_mask = (geo_mask * mask_scan) >=1 + + return geo_mask + + +from .ldm.vq_gan_3d.model.vqgan import VQGAN +import matplotlib.pyplot as plt +import SimpleITK as sitk +from .ldm.ddpm import Unet3D, GaussianDiffusion, Tester +from hydra import initialize, compose +import torch +import yaml +def synt_model_prepare(device, vqgan_ckpt='TumorGeneration/model_weight/recon_96d4_all.ckpt', diffusion_ckpt='TumorGeneration/model_weight/', fold=0, organ='liver'): + with initialize(config_path="diffusion_config/"): + cfg=compose(config_name="ddpm.yaml") + print('diffusion_ckpt',diffusion_ckpt) + vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt) + vqgan = vqgan.to(device) + vqgan.eval() + + early_Unet3D = Unet3D( + dim=cfg.diffusion_img_size, + dim_mults=cfg.dim_mults, + channels=cfg.diffusion_num_channels, + out_dim=cfg.out_dim + ).to(device) + + early_diffusion = GaussianDiffusion( + early_Unet3D, + vqgan_ckpt= vqgan_ckpt, # cfg.vqgan_ckpt, + image_size=cfg.diffusion_img_size, + num_frames=cfg.diffusion_depth_size, + channels=cfg.diffusion_num_channels, + timesteps=4, # cfg.timesteps, + loss_type=cfg.loss_type, + device=device + ).to(device) + + noearly_Unet3D = Unet3D( + dim=cfg.diffusion_img_size, + dim_mults=cfg.dim_mults, + channels=cfg.diffusion_num_channels, + out_dim=cfg.out_dim + ).to(device) + + noearly_diffusion = GaussianDiffusion( + noearly_Unet3D, + vqgan_ckpt= vqgan_ckpt, # cfg.vqgan_ckpt, + image_size=cfg.diffusion_img_size, + num_frames=cfg.diffusion_depth_size, + channels=cfg.diffusion_num_channels, + timesteps=200, # cfg.timesteps, + loss_type=cfg.loss_type, + device=device + ).to(device) + + early_tester = Tester(early_diffusion) + # noearly_tester = Tester(noearly_diffusion) + early_tester.load(diffusion_ckpt+'diff_{}_fold{}_early.pt'.format(organ, fold), map_location=device) + # noearly_tester.load(diffusion_ckpt+'diff_liver_fold{}_noearly_t200.pt'.format(fold), map_location=device) + + # early_checkpoint = torch.load(diffusion_ckpt+'diff_liver_fold{}_early.pt'.format(fold), map_location=device) + noearly_checkpoint = torch.load(diffusion_ckpt+'diff_{}_fold{}_noearly_t200.pt'.format(organ, fold), map_location=device) + # early_diffusion.load_state_dict(early_checkpoint['ema']) + noearly_diffusion.load_state_dict(noearly_checkpoint['ema']) + # early_sampler = DDIMSampler(early_diffusion, schedule="cosine") + noearly_sampler = DDIMSampler(noearly_diffusion, schedule="cosine") + # breakpoint() + return vqgan, early_tester, noearly_sampler + +def synthesize_early_tumor(ct_volume, organ_mask, organ_type, vqgan, tester): + device=ct_volume.device + + # generate tumor mask + tumor_types = ['tiny', 'small'] + # tumor_probs = np.array([0.5, 0.5]) + tumor_probs = np.array([0.2, 0.8]) + total_tumor_mask = [] + organ_mask_np = organ_mask.cpu().numpy() + with torch.no_grad(): + # get model input + for bs in range(organ_mask_np.shape[0]): + synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type) + total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:]) + total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device) + + volume = ct_volume*2.0 - 1.0 + mask = total_tumor_mask*2.0 - 1.0 + mask_ = 1-total_tumor_mask + masked_volume = (volume*mask_).detach() + + volume = volume.permute(0,1,-1,-3,-2) + masked_volume = masked_volume.permute(0,1,-1,-3,-2) + mask = mask.permute(0,1,-1,-3,-2) + + # vqgan encoder inference + masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True) + masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) / + (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:]) + cond = torch.cat((masked_volume_feat, cc), dim=1) + + # diffusion inference and decoder + tester.ema_model.eval() + sample = tester.ema_model.sample(batch_size=volume.shape[0], cond=cond) + + # if organ_type == 'liver' or organ_type == 'kidney' : + + mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0) + sigma = np.random.uniform(0, 4) # (1, 2) + mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma]) + # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy() + + volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0) + sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0) + + mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device) + final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_ + final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0) + # elif organ_type == 'pancreas': + # final_volume_ = (sample+1.0)/2.0 + final_volume_ = final_volume_.permute(0,1,-2,-1,-3) + organ_tumor_mask = torch.ones_like(organ_mask) + organ_tumor_mask[total_tumor_mask==1] = 2 + + return final_volume_, organ_tumor_mask + +def synthesize_medium_tumor(ct_volume, organ_mask, organ_type, vqgan, sampler, ddim_ts=50): + device=ct_volume.device + + # generate tumor mask + # tumor_types = ['large'] + # tumor_probs = np.array([1.0]) + total_tumor_mask = [] + organ_mask_np = organ_mask.cpu().numpy() + with torch.no_grad(): + # get model input + for bs in range(organ_mask_np.shape[0]): + # synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + synthetic_tumor_type = 'medium' + tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type) + total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:]) + total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device) + + volume = ct_volume*2.0 - 1.0 + mask = total_tumor_mask*2.0 - 1.0 + mask_ = 1-total_tumor_mask + masked_volume = (volume*mask_).detach() + + volume = volume.permute(0,1,-1,-3,-2) + masked_volume = masked_volume.permute(0,1,-1,-3,-2) + mask = mask.permute(0,1,-1,-3,-2) + + # vqgan encoder inference + masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True) + masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) / + (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:]) + cond = torch.cat((masked_volume_feat, cc), dim=1) + + # diffusion inference and decoder + shape = masked_volume_feat.shape[-4:] + samples_ddim, _ = sampler.sample(S=ddim_ts, + conditioning=cond, + batch_size=1, + shape=shape, + verbose=False) + samples_ddim = (((samples_ddim + 1.0) / 2.0) * (vqgan.codebook.embeddings.max() - + vqgan.codebook.embeddings.min())) + vqgan.codebook.embeddings.min() + + sample = vqgan.decode(samples_ddim, quantize=True) + + # if organ_type == 'liver' or organ_type == 'kidney': + # post-process + mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0) + sigma = np.random.uniform(0, 4) # (1, 2) + mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma]) + # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy() + + volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0) + sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0) + + mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device) + final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_ + final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0) + # elif organ_type == 'pancreas': + # final_volume_ = (sample+1.0)/2.0 + + final_volume_ = final_volume_.permute(0,1,-2,-1,-3) + organ_tumor_mask = torch.ones_like(organ_mask) + organ_tumor_mask[total_tumor_mask==1] = 2 + + return final_volume_, organ_tumor_mask + +def synthesize_large_tumor(ct_volume, organ_mask, organ_type, vqgan, sampler, ddim_ts=50): + device=ct_volume.device + + # generate tumor mask + # tumor_types = ['large'] + # tumor_probs = np.array([1.0]) + total_tumor_mask = [] + organ_mask_np = organ_mask.cpu().numpy() + with torch.no_grad(): + # get model input + for bs in range(organ_mask_np.shape[0]): + # synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + synthetic_tumor_type = 'large' + tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type) + total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:]) + total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device) + + volume = ct_volume*2.0 - 1.0 + mask = total_tumor_mask*2.0 - 1.0 + mask_ = 1-total_tumor_mask + masked_volume = (volume*mask_).detach() + + volume = volume.permute(0,1,-1,-3,-2) + masked_volume = masked_volume.permute(0,1,-1,-3,-2) + mask = mask.permute(0,1,-1,-3,-2) + + # vqgan encoder inference + masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True) + masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) / + (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:]) + cond = torch.cat((masked_volume_feat, cc), dim=1) + + # diffusion inference and decoder + shape = masked_volume_feat.shape[-4:] + samples_ddim, _ = sampler.sample(S=ddim_ts, + conditioning=cond, + batch_size=1, + shape=shape, + verbose=False) + samples_ddim = (((samples_ddim + 1.0) / 2.0) * (vqgan.codebook.embeddings.max() - + vqgan.codebook.embeddings.min())) + vqgan.codebook.embeddings.min() + + sample = vqgan.decode(samples_ddim, quantize=True) + + # if organ_type == 'liver' or organ_type == 'kidney': + # post-process + mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0) + sigma = np.random.uniform(0, 4) # (1, 2) + mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma]) + # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy() + + volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0) + sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0) + + mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device) + final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_ + final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0) + # elif organ_type == 'pancreas': + # final_volume_ = (sample+1.0)/2.0 + + final_volume_ = final_volume_.permute(0,1,-2,-1,-3) + organ_tumor_mask = torch.ones_like(organ_mask) + organ_tumor_mask[total_tumor_mask==1] = 2 + + return final_volume_, organ_tumor_mask \ No newline at end of file diff --git a/Generation_Pipeline_filter/syn_kidney/TumorGeneration/utils_.py b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/utils_.py new file mode 100644 index 0000000000000000000000000000000000000000..312e72d1571b3cd996fd567bcedf939443a4e182 --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/TumorGeneration/utils_.py @@ -0,0 +1,298 @@ +### Tumor Generateion +import random +import cv2 +import elasticdeform +import numpy as np +from scipy.ndimage import gaussian_filter + +# Step 1: Random select (numbers) location for tumor. +def random_select(mask_scan): + # we first find z index and then sample point with z slice + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0][[0, -1]] + + # we need to strict number z's position (0.3 - 0.7 in the middle of liver) + z = round(random.uniform(0.3, 0.7) * (z_end - z_start)) + z_start + + liver_mask = mask_scan[..., z] + + # erode the mask (we don't want the edge points) + kernel = np.ones((5,5), dtype=np.uint8) + liver_mask = cv2.erode(liver_mask, kernel, iterations=1) + + coordinates = np.argwhere(liver_mask == 1) + random_index = np.random.randint(0, len(coordinates)) + xyz = coordinates[random_index].tolist() # get x,y + xyz.append(z) + potential_points = xyz + + return potential_points + +# Step 2 : generate the ellipsoid +def get_ellipsoid(x, y, z): + """" + x, y, z is the radius of this ellipsoid in x, y, z direction respectly. + """ + sh = (4*x, 4*y, 4*z) + out = np.zeros(sh, int) + aux = np.zeros(sh) + radii = np.array([x, y, z]) + com = np.array([2*x, 2*y, 2*z]) # center point + + # calculate the ellipsoid + bboxl = np.floor(com-radii).clip(0,None).astype(int) + bboxh = (np.ceil(com+radii)+1).clip(None, sh).astype(int) + roi = out[tuple(map(slice,bboxl,bboxh))] + roiaux = aux[tuple(map(slice,bboxl,bboxh))] + logrid = *map(np.square,np.ogrid[tuple( + map(slice,(bboxl-com)/radii,(bboxh-com-1)/radii,1j*(bboxh-bboxl)))]), + dst = (1-sum(logrid)).clip(0,None) + mask = dst>roiaux + roi[mask] = 1 + np.copyto(roiaux,dst,where=mask) + + return out + +def get_fixed_geo(mask_scan, tumor_type): + + enlarge_x, enlarge_y, enlarge_z = 160, 160, 160 + geo_mask = np.zeros((mask_scan.shape[0] + enlarge_x, mask_scan.shape[1] + enlarge_y, mask_scan.shape[2] + enlarge_z), dtype=np.int8) + tiny_radius, small_radius, medium_radius, large_radius = 4, 8, 16, 32 + + if tumor_type == 'tiny': + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Tiny tumor + x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + sigma = random.uniform(0.5,1) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + + if tumor_type == 'small': + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Small tumor + x = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + y = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + z = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + sigma = random.randint(1, 2) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'medium': + num_tumor = random.randint(2, 5) + for _ in range(num_tumor): + # medium tumor + x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + sigma = random.randint(3, 6) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste medium tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'large': + num_tumor = random.randint(1,3) + for _ in range(num_tumor): + # Large tumor + x = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + y = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + z = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + sigma = random.randint(5, 10) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == "mix": + # tiny + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Tiny tumor + x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + sigma = random.uniform(0.5,1) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + + # small + num_tumor = random.randint(5,10) + for _ in range(num_tumor): + # Small tumor + x = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + y = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + z = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + sigma = random.randint(1, 2) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + # medium + num_tumor = random.randint(2, 5) + for _ in range(num_tumor): + # medium tumor + x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + sigma = random.randint(3, 6) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste medium tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + # large + num_tumor = random.randint(1,3) + for _ in range(num_tumor): + # Large tumor + x = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + y = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + z = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + sigma = random.randint(5, 10) + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + geo_mask = geo_mask[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + # texture_map = texture_map[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + geo_mask = (geo_mask * mask_scan) >=1 + + return geo_mask + + +def get_tumor(volume_scan, mask_scan, tumor_type): + tumor_mask = get_fixed_geo(mask_scan, tumor_type) + + sigma = np.random.uniform(1, 2) + # difference = np.random.uniform(65, 145) + difference = 1 + + # blur the boundary + tumor_mask_blur = gaussian_filter(tumor_mask*1.0, sigma) + + + abnormally_full = volume_scan * (1 - mask_scan) + abnormally + abnormally_mask = mask_scan + geo_mask + + return abnormally_full, abnormally_mask + +def SynthesisTumor(volume_scan, mask_scan, tumor_type): + # for speed_generate_tumor, we only send the liver part into the generate program + x_start, x_end = np.where(np.any(mask_scan, axis=(1, 2)))[0][[0, -1]] + y_start, y_end = np.where(np.any(mask_scan, axis=(0, 2)))[0][[0, -1]] + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0][[0, -1]] + + # shrink the boundary + x_start, x_end = max(0, x_start+1), min(mask_scan.shape[0], x_end-1) + y_start, y_end = max(0, y_start+1), min(mask_scan.shape[1], y_end-1) + z_start, z_end = max(0, z_start+1), min(mask_scan.shape[2], z_end-1) + + ct_volume = volume_scan[x_start:x_end, y_start:y_end, z_start:z_end] + organ_mask = mask_scan[x_start:x_end, y_start:y_end, z_start:z_end] + + # input texture shape: 420, 300, 320 + # we need to cut it into the shape of liver_mask + # for examples, the liver_mask.shape = 286, 173, 46; we should change the texture shape + x_length, y_length, z_length = 64, 64, 64 + crop_x = random.randint(x_start, x_end - x_length - 1) # random select the start point, -1 is to avoid boundary check + crop_y = random.randint(y_start, y_end - y_length - 1) + crop_z = random.randint(z_start, z_end - z_length - 1) + + ct_volume, organ_tumor_mask = get_tumor(ct_volume, organ_mask, tumor_type) + volume_scan[x_start:x_end, y_start:y_end, z_start:z_end] = ct_volume + mask_scan[x_start:x_end, y_start:y_end, z_start:z_end] = organ_tumor_mask + + return volume_scan, mask_scan diff --git a/Generation_Pipeline_filter/syn_kidney/healthy_kidney_1k.txt b/Generation_Pipeline_filter/syn_kidney/healthy_kidney_1k.txt new file mode 100644 index 0000000000000000000000000000000000000000..f9487280079db99e7abc891c32997dfa6f4e6751 --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/healthy_kidney_1k.txt @@ -0,0 +1,565 @@ +BDMAP_00002275 +BDMAP_00001907 +BDMAP_00002712 +BDMAP_00004615 +BDMAP_00004651 +BDMAP_00002230 +BDMAP_00002955 +BDMAP_00004183 +BDMAP_00002304 +BDMAP_00002029 +BDMAP_00001646 +BDMAP_00002909 +BDMAP_00002328 +BDMAP_00004829 +BDMAP_00001093 +BDMAP_00002117 +BDMAP_00004600 +BDMAP_00003771 +BDMAP_00001198 +BDMAP_00003451 +BDMAP_00002719 +BDMAP_00002846 +BDMAP_00002282 +BDMAP_00003827 +BDMAP_00001649 +BDMAP_00005141 +BDMAP_00000941 +BDMAP_00002875 +BDMAP_00004641 +BDMAP_00003373 +BDMAP_00001924 +BDMAP_00003897 +BDMAP_00005074 +BDMAP_00001753 +BDMAP_00000101 +BDMAP_00003412 +BDMAP_00002945 +BDMAP_00002598 +BDMAP_00004858 +BDMAP_00001632 +BDMAP_00003327 +BDMAP_00005130 +BDMAP_00004783 +BDMAP_00002844 +BDMAP_00002479 +BDMAP_00001464 +BDMAP_00001809 +BDMAP_00003385 +BDMAP_00003918 +BDMAP_00004995 +BDMAP_00004447 +BDMAP_00003972 +BDMAP_00003438 +BDMAP_00003898 +BDMAP_00001057 +BDMAP_00005005 +BDMAP_00003244 +BDMAP_00003631 +BDMAP_00004103 +BDMAP_00000069 +BDMAP_00001736 +BDMAP_00003002 +BDMAP_00004704 +BDMAP_00001055 +BDMAP_00000447 +BDMAP_00000778 +BDMAP_00005097 +BDMAP_00004264 +BDMAP_00004304 +BDMAP_00005170 +BDMAP_00000547 +BDMAP_00004764 +BDMAP_00004229 +BDMAP_00001414 +BDMAP_00001828 +BDMAP_00003151 +BDMAP_00003769 +BDMAP_00001962 +BDMAP_00003333 +BDMAP_00000676 +BDMAP_00001704 +BDMAP_00004459 +BDMAP_00003683 +BDMAP_00003439 +BDMAP_00004016 +BDMAP_00000438 +BDMAP_00004117 +BDMAP_00001785 +BDMAP_00002688 +BDMAP_00000913 +BDMAP_00000942 +BDMAP_00003400 +BDMAP_00003824 +BDMAP_00000470 +BDMAP_00002918 +BDMAP_00002828 +BDMAP_00004286 +BDMAP_00001845 +BDMAP_00002791 +BDMAP_00004672 +BDMAP_00002717 +BDMAP_00002856 +BDMAP_00002188 +BDMAP_00001701 +BDMAP_00001175 +BDMAP_00002841 +BDMAP_00003254 +BDMAP_00004508 +BDMAP_00000373 +BDMAP_00001565 +BDMAP_00002214 +BDMAP_00000701 +BDMAP_00000690 +BDMAP_00001215 +BDMAP_00000324 +BDMAP_00004015 +BDMAP_00004196 +BDMAP_00001419 +BDMAP_00000618 +BDMAP_00003640 +BDMAP_00001697 +BDMAP_00000332 +BDMAP_00004023 +BDMAP_00002815 +BDMAP_00004199 +BDMAP_00003890 +BDMAP_00002529 +BDMAP_00004843 +BDMAP_00002076 +BDMAP_00004895 +BDMAP_00000623 +BDMAP_00002244 +BDMAP_00000205 +BDMAP_00001185 +BDMAP_00003133 +BDMAP_00001957 +BDMAP_00001015 +BDMAP_00003932 +BDMAP_00001010 +BDMAP_00001102 +BDMAP_00004880 +BDMAP_00004664 +BDMAP_00002748 +BDMAP_00000430 +BDMAP_00004293 +BDMAP_00002829 +BDMAP_00000558 +BDMAP_00000084 +BDMAP_00001438 +BDMAP_00001917 +BDMAP_00004129 +BDMAP_00000232 +BDMAP_00002463 +BDMAP_00004839 +BDMAP_00003664 +BDMAP_00004604 +BDMAP_00002021 +BDMAP_00004550 +BDMAP_00004106 +BDMAP_00004128 +BDMAP_00000696 +BDMAP_00002411 +BDMAP_00003569 +BDMAP_00001912 +BDMAP_00003036 +BDMAP_00001288 +BDMAP_00002216 +BDMAP_00002199 +BDMAP_00000100 +BDMAP_00003634 +BDMAP_00000345 +BDMAP_00000614 +BDMAP_00001769 +BDMAP_00002580 +BDMAP_00004676 +BDMAP_00000388 +BDMAP_00003357 +BDMAP_00004431 +BDMAP_00002359 +BDMAP_00000132 +BDMAP_00004097 +BDMAP_00003847 +BDMAP_00003017 +BDMAP_00003680 +BDMAP_00001737 +BDMAP_00003361 +BDMAP_00003377 +BDMAP_00000437 +BDMAP_00002237 +BDMAP_00003900 +BDMAP_00001754 +BDMAP_00004288 +BDMAP_00002612 +BDMAP_00003329 +BDMAP_00004187 +BDMAP_00000873 +BDMAP_00003525 +BDMAP_00000921 +BDMAP_00004231 +BDMAP_00001343 +BDMAP_00004793 +BDMAP_00001898 +BDMAP_00002271 +BDMAP_00002313 +BDMAP_00002896 +BDMAP_00000851 +BDMAP_00004165 +BDMAP_00003840 +BDMAP_00000338 +BDMAP_00000715 +BDMAP_00004295 +BDMAP_00000236 +BDMAP_00001985 +BDMAP_00003633 +BDMAP_00004825 +BDMAP_00002305 +BDMAP_00001237 +BDMAP_00002419 +BDMAP_00001766 +BDMAP_00004546 +BDMAP_00000881 +BDMAP_00001836 +BDMAP_00003052 +BDMAP_00001502 +BDMAP_00003483 +BDMAP_00003396 +BDMAP_00005119 +BDMAP_00003299 +BDMAP_00000568 +BDMAP_00003590 +BDMAP_00002616 +BDMAP_00001835 +BDMAP_00002172 +BDMAP_00004964 +BDMAP_00002944 +BDMAP_00002465 +BDMAP_00002227 +BDMAP_00001905 +BDMAP_00002603 +BDMAP_00003111 +BDMAP_00004398 +BDMAP_00002373 +BDMAP_00000093 +BDMAP_00001247 +BDMAP_00003172 +BDMAP_00001865 +BDMAP_00001545 +BDMAP_00000411 +BDMAP_00002349 +BDMAP_00001617 +BDMAP_00003884 +BDMAP_00000809 +BDMAP_00003497 +BDMAP_00003961 +BDMAP_00005139 +BDMAP_00001628 +BDMAP_00004969 +BDMAP_00004228 +BDMAP_00001316 +BDMAP_00005160 +BDMAP_00001024 +BDMAP_00005073 +BDMAP_00001209 +BDMAP_00004954 +BDMAP_00003798 +BDMAP_00005063 +BDMAP_00001476 +BDMAP_00000243 +BDMAP_00003809 +BDMAP_00001309 +BDMAP_00003886 +BDMAP_00002758 +BDMAP_00002289 +BDMAP_00001862 +BDMAP_00004804 +BDMAP_00003113 +BDMAP_00001361 +BDMAP_00000692 +BDMAP_00001523 +BDMAP_00004115 +BDMAP_00002387 +BDMAP_00003781 +BDMAP_00000087 +BDMAP_00001823 +BDMAP_00000940 +BDMAP_00004719 +BDMAP_00004624 +BDMAP_00002849 +BDMAP_00003657 +BDMAP_00001461 +BDMAP_00002690 +BDMAP_00003236 +BDMAP_00004558 +BDMAP_00004639 +BDMAP_00004541 +BDMAP_00005083 +BDMAP_00000907 +BDMAP_00000972 +BDMAP_00001200 +BDMAP_00003168 +BDMAP_00000828 +BDMAP_00004450 +BDMAP_00001597 +BDMAP_00003867 +BDMAP_00001746 +BDMAP_00002252 +BDMAP_00002947 +BDMAP_00004878 +BDMAP_00001842 +BDMAP_00002654 +BDMAP_00002185 +BDMAP_00001802 +BDMAP_00001040 +BDMAP_00004198 +BDMAP_00000831 +BDMAP_00004491 +BDMAP_00003109 +BDMAP_00002120 +BDMAP_00001834 +BDMAP_00002619 +BDMAP_00000138 +BDMAP_00004773 +BDMAP_00001236 +BDMAP_00002402 +BDMAP_00001598 +BDMAP_00000714 +BDMAP_00003356 +BDMAP_00000462 +BDMAP_00001114 +BDMAP_00000607 +BDMAP_00004297 +BDMAP_00004841 +BDMAP_00005022 +BDMAP_00000572 +BDMAP_00000541 +BDMAP_00005140 +BDMAP_00004415 +BDMAP_00003946 +BDMAP_00003319 +BDMAP_00003510 +BDMAP_00004163 +BDMAP_00002458 +BDMAP_00005020 +BDMAP_00004511 +BDMAP_00004549 +BDMAP_00005155 +BDMAP_00004147 +BDMAP_00004876 +BDMAP_00002103 +BDMAP_00000882 +BDMAP_00003138 +BDMAP_00005037 +BDMAP_00003853 +BDMAP_00002039 +BDMAP_00000774 +BDMAP_00004741 +BDMAP_00001171 +BDMAP_00004636 +BDMAP_00002332 +BDMAP_00004894 +BDMAP_00002730 +BDMAP_00001125 +BDMAP_00003822 +BDMAP_00003592 +BDMAP_00001368 +BDMAP_00003513 +BDMAP_00003612 +BDMAP_00005169 +BDMAP_00004017 +BDMAP_00002855 +BDMAP_00000152 +BDMAP_00000091 +BDMAP_00004529 +BDMAP_00003443 +BDMAP_00003543 +BDMAP_00002267 +BDMAP_00004462 +BDMAP_00000874 +BDMAP_00002793 +BDMAP_00001471 +BDMAP_00001605 +BDMAP_00000709 +BDMAP_00004435 +BDMAP_00003524 +BDMAP_00000965 +BDMAP_00000939 +BDMAP_00002278 +BDMAP_00002295 +BDMAP_00000971 +BDMAP_00004917 +BDMAP_00003812 +BDMAP_00002401 +BDMAP_00003074 +BDMAP_00004028 +BDMAP_00001982 +BDMAP_00004281 +BDMAP_00000347 +BDMAP_00001732 +BDMAP_00001205 +BDMAP_00001379 +BDMAP_00001095 +BDMAP_00004770 +BDMAP_00002283 +BDMAP_00000052 +BDMAP_00000192 +BDMAP_00003564 +BDMAP_00003427 +BDMAP_00004888 +BDMAP_00005016 +BDMAP_00004745 +BDMAP_00001078 +BDMAP_00001122 +BDMAP_00001584 +BDMAP_00003551 +BDMAP_00002495 +BDMAP_00000589 +BDMAP_00005065 +BDMAP_00002171 +BDMAP_00004830 +BDMAP_00001804 +BDMAP_00004493 +BDMAP_00000400 +BDMAP_00000745 +BDMAP_00001333 +BDMAP_00004890 +BDMAP_00002845 +BDMAP_00001875 +BDMAP_00001096 +BDMAP_00004060 +BDMAP_00002451 +BDMAP_00002523 +BDMAP_00002899 +BDMAP_00000642 +BDMAP_00005075 +BDMAP_00003685 +BDMAP_00004650 +BDMAP_00001618 +BDMAP_00000771 +BDMAP_00003920 +BDMAP_00002309 +BDMAP_00004847 +BDMAP_00002485 +BDMAP_00001590 +BDMAP_00001692 +BDMAP_00003502 +BDMAP_00000431 +BDMAP_00000679 +BDMAP_00002986 +BDMAP_00003277 +BDMAP_00004885 +BDMAP_00000427 +BDMAP_00000716 +BDMAP_00003744 +BDMAP_00001806 +BDMAP_00003857 +BDMAP_00000859 +BDMAP_00001067 +BDMAP_00004121 +BDMAP_00002475 +BDMAP_00002318 +BDMAP_00003114 +BDMAP_00001712 +BDMAP_00001214 +BDMAP_00000362 +BDMAP_00001441 +BDMAP_00003272 +BDMAP_00000956 +BDMAP_00005064 +BDMAP_00000154 +BDMAP_00005186 +BDMAP_00003658 +BDMAP_00002704 +BDMAP_00004796 +BDMAP_00000197 +BDMAP_00005070 +BDMAP_00005001 +BDMAP_00000480 +BDMAP_00005078 +BDMAP_00001564 +BDMAP_00001025 +BDMAP_00003598 +BDMAP_00004262 +BDMAP_00001092 +BDMAP_00004185 +BDMAP_00003776 +BDMAP_00001270 +BDMAP_00000615 +BDMAP_00003141 +BDMAP_00003330 +BDMAP_00000190 +BDMAP_00003650 +BDMAP_00001397 +BDMAP_00005185 +BDMAP_00001966 +BDMAP_00004184 +BDMAP_00004992 +BDMAP_00004416 +BDMAP_00000993 +BDMAP_00001445 +BDMAP_00003482 +BDMAP_00004514 +BDMAP_00001504 +BDMAP_00000416 +BDMAP_00002805 +BDMAP_00002232 +BDMAP_00004384 +BDMAP_00001921 +BDMAP_00001426 +BDMAP_00004910 +BDMAP_00003560 +BDMAP_00003130 +BDMAP_00005108 +BDMAP_00000113 +BDMAP_00001521 +BDMAP_00003556 +BDMAP_00003376 +BDMAP_00000273 +BDMAP_00004735 +BDMAP_00001539 +BDMAP_00004494 +BDMAP_00001212 +BDMAP_00005067 +BDMAP_00000413 +BDMAP_00002863 +BDMAP_00000671 +BDMAP_00004927 +BDMAP_00002167 +BDMAP_00002152 +BDMAP_00005168 +BDMAP_00003911 +BDMAP_00002250 +BDMAP_00003215 +BDMAP_00002737 +BDMAP_00001514 +BDMAP_00003440 +BDMAP_00003031 +BDMAP_00001786 +BDMAP_00000552 +BDMAP_00004943 +BDMAP_00003268 +BDMAP_00002233 +BDMAP_00002362 +BDMAP_00001440 +BDMAP_00000225 +BDMAP_00003347 +BDMAP_00002739 +BDMAP_00003479 +BDMAP_00003481 +BDMAP_00003326 +BDMAP_00000683 +BDMAP_00004378 +BDMAP_00003367 +BDMAP_00000855 +BDMAP_00002298 +BDMAP_00004077 +BDMAP_00002253 +BDMAP_00001331 +BDMAP_00000542 +BDMAP_00002924 +BDMAP_00005092 +BDMAP_00004374 +BDMAP_00004509 +BDMAP_00000264 +BDMAP_00000918 +BDMAP_00000030 diff --git a/Generation_Pipeline_filter/syn_kidney/requirements.txt b/Generation_Pipeline_filter/syn_kidney/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9c1067c8eb4f44699bbf517601396113cea4b370 --- /dev/null +++ b/Generation_Pipeline_filter/syn_kidney/requirements.txt @@ -0,0 +1,94 @@ +absl-py==1.1.0 +accelerate==0.11.0 +aiohttp==3.8.1 +aiosignal==1.2.0 +antlr4-python3-runtime==4.9.3 +async-timeout==4.0.2 +attrs==21.4.0 +autopep8==1.6.0 +cachetools==5.2.0 +certifi==2022.6.15 +charset-normalizer==2.0.12 +click==8.1.3 +cycler==0.11.0 +Deprecated==1.2.13 +docker-pycreds==0.4.0 +einops==0.4.1 +einops-exts==0.0.3 +ema-pytorch==0.0.8 +fonttools==4.34.4 +frozenlist==1.3.0 +fsspec==2022.5.0 +ftfy==6.1.1 +future==0.18.2 +gitdb==4.0.9 +GitPython==3.1.27 +google-auth==2.9.0 +google-auth-oauthlib==0.4.6 +grpcio==1.47.0 +h5py==3.7.0 +humanize==4.2.2 +hydra-core==1.2.0 +idna==3.3 +imageio==2.19.3 +imageio-ffmpeg==0.4.7 +importlib-metadata==4.12.0 +importlib-resources==5.9.0 +joblib==1.1.0 +kiwisolver==1.4.3 +lxml==4.9.1 +Markdown==3.3.7 +matplotlib==3.5.2 +multidict==6.0.2 +networkx==2.8.5 +nibabel==4.0.1 +nilearn==0.9.1 +numpy==1.23.0 +oauthlib==3.2.0 +omegaconf==2.2.3 +pandas==1.4.3 +Pillow==9.1.1 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pycodestyle==2.8.0 +pyDeprecate==0.3.1 +pydicom==2.3.0 +pytorch-lightning==1.6.4 +pytz==2022.1 +PyWavelets==1.3.0 +PyYAML==6.0 +pyzmq==19.0.2 +regex==2022.6.2 +requests==2.28.0 +requests-oauthlib==1.3.1 +rotary-embedding-torch==0.1.5 +rsa==4.8 +scikit-image==0.19.3 +scikit-learn==1.1.2 +scikit-video==1.1.11 +scipy==1.8.1 +seaborn==0.11.2 +sentry-sdk==1.7.2 +setproctitle==1.2.3 +shortuuid==1.0.9 +SimpleITK==2.1.1.2 +smmap==5.0.0 +tensorboard==2.9.1 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.1 +threadpoolctl==3.1.0 +tifffile==2022.8.3 +toml==0.10.2 +torch-tb-profiler==0.4.0 +torchio==0.18.80 +torchmetrics==0.9.1 +tqdm==4.64.0 +typing_extensions==4.2.0 +urllib3==1.26.9 +wandb==0.12.21 +Werkzeug==2.1.2 +wrapt==1.14.1 +yarl==1.7.2 +zipp==3.8.0 +wandb +tensorboardX==2.4.1 diff --git a/Generation_Pipeline_filter/syn_liver/CT_syn_data.py b/Generation_Pipeline_filter/syn_liver/CT_syn_data.py new file mode 100644 index 0000000000000000000000000000000000000000..f362dc24a8c53afc66b1327d43009009b27190f7 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/CT_syn_data.py @@ -0,0 +1,242 @@ +import os, time, csv +import numpy as np +import torch +from sklearn.metrics import confusion_matrix +from scipy import ndimage +from scipy.ndimage import label +from functools import partial +import monai +from monai.transforms import AsDiscrete,AsDiscreted,Compose,Invertd,SaveImaged +from monai import transforms, data +from TumorGeneration.utils import synthesize_early_tumor, synthesize_medium_tumor, synthesize_large_tumor, synt_model_prepare +import nibabel as nib + +import warnings +warnings.filterwarnings("ignore") + +import argparse +parser = argparse.ArgumentParser(description='liver tumor validation') + +# file dir +parser.add_argument('--data_root', default='/mnt/data/qichen/data/AILab/AbdomenAtlasX_Mini_synt', type=str) +parser.add_argument('--organ_type', default='liver', type=str) +parser.add_argument('--save_dir', default='out', type=str) +parser.add_argument('--data_file', default='healthy_liver_1k.txt', type=str) +parser.add_argument('--ddim_ts', default=50, type=int) +parser.add_argument('--fg_thresh', default=30, type=int) +parser.add_argument('--start', default=0, type=int) +parser.add_argument('--end', default=1000, type=int) + +def voxel2R(A): + return (np.array(A)/4*3/np.pi)**(1/3) + +class RandCropByPosNegLabeld_select(transforms.RandCropByPosNegLabeld): + def __init__(self, keys, label_key, spatial_size, + pos=1.0, neg=1.0, num_samples=1, + image_key=None, image_threshold=0.0, allow_missing_keys=True, + fg_thresh=0): + super().__init__(keys=keys, label_key=label_key, spatial_size=spatial_size, + pos=pos, neg=neg, num_samples=num_samples, + image_key=image_key, image_threshold=image_threshold, allow_missing_keys=allow_missing_keys) + self.fg_thresh = fg_thresh + + def R2voxel(self,R): + return (4/3*np.pi)*(R)**(3) + + def __call__(self, data): + d = dict(data) + data_name = d['name'] + d.pop('name') + + if '10_Decathlon' in data_name or '05_KiTS' in data_name: + d_crop = super().__call__(d) + + else: + flag=0 + while 1: + flag+=1 + + d_crop = super().__call__(d) + pixel_num = (d_crop[0]['label']>0).sum() + + if pixel_num > self.R2voxel(self.fg_thresh): + break + if flag>5 and pixel_num > self.R2voxel(max(self.fg_thresh-5, 5)): + break + if flag>10 and pixel_num > self.R2voxel(max(self.fg_thresh-10, 5)): + break + if flag>15 and pixel_num > self.R2voxel(max(self.fg_thresh-15, 5)): + break + if flag>20 and pixel_num > self.R2voxel(max(self.fg_thresh-20, 5)): + break + if flag>25 and pixel_num > self.R2voxel(max(self.fg_thresh-25, 5)): + break + if flag>50: + break + + d_crop[0]['name'] = data_name + + return d_crop + +def _get_loader(args): + # val_data_dir = args.val_dir + # datalist_json = args.json_dir + val_org_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label", "raw_image"]), + transforms.AddChanneld(keys=["image", "label", "raw_image"]), + transforms.Orientationd(keys=["image", "label"], axcodes="RAS"), + transforms.Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear")), + transforms.ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), + transforms.SpatialPadd(keys=["image", "label"], mode=["minimum", "constant"], spatial_size=[96, 96, 96]), + RandCropByPosNegLabeld_select( + keys=["image", "label", "name"], + label_key="label", + spatial_size=(96, 96, 96), + pos=1, + neg=0, + num_samples=1, + image_key="image", + image_threshold=0, + fg_thresh = args.fg_thresh, + ), + transforms.ToTensord(keys=["image", "label", "raw_image"]), + ] + ) + + val_img=[] + val_lbl=[] + val_name=[] + + for line in open(args.data_file): + # name = line.strip().split()[1].split('.')[0] + # val_img.append(args.data_root + line.strip().split()[0]) + # val_lbl.append(args.data_root + line.strip().split()[1]) + # breakpoint() + name = line.strip() + val_img.append(os.path.join(args.data_root, name, 'ct.nii.gz')) + val_lbl.append(os.path.join(args.data_root, name, 'segmentations/liver.nii.gz')) + val_name.append(name) + data_dicts_val = [{'image': image, 'raw_image':image, 'label': label, 'name': name} + for image, label, name in zip(val_img, val_lbl, val_name)] + + if args.end < len(data_dicts_val): + data_dicts_val = data_dicts_val[args.start:args.end] + else: + data_dicts_val = data_dicts_val[args.start:] + print('val len {}'.format(len(data_dicts_val))) + val_org_ds = data.Dataset(data_dicts_val, transform=val_org_transform) + val_org_loader = data.DataLoader(val_org_ds, batch_size=1, shuffle=False, num_workers=4, sampler=None, pin_memory=True) + + post_transforms = Compose([ + Invertd( + keys=['image'], + transform=val_org_transform, + orig_keys="image", + nearest_interp=False, + # nearest_interp=True, + to_tensor=True, + ), + Invertd( + keys=['label'], + transform=val_org_transform, + orig_keys="label", + nearest_interp=False, + # nearest_interp=True, + to_tensor=True, + ) + ]) + return val_org_loader, post_transforms + +def main(): + args = parser.parse_args() + output_dir = args.save_dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + print("MAIN Argument values:") + for k, v in vars(args).items(): + print(k, '=>', v) + print('-----------------') + + ## loader and post_transform + val_loader, post_transforms = _get_loader(args) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device) + model.load_state_dict(torch.load("../best_metric_model_classification3d_dict.pth")) + model.eval() + + start_time=0 + with torch.no_grad(): + for idx, val_data in enumerate(val_loader): + print('idx',idx) + if idx == 0: + start_time = time.time() + # val_inputs = val_data["image"] + # name = val_data['label_meta_dict']['filename_or_obj'][0].split('/')[-1].split('.')[0] + + vqgan, early_sampler, noearly_sampler= synt_model_prepare(device = torch.device("cuda"), fold=0, organ=args.organ_type) + + healthy_data, healthy_target, data_names, raw_data = val_data['image'], val_data['label'], val_data['name'], val_data['raw_image'] + case_name = data_names[0].split('/')[-1] + print('case_name', case_name) + original_affine = val_data["label_meta_dict"]["original_affine"][0].numpy() + if healthy_target.sum() == 0: + val_data = [post_transforms(i) for i in data.decollate_batch(val_data)] + tumor_mask = val_data[0]['label'][0].cpu().numpy().astype(np.uint8) + tumor_mask_ = np.zeros_like(tumor_mask) + nib.save(nib.Nifti1Image(tumor_mask_, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/liver_tumor.nii.gz')) + continue + + healthy_data, healthy_target = healthy_data.cuda(), healthy_target.cuda() + healthy_target = (healthy_target==1).to(healthy_target) + + tumor_types = ['early', 'medium', 'large'] + # tumor_probs = np.array([0.45, 0.45, 0.1]) + # tumor_probs = np.array([1.0, 0.0, 0.0]) + tumor_probs = np.array([0.5, 0.4, 0.1]) + synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + print('synthetic_tumor_type',synthetic_tumor_type) + flag=0 + while 1: + flag+=1 + if synthetic_tumor_type == 'early': + synt_data, synt_target = synthesize_early_tumor(healthy_data, healthy_target, args.organ_type, vqgan, early_sampler) + elif synthetic_tumor_type == 'medium': + synt_data, synt_target = synthesize_medium_tumor(healthy_data, healthy_target, args.organ_type, vqgan, noearly_sampler, ddim_ts=args.ddim_ts) + elif synthetic_tumor_type == 'large': + synt_data, synt_target = synthesize_large_tumor(healthy_data, healthy_target, args.organ_type, vqgan, noearly_sampler, ddim_ts=args.ddim_ts) + + syn_confidence = model(synt_data).sigmoid()[:,1] + if syn_confidence>0.01: + break + elif flag > 40 and syn_confidence>0.005: + break + elif flag > 60 and syn_confidence>0.001: + break + val_data['image'] = synt_data.detach() + val_data['label'] = synt_target.detach() + + val_data = [post_transforms(i) for i in data.decollate_batch(val_data)] + synt_data = val_data[0]['image'][0] + synt_target = val_data[0]['label'][0] + final_data = raw_data[0,0] + + synt_data = (synt_data*(250+175)-175) + final_data[synt_target>1] = synt_data[synt_target>1] + final_data = final_data.cpu().numpy() + final_label = (synt_target>=1.5).cpu().numpy().astype(np.uint8) + + os.makedirs(os.path.join(output_dir, f'{case_name}'), exist_ok=True) + os.makedirs(os.path.join(output_dir, f'{case_name}/segmentations'), exist_ok=True) + nib.save(nib.Nifti1Image(final_data, original_affine), os.path.join(output_dir, f'{case_name}/ct.nii.gz')) + nib.save(nib.Nifti1Image(final_label, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/liver_tumor.nii.gz')) + # breakpoint() + # nib.save(nib.Nifti1Image(synt_data.cpu().numpy(), original_affine), os.path.join(output_dir, 'synt_data.nii.gz')) + # nib.save(nib.Nifti1Image(synt_target.cpu().numpy(), original_affine), os.path.join(output_dir, 'synt_target.nii.gz')) + print('time = ', time.time()-start_time) + start_time = time.time() + + # breakpoint() +if __name__ == "__main__": + main() diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/.DS_Store b/Generation_Pipeline_filter/syn_liver/TumorGeneration/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..fb0d207e124cfecc777faae5b4a56c5ca1b9cd2e Binary files /dev/null and b/Generation_Pipeline_filter/syn_liver/TumorGeneration/.DS_Store differ diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/README.md b/Generation_Pipeline_filter/syn_liver/TumorGeneration/README.md new file mode 100644 index 0000000000000000000000000000000000000000..201911031ca988c410781a60dc3c192a65ee56b3 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/README.md @@ -0,0 +1,5 @@ +```bash +wget https://www.dropbox.com/scl/fi/k856fhk60kck8uqxtxazw/model_weight.tar.gz?rlkey=hrcn4cbt690dzern1bkbfejll +mv model_weight.tar.gz?rlkey=hrcn4cbt690dzern1bkbfejll model_weight.tar.gz +tar -xzvf model_weight.tar.gz +``` diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/TumorGenerated.py b/Generation_Pipeline_filter/syn_liver/TumorGeneration/TumorGenerated.py new file mode 100644 index 0000000000000000000000000000000000000000..9983cf047b1532f5edd20c0d4c78102d6d611eb9 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/TumorGenerated.py @@ -0,0 +1,39 @@ +import random +from typing import Hashable, Mapping, Dict + +from monai.config import KeysCollection +from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms.transform import MapTransform, RandomizableTransform + +from .utils_ import SynthesisTumor +import numpy as np + +class TumorGenerated(RandomizableTransform, MapTransform): + def __init__(self, + keys: KeysCollection, + prob: float = 0.1, + tumor_prob = [0.2, 0.2, 0.2, 0.2, 0.2], + allow_missing_keys: bool = False + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob) + random.seed(0) + np.random.seed(0) + + self.tumor_types = ['tiny', 'small', 'medium', 'large', 'mix'] + + assert len(tumor_prob) == 5 + self.tumor_prob = np.array(tumor_prob) + + + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + self.randomize(None) + + if self._do_transform and (np.max(d['label']) <= 1): + tumor_type = np.random.choice(self.tumor_types, p=self.tumor_prob.ravel()) + + d['image'][0], d['label'][0] = SynthesisTumor(d['image'][0], d['label'][0], tumor_type) + # print(tumor_type, d['image'].shape, np.max(d['label'])) + return d diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/__init__.py b/Generation_Pipeline_filter/syn_liver/TumorGeneration/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d9bbd5e8cede113145b2742ebdd63d7226fe6396 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/__init__.py @@ -0,0 +1,5 @@ +### Online Version TumorGeneration ### + +from .TumorGenerated import TumorGenerated + +# from .utils import synthesize_early_tumor, synthesize_medium_tumor, synthesize_large_tumor \ No newline at end of file diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc b/Generation_Pipeline_filter/syn_liver/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a762bebf847cb7dff5e0796b57befdc1fcedb6c8 Binary files /dev/null and b/Generation_Pipeline_filter/syn_liver/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter/syn_liver/TumorGeneration/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd297ba7cd9667fbf7855b316c40abb0e1ca9876 Binary files /dev/null and b/Generation_Pipeline_filter/syn_liver/TumorGeneration/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/__pycache__/utils.cpython-38.pyc b/Generation_Pipeline_filter/syn_liver/TumorGeneration/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31e0888502e562dd2f5b9041fe35612ba66fa973 Binary files /dev/null and b/Generation_Pipeline_filter/syn_liver/TumorGeneration/__pycache__/utils.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/__pycache__/utils_.cpython-38.pyc b/Generation_Pipeline_filter/syn_liver/TumorGeneration/__pycache__/utils_.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93b14c3f1cbb06b068dc9f9fd400173339c742f8 Binary files /dev/null and b/Generation_Pipeline_filter/syn_liver/TumorGeneration/__pycache__/utils_.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/diffusion_config/ddpm.yaml b/Generation_Pipeline_filter/syn_liver/TumorGeneration/diffusion_config/ddpm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..474a973b1f76c2a026e2df916c4a153b5bbea05f --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/diffusion_config/ddpm.yaml @@ -0,0 +1,29 @@ +vqgan_ckpt: None + +# Have to be derived from VQ-GAN Latent space dimensions +diffusion_img_size: 24 +diffusion_depth_size: 24 +diffusion_num_channels: 17 # 17 +out_dim: 8 +dim_mults: [1,2,4,8] +results_folder: checkpoints/ddpm/ +results_folder_postfix: 'own_dataset_t2' +load_milestone: False # False + +batch_size: 2 # 40 +num_workers: 20 +logger: wandb +objective: pred_x0 +save_and_sample_every: 1000 +denoising_fn: Unet3D +train_lr: 1e-4 +timesteps: 2 # number of steps +sampling_timesteps: 250 # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) +loss_type: l1 # L1 or L2 +train_num_steps: 700000 # total training steps +gradient_accumulate_every: 2 # gradient accumulation steps +ema_decay: 0.995 # exponential moving average decay +amp: False # turn on mixed precision +num_sample_rows: 1 +gpus: 0 + diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/diffusion_config/vq_gan_3d.yaml b/Generation_Pipeline_filter/syn_liver/TumorGeneration/diffusion_config/vq_gan_3d.yaml new file mode 100644 index 0000000000000000000000000000000000000000..29940377c5cadb0c06322de8ac60d0713f799024 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/diffusion_config/vq_gan_3d.yaml @@ -0,0 +1,37 @@ +seed: 1234 +batch_size: 2 # 30 +num_workers: 32 # 30 + +gpus: 1 +accumulate_grad_batches: 1 +default_root_dir: checkpoints/vq_gan/ +default_root_dir_postfix: 'flair' +resume_from_checkpoint: +max_steps: -1 +max_epochs: -1 +precision: 16 +gradient_clip_val: 1.0 + + +embedding_dim: 8 # 256 +n_codes: 16384 # 2048 +n_hiddens: 16 +lr: 3e-4 +downsample: [2, 2, 2] # [4, 4, 4] +disc_channels: 64 +disc_layers: 3 +discriminator_iter_start: 10000 # 50000 +disc_loss_type: hinge +image_gan_weight: 1.0 +video_gan_weight: 1.0 +l1_weight: 4.0 +gan_feat_weight: 4.0 # 0.0 +perceptual_weight: 4.0 # 0.0 +i3d_feat: False +restart_thres: 1.0 +no_random_restart: False +norm_type: group +padding_type: replicate +num_groups: 32 + + diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__init__.py b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a01336e46e37471097d8e1420d0ffd5d803a1edd --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__init__.py @@ -0,0 +1 @@ +from .diffusion import Unet3D, GaussianDiffusion, Tester diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..119359052658ded8dc8641a04f45d1a2d725bbe7 Binary files /dev/null and b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61c35bc9d5d5f16280ac8a6d409ad3997055ad61 Binary files /dev/null and b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfebf650dd0d824a0faf7abdf4905042d8e8c9f1 Binary files /dev/null and b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..075077773ff5fe7438b8c047c686c8a2ab66f1e5 Binary files /dev/null and b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7ac1a6aea44a1d1741652f2f0a598fb18b2730d Binary files /dev/null and b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20c6786f9995cbd455c900944b0ab9501a97c202 Binary files /dev/null and b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e2d3c12dd574f975924551cf46dcaa864b9f8c9 Binary files /dev/null and b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/ddim.py b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..1bc51f94355732aedb0ffd254b8166144c468370 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/ddim.py @@ -0,0 +1,206 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from .util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): # "uniform" 'quad' + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + # breakpoint() + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, T, H, W = shape + # breakpoint() + size = (batch_size, C, T, H, W) + # print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + # iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(time_range): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + # breakpoint() + e_t = self.model.denoise_fn(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.denoise_fn(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/diffusion.py b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..87111dc3b2ab671e798efc3a320ae81d024f20b9 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/diffusion.py @@ -0,0 +1,1016 @@ +"Largely taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch" + +import math +import copy +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial + +from torch.utils import data +from pathlib import Path +from torch.optim import Adam +from torchvision import transforms as T, utils +from torch.cuda.amp import autocast, GradScaler +from PIL import Image + +from tqdm import tqdm +from einops import rearrange +from einops_exts import check_shape, rearrange_many + +from rotary_embedding_torch import RotaryEmbedding + +from .text import tokenize, bert_embed, BERT_MODEL_DIM +from torch.utils.data import Dataset, DataLoader +from ..vq_gan_3d.model.vqgan import VQGAN + +import matplotlib.pyplot as plt + +# helpers functions + + +def exists(x): + return x is not None + + +def noop(*args, **kwargs): + pass + + +def is_odd(n): + return (n % 2) == 1 + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +def cycle(dl): + while True: + for data in dl: + yield data + + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + + +def prob_mask_like(shape, prob, device): + if prob == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif prob == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob + + +def is_list_str(x): + if not isinstance(x, (list, tuple)): + return False + return all([type(el) == str for el in x]) + +# relative positional bias + + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads=8, + num_buckets=32, + max_distance=128 + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / + max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min( + val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype=torch.long, device=device) + k_pos = torch.arange(n, dtype=torch.long, device=device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket( + rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') + +# small helper modules + + +class EMA(): + def __init__(self, beta): + super().__init__() + self.beta = beta + + def update_model_average(self, ma_model, current_model): + for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + old_weight, up_weight = ma_params.data, current_params.data + ma_params.data = self.update_average(old_weight, up_weight) + + def update_average(self, old, new): + if old is None: + return new + return old * self.beta + (1 - self.beta) * new + + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +def Upsample(dim): + return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) + + +def Downsample(dim): + return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) + + +class LayerNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1)) + + def forward(self, x): + var = torch.var(x, dim=1, unbiased=False, keepdim=True) + mean = torch.mean(x, dim=1, keepdim=True) + return (x - mean) / (var + self.eps).sqrt() * self.gamma + + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = LayerNorm(dim) + + def forward(self, x, **kwargs): + x = self.norm(x) + return self.fn(x, **kwargs) + +# building block modules + + +class Block(nn.Module): + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding=(0, 1, 1)) + self.norm = nn.GroupNorm(groups, dim_out) + self.act = nn.SiLU() + + def forward(self, x, scale_shift=None): + x = self.proj(x) + x = self.norm(x) + + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + + return self.act(x) + + +class ResnetBlock(nn.Module): + def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): + super().__init__() + self.mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(time_emb_dim, dim_out * 2) + ) if exists(time_emb_dim) else None + + self.block1 = Block(dim, dim_out, groups=groups) + self.block2 = Block(dim_out, dim_out, groups=groups) + self.res_conv = nn.Conv3d( + dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb=None): + + scale_shift = None + if exists(self.mlp): + assert exists(time_emb), 'time emb must be passed in' + time_emb = self.mlp(time_emb) + time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') + scale_shift = time_emb.chunk(2, dim=1) + + h = self.block1(x, scale_shift=scale_shift) + + h = self.block2(h) + return h + self.res_conv(x) + + +class SpatialLinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, f, h, w = x.shape + x = rearrange(x, 'b c f h w -> (b f) c h w') + + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = rearrange_many( + qkv, 'b (h c) x y -> b h c (x y)', h=self.heads) + + q = q.softmax(dim=-2) + k = k.softmax(dim=-1) + + q = q * self.scale + context = torch.einsum('b h d n, b h e n -> b h d e', k, v) + + out = torch.einsum('b h d e, b h d n -> b h e n', context, q) + out = rearrange(out, 'b h c (x y) -> b (h c) x y', + h=self.heads, x=h, y=w) + out = self.to_out(out) + return rearrange(out, '(b f) c h w -> b c f h w', b=b) + +# attention along space and time + + +class EinopsToAndFrom(nn.Module): + def __init__(self, from_einops, to_einops, fn): + super().__init__() + self.from_einops = from_einops + self.to_einops = to_einops + self.fn = fn + + def forward(self, x, **kwargs): + shape = x.shape + reconstitute_kwargs = dict( + tuple(zip(self.from_einops.split(' '), shape))) + x = rearrange(x, f'{self.from_einops} -> {self.to_einops}') + x = self.fn(x, **kwargs) + x = rearrange( + x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + heads=4, + dim_head=32, + rotary_emb=None + ): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + + self.rotary_emb = rotary_emb + self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False) + self.to_out = nn.Linear(hidden_dim, dim, bias=False) + + def forward( + self, + x, + pos_bias=None, + focus_present_mask=None + ): + n, device = x.shape[-2], x.device + + qkv = self.to_qkv(x).chunk(3, dim=-1) + + if exists(focus_present_mask) and focus_present_mask.all(): + # if all batch samples are focusing on present + # it would be equivalent to passing that token's values through to the output + values = qkv[-1] + return self.to_out(values) + + # split out heads + + q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads) + + # scale + + q = q * self.scale + + # rotate positions into queries and keys for time attention + + if exists(self.rotary_emb): + q = self.rotary_emb.rotate_queries_or_keys(q) + k = self.rotary_emb.rotate_queries_or_keys(k) + + # similarity + + sim = einsum('... h i d, ... h j d -> ... h i j', q, k) + + # relative positional bias + + if exists(pos_bias): + sim = sim + pos_bias + + if exists(focus_present_mask) and not (~focus_present_mask).all(): + attend_all_mask = torch.ones( + (n, n), device=device, dtype=torch.bool) + attend_self_mask = torch.eye(n, device=device, dtype=torch.bool) + + mask = torch.where( + rearrange(focus_present_mask, 'b -> b 1 1 1 1'), + rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), + rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), + ) + + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + # numerical stability + + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + # aggregate values + + out = einsum('... h i j, ... h j d -> ... h i d', attn, v) + out = rearrange(out, '... h n d -> ... n (h d)') + return self.to_out(out) + +# model + + +class Unet3D(nn.Module): + def __init__( + self, + dim, + cond_dim=None, + out_dim=None, + dim_mults=(1, 2, 4, 8), + channels=3, + attn_heads=8, + attn_dim_head=32, + use_bert_text_cond=False, + init_dim=None, + init_kernel_size=7, + use_sparse_linear_attn=True, + block_type='resnet', + resnet_groups=8 + ): + super().__init__() + self.channels = channels + + # temporal attention and its relative positional encoding + + rotary_emb = RotaryEmbedding(min(32, attn_dim_head)) + + def temporal_attn(dim): return EinopsToAndFrom('b c f h w', 'b (h w) f c', Attention( + dim, heads=attn_heads, dim_head=attn_dim_head, rotary_emb=rotary_emb)) + + # realistically will not be able to generate that many frames of video... yet + self.time_rel_pos_bias = RelativePositionBias( + heads=attn_heads, max_distance=32) + + # initial conv + + init_dim = default(init_dim, dim) + assert is_odd(init_kernel_size) + + init_padding = init_kernel_size // 2 + self.init_conv = nn.Conv3d(channels, init_dim, (1, init_kernel_size, + init_kernel_size), padding=(0, init_padding, init_padding)) + + self.init_temporal_attn = Residual( + PreNorm(init_dim, temporal_attn(init_dim))) + + # dimensions + + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + # time conditioning + + time_dim = dim * 4 + self.time_mlp = nn.Sequential( + SinusoidalPosEmb(dim), + nn.Linear(dim, time_dim), + nn.GELU(), + nn.Linear(time_dim, time_dim) + ) + + # text conditioning + + self.has_cond = exists(cond_dim) or use_bert_text_cond + cond_dim = BERT_MODEL_DIM if use_bert_text_cond else cond_dim + + self.null_cond_emb = nn.Parameter( + torch.randn(1, cond_dim)) if self.has_cond else None + + cond_dim = time_dim + int(cond_dim or 0) + + # layers + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + + num_resolutions = len(in_out) + # block type + + block_klass = partial(ResnetBlock, groups=resnet_groups) + block_klass_cond = partial(block_klass, time_emb_dim=cond_dim) + + # modules for all layers + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + + self.downs.append(nn.ModuleList([ + block_klass_cond(dim_in, dim_out), + block_klass_cond(dim_out, dim_out), + Residual(PreNorm(dim_out, SpatialLinearAttention( + dim_out, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(), + Residual(PreNorm(dim_out, temporal_attn(dim_out))), + Downsample(dim_out) if not is_last else nn.Identity() + ])) + + mid_dim = dims[-1] + self.mid_block1 = block_klass_cond(mid_dim, mid_dim) + + spatial_attn = EinopsToAndFrom( + 'b c f h w', 'b f (h w) c', Attention(mid_dim, heads=attn_heads)) + + self.mid_spatial_attn = Residual(PreNorm(mid_dim, spatial_attn)) + self.mid_temporal_attn = Residual( + PreNorm(mid_dim, temporal_attn(mid_dim))) + + self.mid_block2 = block_klass_cond(mid_dim, mid_dim) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): + is_last = ind >= (num_resolutions - 1) + + self.ups.append(nn.ModuleList([ + block_klass_cond(dim_out * 2, dim_in), + block_klass_cond(dim_in, dim_in), + Residual(PreNorm(dim_in, SpatialLinearAttention( + dim_in, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(), + Residual(PreNorm(dim_in, temporal_attn(dim_in))), + Upsample(dim_in) if not is_last else nn.Identity() + ])) + + out_dim = default(out_dim, channels) + self.final_conv = nn.Sequential( + block_klass(dim * 2, dim), + nn.Conv3d(dim, out_dim, 1) + ) + + def forward_with_cond_scale( + self, + *args, + cond_scale=2., + **kwargs + ): + logits = self.forward(*args, null_cond_prob=0., **kwargs) + if cond_scale == 1 or not self.has_cond: + return logits + + null_logits = self.forward(*args, null_cond_prob=1., **kwargs) + return null_logits + (logits - null_logits) * cond_scale + + def forward( + self, + x, + time, + cond=None, + null_cond_prob=0., + focus_present_mask=None, + # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time) + prob_focus_present=0. + ): + assert not (self.has_cond and not exists(cond) + ), 'cond must be passed in if cond_dim specified' + x = torch.cat([x, cond], dim=1) + + batch, device = x.shape[0], x.device + + focus_present_mask = default(focus_present_mask, lambda: prob_mask_like( + (batch,), prob_focus_present, device=device)) + + time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device=x.device) + + x = self.init_conv(x) + r = x.clone() + + x = self.init_temporal_attn(x, pos_bias=time_rel_pos_bias) + + t = self.time_mlp(time) if exists(self.time_mlp) else None # [2, 128] + + # classifier free guidance + + if self.has_cond: + batch, device = x.shape[0], x.device + mask = prob_mask_like((batch,), null_cond_prob, device=device) + cond = torch.where(rearrange(mask, 'b -> b 1'), + self.null_cond_emb, cond) + t = torch.cat((t, cond), dim=-1) + + h = [] + + for block1, block2, spatial_attn, temporal_attn, downsample in self.downs: + x = block1(x, t) + x = block2(x, t) + x = spatial_attn(x) + x = temporal_attn(x, pos_bias=time_rel_pos_bias, + focus_present_mask=focus_present_mask) + h.append(x) + x = downsample(x) + + # [2, 256, 32, 4, 4] + x = self.mid_block1(x, t) + x = self.mid_spatial_attn(x) + x = self.mid_temporal_attn( + x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask) + x = self.mid_block2(x, t) + + for block1, block2, spatial_attn, temporal_attn, upsample in self.ups: + x = torch.cat((x, h.pop()), dim=1) + x = block1(x, t) + x = block2(x, t) + x = spatial_attn(x) + x = temporal_attn(x, pos_bias=time_rel_pos_bias, + focus_present_mask=focus_present_mask) + x = upsample(x) + + x = torch.cat((x, r), dim=1) + return self.final_conv(x) + +# gaussian diffusion trainer class + + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = torch.linspace(0, timesteps, steps, dtype=torch.float64) + alphas_cumprod = torch.cos( + ((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return torch.clip(betas, 0, 0.9999) + + +class GaussianDiffusion(nn.Module): + def __init__( + self, + denoise_fn, + *, + image_size, + num_frames, + text_use_bert_cls=False, + channels=3, + timesteps=1000, + loss_type='l1', + use_dynamic_thres=False, # from the Imagen paper + dynamic_thres_percentile=0.9, + vqgan_ckpt=None, + device=None + ): + super().__init__() + self.channels = channels + self.image_size = image_size + self.num_frames = num_frames + self.denoise_fn = denoise_fn + self.device = device + + if vqgan_ckpt: + self.vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt).cuda() + self.vqgan.eval() + else: + self.vqgan = None + + betas = cosine_beta_schedule(timesteps) + + alphas = 1. - betas + alphas_cumprod = torch.cumprod(alphas, axis=0) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.loss_type = loss_type + + # register buffer helper function that casts float64 to float32 + + def register_buffer(name, val): return self.register_buffer( + name, val.to(torch.float32)) + + register_buffer('betas', betas) + register_buffer('alphas_cumprod', alphas_cumprod) + register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) + + # calculations for diffusion q(x_t | x_{t-1}) and others + + register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) + register_buffer('sqrt_one_minus_alphas_cumprod', + torch.sqrt(1. - alphas_cumprod)) + register_buffer('log_one_minus_alphas_cumprod', + torch.log(1. - alphas_cumprod)) + register_buffer('sqrt_recip_alphas_cumprod', + torch.sqrt(1. / alphas_cumprod)) + register_buffer('sqrt_recipm1_alphas_cumprod', + torch.sqrt(1. / alphas_cumprod - 1)) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + + posterior_variance = betas * \ + (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + + register_buffer('posterior_variance', posterior_variance) + + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + + register_buffer('posterior_log_variance_clipped', + torch.log(posterior_variance.clamp(min=1e-20))) + register_buffer('posterior_mean_coef1', betas * + torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) + register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) + * torch.sqrt(alphas) / (1. - alphas_cumprod)) + + # text conditioning parameters + + self.text_use_bert_cls = text_use_bert_cls + + # dynamic thresholding when sampling + + self.use_dynamic_thres = use_dynamic_thres + self.dynamic_thres_percentile = dynamic_thres_percentile + + def q_mean_variance(self, x_start, t): + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract(1. - self.alphas_cumprod, t, x_start.shape) + log_variance = extract( + self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract( + self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool, cond=None, cond_scale=1.): + x_recon = self.predict_start_from_noise( + x, t=t, noise=self.denoise_fn.forward_with_cond_scale(x, t, cond=cond, cond_scale=cond_scale)) + + if clip_denoised: + s = 1. + if self.use_dynamic_thres: + s = torch.quantile( + rearrange(x_recon, 'b ... -> b (...)').abs(), + self.dynamic_thres_percentile, + dim=-1 + ) + + s.clamp_(min=1.) + s = s.view(-1, *((1,) * (x_recon.ndim - 1))) + + # clip by threshold, depending on whether static or dynamic + x_recon = x_recon.clamp(-s, s) / s + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.inference_mode() + def p_sample(self, x, t, cond=None, cond_scale=1., clip_denoised=True): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance( + x=x, t=t, clip_denoised=clip_denoised, cond=cond, cond_scale=cond_scale) + noise = torch.randn_like(x) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, + *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.inference_mode() + def p_sample_loop(self, shape, cond=None, cond_scale=1.): + device = self.betas.device + + b = shape[0] + img = torch.randn(shape, device=device) + # print('cond', cond.shape) + for i in reversed(range(0, self.num_timesteps)): + img = self.p_sample(img, torch.full( + (b,), i, device=device, dtype=torch.long), cond=cond, cond_scale=cond_scale) + + return img + + @torch.inference_mode() + def sample(self, cond=None, cond_scale=1., batch_size=16): + device = next(self.denoise_fn.parameters()).device + + if is_list_str(cond): + cond = bert_embed(tokenize(cond)).to(device) + + # batch_size = cond.shape[0] if exists(cond) else batch_size + batch_size = batch_size + image_size = self.image_size + channels = 8 # self.channels + num_frames = self.num_frames + # print((batch_size, channels, num_frames, image_size, image_size)) + # print('cond_',cond.shape) + _sample = self.p_sample_loop( + (batch_size, channels, num_frames, image_size, image_size), cond=cond, cond_scale=cond_scale) + + if isinstance(self.vqgan, VQGAN): + # denormalize TODO: Remove eventually + _sample = (((_sample + 1.0) / 2.0) * (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) + self.vqgan.codebook.embeddings.min() + + _sample = self.vqgan.decode(_sample, quantize=True) + else: + unnormalize_img(_sample) + + return _sample + + @torch.inference_mode() + def interpolate(self, x1, x2, t=None, lam=0.5): + b, *_, device = *x1.shape, x1.device + t = default(t, self.num_timesteps - 1) + + assert x1.shape == x2.shape + + t_batched = torch.stack([torch.tensor(t, device=device)] * b) + xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) + + img = (1 - lam) * xt1 + lam * xt2 + for i in reversed(range(0, t)): + img = self.p_sample(img, torch.full( + (b,), i, device=device, dtype=torch.long)) + + return img + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, + t, x_start.shape) * noise + ) + + def p_losses(self, x_start, t, cond=None, noise=None, **kwargs): + b, c, f, h, w, device = *x_start.shape, x_start.device + noise = default(noise, lambda: torch.randn_like(x_start)) + # breakpoint() + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # [2, 8, 32, 32, 32] + + if is_list_str(cond): + cond = bert_embed( + tokenize(cond), return_cls_repr=self.text_use_bert_cls) + cond = cond.to(device) + + x_recon = self.denoise_fn(x_noisy, t, cond=cond, **kwargs) + + if self.loss_type == 'l1': + loss = F.l1_loss(noise, x_recon) + elif self.loss_type == 'l2': + loss = F.mse_loss(noise, x_recon) + else: + raise NotImplementedError() + + return loss + + def forward(self, x, *args, **kwargs): + bs = int(x.shape[0]/2) + img=x[:bs,...] + mask=x[bs:,...] + mask_=(1-mask).detach() + masked_img = (img*mask_).detach() + masked_img=masked_img.permute(0,1,-1,-3,-2) + img=img.permute(0,1,-1,-3,-2) + mask=mask.permute(0,1,-1,-3,-2) + # breakpoint() + if isinstance(self.vqgan, VQGAN): + with torch.no_grad(): + img = self.vqgan.encode( + img, quantize=False, include_embeddings=True) + # normalize to -1 and 1 + img = ((img - self.vqgan.codebook.embeddings.min()) / + (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + masked_img = self.vqgan.encode( + masked_img, quantize=False, include_embeddings=True) + # normalize to -1 and 1 + masked_img = ((masked_img - self.vqgan.codebook.embeddings.min()) / + (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + else: + print("Hi") + img = normalize_img(img) + masked_img = normalize_img(masked_img) + mask = mask*2.0 - 1.0 + cc = torch.nn.functional.interpolate(mask, size=masked_img.shape[-3:]) + cond = torch.cat((masked_img, cc), dim=1) + + b, device, img_size, = img.shape[0], img.device, self.image_size + t = torch.randint(0, self.num_timesteps, (b,), device=device).long() + # breakpoint() + return self.p_losses(img, t, cond=cond, *args, **kwargs) + +# trainer class + + +CHANNELS_TO_MODE = { + 1: 'L', + 3: 'RGB', + 4: 'RGBA' +} + + +def seek_all_images(img, channels=3): + assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid' + mode = CHANNELS_TO_MODE[channels] + + i = 0 + while True: + try: + img.seek(i) + yield img.convert(mode) + except EOFError: + break + i += 1 + +# tensor of shape (channels, frames, height, width) -> gif + + +def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True): + tensor = ((tensor - tensor.min()) / (tensor.max() - tensor.min())) * 1.0 + images = map(T.ToPILImage(), tensor.unbind(dim=1)) + first_img, *rest_imgs = images + first_img.save(path, save_all=True, append_images=rest_imgs, + duration=duration, loop=loop, optimize=optimize) + return images + +# gif -> (channels, frame, height, width) tensor + + +def gif_to_tensor(path, channels=3, transform=T.ToTensor()): + img = Image.open(path) + tensors = tuple(map(transform, seek_all_images(img, channels=channels))) + return torch.stack(tensors, dim=1) + + +def identity(t, *args, **kwargs): + return t + + +def normalize_img(t): + return t * 2 - 1 + + +def unnormalize_img(t): + return (t + 1) * 0.5 + + +def cast_num_frames(t, *, frames): + f = t.shape[1] + + if f == frames: + return t + + if f > frames: + return t[:, :frames] + + return F.pad(t, (0, 0, 0, 0, 0, frames - f)) + + +class Dataset(data.Dataset): + def __init__( + self, + folder, + image_size, + channels=3, + num_frames=16, + horizontal_flip=False, + force_num_frames=True, + exts=['gif'] + ): + super().__init__() + self.folder = folder + self.image_size = image_size + self.channels = channels + self.paths = [p for ext in exts for p in Path( + f'{folder}').glob(f'**/*.{ext}')] + + self.cast_num_frames_fn = partial( + cast_num_frames, frames=num_frames) if force_num_frames else identity + + self.transform = T.Compose([ + T.Resize(image_size), + T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity), + T.CenterCrop(image_size), + T.ToTensor() + ]) + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = self.paths[index] + tensor = gif_to_tensor(path, self.channels, transform=self.transform) + return self.cast_num_frames_fn(tensor) + +# trainer class + + +class Tester(object): + def __init__( + self, + diffusion_model, + ): + super().__init__() + self.model = diffusion_model + self.ema_model = copy.deepcopy(self.model) + self.step=0 + self.image_size = diffusion_model.image_size + + self.reset_parameters() + + def reset_parameters(self): + self.ema_model.load_state_dict(self.model.state_dict()) + + + def load(self, milestone, map_location=None, **kwargs): + if milestone == -1: + all_milestones = [int(p.stem.split('-')[-1]) + for p in Path(self.results_folder).glob('**/*.pt')] + assert len( + all_milestones) > 0, 'need to have at least one milestone to load from latest checkpoint (milestone == -1)' + milestone = max(all_milestones) + + if map_location: + data = torch.load(milestone, map_location=map_location) + else: + data = torch.load(milestone) + + self.step = data['step'] + self.model.load_state_dict(data['model'], **kwargs) + self.ema_model.load_state_dict(data['ema'], **kwargs) + + diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/text.py b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/text.py new file mode 100644 index 0000000000000000000000000000000000000000..42dca78bec5075fb4f59c522aa3d00cc395d7536 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/text.py @@ -0,0 +1,94 @@ +"Taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch" + +import torch +from einops import rearrange + + +def exists(val): + return val is not None + +# singleton globals + + +MODEL = None +TOKENIZER = None +BERT_MODEL_DIM = 768 + + +def get_tokenizer(): + global TOKENIZER + if not exists(TOKENIZER): + TOKENIZER = torch.hub.load( + 'huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased') + return TOKENIZER + + +def get_bert(): + global MODEL + if not exists(MODEL): + MODEL = torch.hub.load( + 'huggingface/pytorch-transformers', 'model', 'bert-base-cased') + if torch.cuda.is_available(): + MODEL = MODEL.cuda() + + return MODEL + +# tokenize + + +def tokenize(texts, add_special_tokens=True): + if not isinstance(texts, (list, tuple)): + texts = [texts] + + tokenizer = get_tokenizer() + + encoding = tokenizer.batch_encode_plus( + texts, + add_special_tokens=add_special_tokens, + padding=True, + return_tensors='pt' + ) + + token_ids = encoding.input_ids + return token_ids + +# embedding function + + +@torch.no_grad() +def bert_embed( + token_ids, + return_cls_repr=False, + eps=1e-8, + pad_id=0. +): + model = get_bert() + mask = token_ids != pad_id + + if torch.cuda.is_available(): + token_ids = token_ids.cuda() + mask = mask.cuda() + + outputs = model( + input_ids=token_ids, + attention_mask=mask, + output_hidden_states=True + ) + + hidden_state = outputs.hidden_states[-1] + + if return_cls_repr: + # return [cls] as representation + return hidden_state[:, 0] + + if not exists(mask): + return hidden_state.mean(dim=1) + + # mean all tokens excluding [cls], accounting for length + mask = mask[:, 1:] + mask = rearrange(mask, 'b n -> b n 1') + + numer = (hidden_state[:, 1:] * mask).sum(dim=1) + denom = mask.sum(dim=1) + masked_mean = numer / (denom + eps) + return diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/time_embedding.py b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/time_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..7c911fa4f485295005cda9c9c2099db3ddbda15c --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/time_embedding.py @@ -0,0 +1,75 @@ +import math +import torch +import torch.nn as nn + +from monai.networks.layers.utils import get_act_layer + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, emb_dim=16, downscale_freq_shift=1, max_period=10000, flip_sin_to_cos=False): + super().__init__() + self.emb_dim = emb_dim + self.downscale_freq_shift = downscale_freq_shift + self.max_period = max_period + self.flip_sin_to_cos = flip_sin_to_cos + + def forward(self, x): + device = x.device + half_dim = self.emb_dim // 2 + emb = math.log(self.max_period) / \ + (half_dim - self.downscale_freq_shift) + emb = torch.exp(-emb*torch.arange(half_dim, device=device)) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + + if self.flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + if self.emb_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class LearnedSinusoidalPosEmb(nn.Module): + """ following @crowsonkb 's lead with learned sinusoidal pos emb """ + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, emb_dim): + super().__init__() + self.emb_dim = emb_dim + half_dim = emb_dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x): + x = x[:, None] + freqs = x * self.weights[None, :] * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + if self.emb_dim % 2 == 1: + fouriered = torch.nn.functional.pad(fouriered, (0, 1, 0, 0)) + return fouriered + + +class TimeEmbbeding(nn.Module): + def __init__( + self, + emb_dim=64, + pos_embedder=SinusoidalPosEmb, + pos_embedder_kwargs={}, + act_name=("SWISH", {}) # Swish = SiLU + ): + super().__init__() + self.emb_dim = emb_dim + self.pos_emb_dim = pos_embedder_kwargs.get('emb_dim', emb_dim//4) + pos_embedder_kwargs['emb_dim'] = self.pos_emb_dim + self.pos_embedder = pos_embedder(**pos_embedder_kwargs) + + self.time_emb = nn.Sequential( + self.pos_embedder, + nn.Linear(self.pos_emb_dim, self.emb_dim), + get_act_layer(act_name), + nn.Linear(self.emb_dim, self.emb_dim) + ) + + def forward(self, time): + return self.time_emb(time) diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/unet.py b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..e9cae8021e874a92b6baaaf4decd802516c2dc87 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/unet.py @@ -0,0 +1,226 @@ +from ddpm.time_embedding import TimeEmbbeding + +import monai.networks.nets as nets +import torch +import torch.nn as nn +from einops import rearrange + +from monai.networks.blocks import UnetBasicBlock, UnetResBlock, UnetUpBlock, Convolution, UnetOutBlock +from monai.networks.layers.utils import get_act_layer + + +class DownBlock(nn.Module): + def __init__( + self, + spatial_dims, + in_ch, + out_ch, + time_emb_dim, + cond_emb_dim, + act_name=("swish", {}), + **kwargs): + super(DownBlock, self).__init__() + self.loca_time_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(time_emb_dim, in_ch) # in_ch * 2 + ) + if cond_emb_dim is not None: + self.loca_cond_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(cond_emb_dim, in_ch), + ) + self.down_op = UnetBasicBlock( + spatial_dims, in_ch, out_ch, act_name=act_name, **kwargs) + + def forward(self, x, time_emb, cond_emb): + b, c, *_ = x.shape + sp_dim = x.ndim-2 + + # ------------ Time ---------- + time_emb = self.loca_time_embedder(time_emb) + time_emb = time_emb.reshape(b, c, *((1,)*sp_dim)) + # scale, shift = time_emb.chunk(2, dim = 1) + + # ------------ Combine ------------ + # x = x * (scale + 1) + shift + x = x + time_emb + + # ----------- Condition ------------ + if cond_emb is not None: + cond_emb = self.loca_cond_embedder(cond_emb) + cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim)) + x = x + cond_emb + + # ----------- Image --------- + y = self.down_op(x) + return y + + +class UpBlock(nn.Module): + def __init__( + self, + spatial_dims, + skip_ch, + enc_ch, + time_emb_dim, + cond_emb_dim, + act_name=("swish", {}), + **kwargs): + super(UpBlock, self).__init__() + self.up_op = UnetUpBlock(spatial_dims, enc_ch, + skip_ch, act_name=act_name, **kwargs) + self.loca_time_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(time_emb_dim, skip_ch * 2), + ) + if cond_emb_dim is not None: + self.loca_cond_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(cond_emb_dim, skip_ch * 2), + ) + + def forward(self, x_skip, x_enc, time_emb, cond_emb): + b, c, *_ = x_enc.shape + sp_dim = x_enc.ndim-2 + + # ----------- Time -------------- + time_emb = self.loca_time_embedder(time_emb) + time_emb = time_emb.reshape(b, c, *((1,)*sp_dim)) + # scale, shift = time_emb.chunk(2, dim = 1) + + # -------- Combine ------------- + # y = x * (scale + 1) + shift + x_enc = x_enc + time_emb + + # ----------- Condition ------------ + if cond_emb is not None: + cond_emb = self.loca_cond_embedder(cond_emb) + cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim)) + x_enc = x_enc + cond_emb + + # ----------- Image ------------- + y = self.up_op(x_enc, x_skip) + + # -------- Combine ------------- + # y = y * (scale + 1) + shift + + return y + + +class UNet(nn.Module): + + def __init__(self, + in_ch=1, + out_ch=1, + spatial_dims=3, + hid_chs=[32, 64, 128, 256, 512], + kernel_sizes=[(1, 3, 3), (1, 3, 3), (1, 3, 3), 3, 3], + strides=[1, (1, 2, 2), (1, 2, 2), 2, 2], + upsample_kernel_sizes=None, + act_name=("SWISH", {}), + norm_name=("INSTANCE", {"affine": True}), + time_embedder=TimeEmbbeding, + time_embedder_kwargs={}, + cond_embedder=None, + cond_embedder_kwargs={}, + # True = all but last layer, 0/False=disable, 1=only first layer, ... + deep_ver_supervision=True, + estimate_variance=False, + use_self_conditioning=False, + **kwargs + ): + super().__init__() + if upsample_kernel_sizes is None: + upsample_kernel_sizes = strides[1:] + + # ------------- Time-Embedder----------- + self.time_embedder = time_embedder(**time_embedder_kwargs) + + # ------------- Condition-Embedder----------- + if cond_embedder is not None: + self.cond_embedder = cond_embedder(**cond_embedder_kwargs) + cond_emb_dim = self.cond_embedder.emb_dim + else: + self.cond_embedder = None + cond_emb_dim = None + + # ----------- In-Convolution ------------ + in_ch = in_ch*2 if use_self_conditioning else in_ch + self.inc = UnetBasicBlock(spatial_dims, in_ch, hid_chs[0], kernel_size=kernel_sizes[0], stride=strides[0], + act_name=act_name, norm_name=norm_name, **kwargs) + + # ----------- Encoder ---------------- + self.encoders = nn.ModuleList([ + DownBlock(spatial_dims, hid_chs[i-1], hid_chs[i], time_emb_dim=self.time_embedder.emb_dim, + cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[ + i], stride=strides[i], act_name=act_name, + norm_name=norm_name, **kwargs) + for i in range(1, len(strides)) + ]) + + # ------------ Decoder ---------- + self.decoders = nn.ModuleList([ + UpBlock(spatial_dims, hid_chs[i], hid_chs[i+1], time_emb_dim=self.time_embedder.emb_dim, + cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[i + + 1], stride=strides[i+1], act_name=act_name, + norm_name=norm_name, upsample_kernel_size=upsample_kernel_sizes[i], **kwargs) + for i in range(len(strides)-1) + ]) + + # --------------- Out-Convolution ---------------- + out_ch_hor = out_ch*2 if estimate_variance else out_ch + self.outc = UnetOutBlock( + spatial_dims, hid_chs[0], out_ch_hor, dropout=None) + if isinstance(deep_ver_supervision, bool): + deep_ver_supervision = len( + strides)-2 if deep_ver_supervision else 0 + self.outc_ver = nn.ModuleList([ + UnetOutBlock(spatial_dims, hid_chs[i], out_ch, dropout=None) + for i in range(1, deep_ver_supervision+1) + ]) + + def forward(self, x_t, t, cond=None, self_cond=None, **kwargs): + condition = cond + # x_t [B, C, (D), H, W] + # t [B,] + + # -------- In-Convolution -------------- + x = [None for _ in range(len(self.encoders)+1)] + x_t = torch.cat([x_t, self_cond], + dim=1) if self_cond is not None else x_t + x[0] = self.inc(x_t) + + # -------- Time Embedding (Gloabl) ----------- + time_emb = self.time_embedder(t) # [B, C] + + # -------- Condition Embedding (Gloabl) ----------- + if (condition is None) or (self.cond_embedder is None): + cond_emb = None + else: + cond_emb = self.cond_embedder(condition) # [B, C] + + # --------- Encoder -------------- + for i in range(len(self.encoders)): + x[i+1] = self.encoders[i](x[i], time_emb, cond_emb) + + # -------- Decoder ----------- + for i in range(len(self.decoders), 0, -1): + x[i-1] = self.decoders[i-1](x[i-1], x[i], time_emb, cond_emb) + + # ---------Out-Convolution ------------ + y_hor = self.outc(x[0]) + y_ver = [outc_ver_i(x[i+1]) + for i, outc_ver_i in enumerate(self.outc_ver)] + + return y_hor # , y_ver + + def forward_with_cond_scale(self, *args, cond_scale=0., **kwargs): + return self.forward(*args, **kwargs) + + +if __name__ == '__main__': + model = UNet(in_ch=3) + input = torch.randn((1, 3, 16, 128, 128)) + time = torch.randn((1,)) + out_hor, out_ver = model(input, time) + print(out_hor[0].shape) diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/util.py b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..2047cf56802046bbf37cfd33f819ed75a239a7df --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/ddpm/util.py @@ -0,0 +1,271 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +# from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + if c != 1: + steps_out = ddim_timesteps + 1 + else: + steps_out = ddim_timesteps + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c6af8bd00ace53e29b1c8aa504c8696b4ea924a Binary files /dev/null and b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__init__.py b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38c08845c1ee2f7b4b52bf95fa282b0808931a3a --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__init__.py @@ -0,0 +1,3 @@ +from .vqgan import VQGAN +from .codebook import Codebook +from .lpips import LPIPS diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef785b63e9430a140fa1101717453f142aa695b7 Binary files /dev/null and b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0be5fb0699b289341cfa55b655484adb3fae6e11 Binary files /dev/null and b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af7ef665204041d05e7be07f870b62ac7b3fd22d Binary files /dev/null and b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffd5c347cf3d51640c4b259d6e55547116c64015 Binary files /dev/null and b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth new file mode 100644 index 0000000000000000000000000000000000000000..f57dcf5cc764d61c8a460365847fb2137ff0a62d --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868 +size 7289 diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/codebook.py b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/codebook.py new file mode 100644 index 0000000000000000000000000000000000000000..c17ed701a2824cc0dcd75670d47a6ab3842e9d35 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/codebook.py @@ -0,0 +1,109 @@ +""" Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from ..utils import shift_dim + + +class Codebook(nn.Module): + def __init__(self, n_codes, embedding_dim, no_random_restart=False, restart_thres=1.0): + super().__init__() + self.register_buffer('embeddings', torch.randn(n_codes, embedding_dim)) + self.register_buffer('N', torch.zeros(n_codes)) + self.register_buffer('z_avg', self.embeddings.data.clone()) + + self.n_codes = n_codes + self.embedding_dim = embedding_dim + self._need_init = True + self.no_random_restart = no_random_restart + self.restart_thres = restart_thres + + def _tile(self, x): + d, ew = x.shape + if d < self.n_codes: + n_repeats = (self.n_codes + d - 1) // d + std = 0.01 / np.sqrt(ew) + x = x.repeat(n_repeats, 1) + x = x + torch.randn_like(x) * std + return x + + def _init_embeddings(self, z): + # z: [b, c, t, h, w] + self._need_init = False + breakpoint() + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [65536, 8] [2, 8, 32, 32, 32] + y = self._tile(flat_inputs) # [65536, 8] + + d = y.shape[0] + _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + self.embeddings.data.copy_(_k_rand) + self.z_avg.data.copy_(_k_rand) + self.N.data.copy_(torch.ones(self.n_codes)) + + def forward(self, z): + # z: [b, c, t, h, w] + if self._need_init and self.training: + self._init_embeddings(z) + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [bthw, c] [65536, 8] + distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \ + - 2 * flat_inputs @ self.embeddings.t() \ + + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) # [bthw, c] [65536, 8] + + encoding_indices = torch.argmin(distances, dim=1) # [65536] + encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as( + flat_inputs) # [bthw, ncode] [65536, 16384] + encoding_indices = encoding_indices.view( + z.shape[0], *z.shape[2:]) # [b, t, h, w, ncode] [2, 32, 32, 32] + + embeddings = F.embedding( + encoding_indices, self.embeddings) # [b, t, h, w, c] self.embeddings [16384, 8] + embeddings = shift_dim(embeddings, -1, 1) # [b, c, t, h, w] [2, 8, 32, 32, 32] + + commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) + + # EMA codebook update + if self.training: + n_total = encode_onehot.sum(dim=0) # [16384] + encode_sum = flat_inputs.t() @ encode_onehot # [8, 16384] + if dist.is_initialized(): + dist.all_reduce(n_total) + dist.all_reduce(encode_sum) + + self.N.data.mul_(0.99).add_(n_total, alpha=0.01) + self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) + + n = self.N.sum() + weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n + encode_normalized = self.z_avg / weights.unsqueeze(1) + self.embeddings.data.copy_(encode_normalized) + + y = self._tile(flat_inputs) + _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + + if not self.no_random_restart: + usage = (self.N.view(self.n_codes, 1) + >= self.restart_thres).float() + self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) + + embeddings_st = (embeddings - z).detach() + z + + avg_probs = torch.mean(encode_onehot, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * + torch.log(avg_probs + 1e-10))) + + return dict(embeddings=embeddings_st, encodings=encoding_indices, + commitment_loss=commitment_loss, perplexity=perplexity) + + def dictionary_lookup(self, encodings): + embeddings = F.embedding(encodings, self.embeddings) + return embeddings diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/lpips.py b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..019434b9b29945d67e6e0a95dec324df7ff908f3 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/lpips.py @@ -0,0 +1,181 @@ +"""Adapted from https://github.com/SongweiGe/TATS""" + +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + + +from collections import namedtuple +from torchvision import models +import torch.nn as nn +import torch +from tqdm import tqdm +import requests +import os +import hashlib +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format( + name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + # self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + self.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name is not "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + model.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer( + input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor( + outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) + for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor( + [-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor( + [.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, + padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple( + "VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, + h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x/(norm_factor+eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py new file mode 100644 index 0000000000000000000000000000000000000000..a64ec8878c0ecbf418775e2958b2cbf578ce2918 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py @@ -0,0 +1,561 @@ +"""Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import math +import argparse +import numpy as np +import pickle as pkl + +import pytorch_lightning as pl +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from ..utils import shift_dim, adopt_weight, comp_getattr +from .lpips import LPIPS +from .codebook import Codebook + + +def silu(x): + return x*torch.sigmoid(x) + + +class SiLU(nn.Module): + def __init__(self): + super(SiLU, self).__init__() + + def forward(self, x): + return silu(x) + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1. - logits_real)) + loss_fake = torch.mean(F.relu(1. + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake))) + return d_loss + + +class VQGAN(pl.LightningModule): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.embedding_dim = cfg.model.embedding_dim # 8 + self.n_codes = cfg.model.n_codes # 16384 + + self.encoder = Encoder(cfg.model.n_hiddens, # 16 + cfg.model.downsample, # [2, 2, 2] + cfg.dataset.image_channels, # 1 + cfg.model.norm_type, # group + cfg.model.padding_type, # replicate + cfg.model.num_groups, # 32 + ) + self.decoder = Decoder( + cfg.model.n_hiddens, cfg.model.downsample, cfg.dataset.image_channels, cfg.model.norm_type, cfg.model.num_groups) + self.enc_out_ch = self.encoder.out_channels + self.pre_vq_conv = SamePadConv3d( + self.enc_out_ch, cfg.model.embedding_dim, 1, padding_type=cfg.model.padding_type) + self.post_vq_conv = SamePadConv3d( + cfg.model.embedding_dim, self.enc_out_ch, 1) + + self.codebook = Codebook(cfg.model.n_codes, cfg.model.embedding_dim, + no_random_restart=cfg.model.no_random_restart, restart_thres=cfg.model.restart_thres) + + self.gan_feat_weight = cfg.model.gan_feat_weight + # TODO: Changed batchnorm from sync to normal + self.image_discriminator = NLayerDiscriminator( + cfg.dataset.image_channels, cfg.model.disc_channels, cfg.model.disc_layers, norm_layer=nn.BatchNorm2d) + self.video_discriminator = NLayerDiscriminator3D( + cfg.dataset.image_channels, cfg.model.disc_channels, cfg.model.disc_layers, norm_layer=nn.BatchNorm3d) + + if cfg.model.disc_loss_type == 'vanilla': + self.disc_loss = vanilla_d_loss + elif cfg.model.disc_loss_type == 'hinge': + self.disc_loss = hinge_d_loss + + self.perceptual_model = LPIPS().eval() + + self.image_gan_weight = cfg.model.image_gan_weight + self.video_gan_weight = cfg.model.video_gan_weight + + self.perceptual_weight = cfg.model.perceptual_weight + + self.l1_weight = cfg.model.l1_weight + self.save_hyperparameters() + + def encode(self, x, include_embeddings=False, quantize=True): + h = self.pre_vq_conv(self.encoder(x)) + if quantize: + vq_output = self.codebook(h) + if include_embeddings: + return vq_output['embeddings'], vq_output['encodings'] + else: + return vq_output['encodings'] + return h + + def decode(self, latent, quantize=False): + if quantize: + vq_output = self.codebook(latent) + latent = vq_output['encodings'] + h = F.embedding(latent, self.codebook.embeddings) + h = self.post_vq_conv(shift_dim(h, -1, 1)) + return self.decoder(h) + + def forward(self, x, optimizer_idx=None, log_image=False): + B, C, T, H, W = x.shape + z = self.pre_vq_conv(self.encoder(x)) # [2, 32, 32, 32, 32] [2, 8, 32, 32, 32] + vq_output = self.codebook(z) # ['embeddings', 'encodings', 'commitment_loss', 'perplexity'] + x_recon = self.decoder(self.post_vq_conv(vq_output['embeddings'])) # [2, 8, 32, 32, 32] [2, 32, 32, 32, 32] + + recon_loss = F.l1_loss(x_recon, x) * self.l1_weight + + # Selects one random 2D image from each 3D Image + frame_idx = torch.randint(0, T, [B]).cuda() + frame_idx_selected = frame_idx.reshape(-1, + 1, 1, 1, 1).repeat(1, C, 1, H, W) # [2, 1, 1, 64, 64] + frames = torch.gather(x, 2, frame_idx_selected).squeeze(2) # [2, 1, 64, 64] + frames_recon = torch.gather(x_recon, 2, frame_idx_selected).squeeze(2) # [2, 1, 64, 64] + + if log_image: + return frames, frames_recon, x, x_recon + + if optimizer_idx == 0: + # Autoencoder - train the "generator" + + # Perceptual loss + perceptual_loss = 0 + if self.perceptual_weight > 0: + perceptual_loss = self.perceptual_model( + frames, frames_recon).mean() * self.perceptual_weight + + # Discriminator loss (turned on after a certain epoch) + logits_image_fake, pred_image_fake = self.image_discriminator( + frames_recon) + logits_video_fake, pred_video_fake = self.video_discriminator( + x_recon) + g_image_loss = -torch.mean(logits_image_fake) + g_video_loss = -torch.mean(logits_video_fake) + g_loss = self.image_gan_weight*g_image_loss + self.video_gan_weight*g_video_loss + disc_factor = adopt_weight( + self.global_step, threshold=self.cfg.model.discriminator_iter_start) + aeloss = disc_factor * g_loss + + # GAN feature matching loss - tune features such that we get the same prediction result on the discriminator + image_gan_feat_loss = 0 + video_gan_feat_loss = 0 + feat_weights = 4.0 / (3 + 1) + if self.image_gan_weight > 0: + logits_image_real, pred_image_real = self.image_discriminator( + frames) + for i in range(len(pred_image_fake)-1): + image_gan_feat_loss += feat_weights * \ + F.l1_loss(pred_image_fake[i], pred_image_real[i].detach( + )) * (self.image_gan_weight > 0) + if self.video_gan_weight > 0: + logits_video_real, pred_video_real = self.video_discriminator( + x) + for i in range(len(pred_video_fake)-1): + video_gan_feat_loss += feat_weights * \ + F.l1_loss(pred_video_fake[i], pred_video_real[i].detach( + )) * (self.video_gan_weight > 0) + gan_feat_loss = disc_factor * self.gan_feat_weight * \ + (image_gan_feat_loss + video_gan_feat_loss) + + self.log("train/g_image_loss", g_image_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/g_video_loss", g_video_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/image_gan_feat_loss", image_gan_feat_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/video_gan_feat_loss", video_gan_feat_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/perceptual_loss", perceptual_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log("train/recon_loss", recon_loss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + self.log("train/aeloss", aeloss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + self.log("train/commitment_loss", vq_output['commitment_loss'], + prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log('train/perplexity', vq_output['perplexity'], + prog_bar=True, logger=True, on_step=True, on_epoch=True) + return recon_loss, x_recon, vq_output, aeloss, perceptual_loss, gan_feat_loss + + if optimizer_idx == 1: + # Train discriminator + logits_image_real, _ = self.image_discriminator(frames.detach()) + logits_video_real, _ = self.video_discriminator(x.detach()) + + logits_image_fake, _ = self.image_discriminator( + frames_recon.detach()) + logits_video_fake, _ = self.video_discriminator(x_recon.detach()) + + d_image_loss = self.disc_loss(logits_image_real, logits_image_fake) + d_video_loss = self.disc_loss(logits_video_real, logits_video_fake) + disc_factor = adopt_weight( + self.global_step, threshold=self.cfg.model.discriminator_iter_start) + discloss = disc_factor * \ + (self.image_gan_weight*d_image_loss + + self.video_gan_weight*d_video_loss) + + self.log("train/logits_image_real", logits_image_real.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_image_fake", logits_image_fake.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_video_real", logits_video_real.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_video_fake", logits_video_fake.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/d_image_loss", d_image_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/d_video_loss", d_video_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/discloss", discloss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + return discloss + + perceptual_loss = self.perceptual_model( + frames, frames_recon) * self.perceptual_weight + return recon_loss, x_recon, vq_output, perceptual_loss + + def training_step(self, batch, batch_idx, optimizer_idx): + x = batch['image'] + if optimizer_idx == 0: + recon_loss, _, vq_output, aeloss, perceptual_loss, gan_feat_loss = self.forward( + x, optimizer_idx) + commitment_loss = vq_output['commitment_loss'] + loss = recon_loss + commitment_loss + aeloss + perceptual_loss + gan_feat_loss + if optimizer_idx == 1: + discloss = self.forward(x, optimizer_idx) + loss = discloss + return loss + + def validation_step(self, batch, batch_idx): + x = batch['image'] # TODO: batch['stft'] + recon_loss, _, vq_output, perceptual_loss = self.forward(x) + self.log('val/recon_loss', recon_loss, prog_bar=True) + self.log('val/perceptual_loss', perceptual_loss, prog_bar=True) + self.log('val/perplexity', vq_output['perplexity'], prog_bar=True) + self.log('val/commitment_loss', + vq_output['commitment_loss'], prog_bar=True) + + def configure_optimizers(self): + lr = self.cfg.model.lr + opt_ae = torch.optim.Adam(list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.pre_vq_conv.parameters()) + + list(self.post_vq_conv.parameters()) + + list(self.codebook.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(list(self.image_discriminator.parameters()) + + list(self.video_discriminator.parameters()), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def log_images(self, batch, **kwargs): + log = dict() + x = batch['image'] + x = x.to(self.device) + frames, frames_rec, _, _ = self(x, log_image=True) + log["inputs"] = frames + log["reconstructions"] = frames_rec + #log['mean_org'] = batch['mean_org'] + #log['std_org'] = batch['std_org'] + return log + + def log_videos(self, batch, **kwargs): + log = dict() + x = batch['image'] + _, _, x, x_rec = self(x, log_image=True) + log["inputs"] = x + log["reconstructions"] = x_rec + #log['mean_org'] = batch['mean_org'] + #log['std_org'] = batch['std_org'] + return log + + +def Normalize(in_channels, norm_type='group', num_groups=32): + assert norm_type in ['group', 'batch'] + if norm_type == 'group': + # TODO Changed num_groups from 32 to 8 + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == 'batch': + return torch.nn.SyncBatchNorm(in_channels) + + +class Encoder(nn.Module): + def __init__(self, n_hiddens, downsample, image_channel=3, norm_type='group', padding_type='replicate', num_groups=32): + super().__init__() + n_times_downsample = np.array([int(math.log2(d)) for d in downsample]) + self.conv_blocks = nn.ModuleList() + max_ds = n_times_downsample.max() + + self.conv_first = SamePadConv3d( + image_channel, n_hiddens, kernel_size=3, padding_type=padding_type) + + for i in range(max_ds): + block = nn.Module() + in_channels = n_hiddens * 2**i + out_channels = n_hiddens * 2**(i+1) + stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) + block.down = SamePadConv3d( + in_channels, out_channels, 4, stride=stride, padding_type=padding_type) + block.res = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + self.conv_blocks.append(block) + n_times_downsample -= 1 + + self.final_block = nn.Sequential( + Normalize(out_channels, norm_type, num_groups=num_groups), + SiLU() + ) + + self.out_channels = out_channels + + def forward(self, x): + h = self.conv_first(x) + for block in self.conv_blocks: + h = block.down(h) + h = block.res(h) + h = self.final_block(h) + return h + + +class Decoder(nn.Module): + def __init__(self, n_hiddens, upsample, image_channel, norm_type='group', num_groups=32): + super().__init__() + + n_times_upsample = np.array([int(math.log2(d)) for d in upsample]) + max_us = n_times_upsample.max() + + in_channels = n_hiddens*2**max_us + self.final_block = nn.Sequential( + Normalize(in_channels, norm_type, num_groups=num_groups), + SiLU() + ) + + self.conv_blocks = nn.ModuleList() + for i in range(max_us): + block = nn.Module() + in_channels = in_channels if i == 0 else n_hiddens*2**(max_us-i+1) + out_channels = n_hiddens*2**(max_us-i) + us = tuple([2 if d > 0 else 1 for d in n_times_upsample]) + block.up = SamePadConvTranspose3d( + in_channels, out_channels, 4, stride=us) + block.res1 = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + block.res2 = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + self.conv_blocks.append(block) + n_times_upsample -= 1 + + self.conv_last = SamePadConv3d( + out_channels, image_channel, kernel_size=3) + + def forward(self, x): + h = self.final_block(x) + for i, block in enumerate(self.conv_blocks): + h = block.up(h) + h = block.res1(h) + h = block.res2(h) + h = self.conv_last(h) + return h + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group', padding_type='replicate', num_groups=32): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, norm_type, num_groups=num_groups) + self.conv1 = SamePadConv3d( + in_channels, out_channels, kernel_size=3, padding_type=padding_type) + self.dropout = torch.nn.Dropout(dropout) + self.norm2 = Normalize(in_channels, norm_type, num_groups=num_groups) + self.conv2 = SamePadConv3d( + out_channels, out_channels, kernel_size=3, padding_type=padding_type) + if self.in_channels != self.out_channels: + self.conv_shortcut = SamePadConv3d( + in_channels, out_channels, kernel_size=3, padding_type=padding_type) + + def forward(self, x): + h = x + h = self.norm1(h) + h = silu(h) + h = self.conv1(h) + h = self.norm2(h) + h = silu(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.conv_shortcut(x) + + return x+h + + +# Does not support dilation +class SamePadConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + if isinstance(stride, int): + stride = (stride,) * 3 + + # assumes that the input shape is divisible by stride + total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) + pad_input = [] + for p in total_pad[::-1]: # reverse since F.pad starts from last dim + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + self.pad_input = pad_input + self.padding_type = padding_type + + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, + stride=stride, padding=0, bias=bias) + + def forward(self, x): + return self.conv(F.pad(x, self.pad_input, mode=self.padding_type)) + + +class SamePadConvTranspose3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + if isinstance(stride, int): + stride = (stride,) * 3 + + total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) + pad_input = [] + for p in total_pad[::-1]: # reverse since F.pad starts from last dim + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + self.pad_input = pad_input + self.padding_type = padding_type + + self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, + stride=stride, bias=bias, + padding=tuple([k - 1 for k in kernel_size])) + + def forward(self, x): + return self.convt(F.pad(x, self.pad_input, mode=self.padding_type)) + + +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True): + # def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=True): + super(NLayerDiscriminator, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, + stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, + stride=1, padding=padw)]] + + if use_sigmoid: + sequence += [[nn.Sigmoid()]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[-1], res[1:] + else: + return self.model(input), _ + + +class NLayerDiscriminator3D(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True): + super(NLayerDiscriminator3D, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv3d(input_nc, ndf, kernel_size=kw, + stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv3d(nf, 1, kernel_size=kw, + stride=1, padding=padw)]] + + if use_sigmoid: + sequence += [[nn.Sigmoid()]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[-1], res[1:] + else: + return self.model(input), _ diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/utils.py b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..acf1a21270d98d0a1ea9935d00f9f16d89d54551 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/ldm/vq_gan_3d/utils.py @@ -0,0 +1,177 @@ +""" Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import warnings +import torch +import imageio + +import math +import numpy as np + +import sys +import pdb as pdb_original +import logging + +import imageio.core.util +logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR) + + +class ForkedPdb(pdb_original.Pdb): + """A Pdb subclass that may be used + from a forked multiprocessing child + + """ + + def interaction(self, *args, **kwargs): + _stdin = sys.stdin + try: + sys.stdin = open('/dev/stdin') + pdb_original.Pdb.interaction(self, *args, **kwargs) + finally: + sys.stdin = _stdin + + +# Shifts src_tf dim to dest dim +# i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c) +def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): + n_dims = len(x.shape) + if src_dim < 0: + src_dim = n_dims + src_dim + if dest_dim < 0: + dest_dim = n_dims + dest_dim + + assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims + + dims = list(range(n_dims)) + del dims[src_dim] + + permutation = [] + ctr = 0 + for i in range(n_dims): + if i == dest_dim: + permutation.append(src_dim) + else: + permutation.append(dims[ctr]) + ctr += 1 + x = x.permute(permutation) + if make_contiguous: + x = x.contiguous() + return x + + +# reshapes tensor start from dim i (inclusive) +# to dim j (exclusive) to the desired shape +# e.g. if x.shape = (b, thw, c) then +# view_range(x, 1, 2, (t, h, w)) returns +# x of shape (b, t, h, w, c) +def view_range(x, i, j, shape): + shape = tuple(shape) + + n_dims = len(x.shape) + if i < 0: + i = n_dims + i + + if j is None: + j = n_dims + elif j < 0: + j = n_dims + j + + assert 0 <= i < j <= n_dims + + x_shape = x.shape + target_shape = x_shape[:i] + shape + x_shape[j:] + return x.view(target_shape) + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def tensor_slice(x, begin, size): + assert all([b >= 0 for b in begin]) + size = [l - b if s == -1 else s + for s, b, l in zip(size, begin, x.shape)] + assert all([s >= 0 for s in size]) + + slices = [slice(b, b + s) for b, s in zip(begin, size)] + return x[slices] + + +def adopt_weight(global_step, threshold=0, value=0.): + weight = 1 + if global_step < threshold: + weight = value + return weight + + +def save_video_grid(video, fname, nrow=None, fps=6): + b, c, t, h, w = video.shape + video = video.permute(0, 2, 3, 4, 1) + video = (video.cpu().numpy() * 255).astype('uint8') + if nrow is None: + nrow = math.ceil(math.sqrt(b)) + ncol = math.ceil(b / nrow) + padding = 1 + video_grid = np.zeros((t, (padding + h) * nrow + padding, + (padding + w) * ncol + padding, c), dtype='uint8') + for i in range(b): + r = i // ncol + c = i % ncol + start_r = (padding + h) * r + start_c = (padding + w) * c + video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] + video = [] + for i in range(t): + video.append(video_grid[i]) + imageio.mimsave(fname, video, fps=fps) + ## skvideo.io.vwrite(fname, video_grid, inputdict={'-r': '5'}) + #print('saved videos to', fname) + + +def comp_getattr(args, attr_name, default=None): + if hasattr(args, attr_name): + return getattr(args, attr_name) + else: + return default + + +def visualize_tensors(t, name=None, nest=0): + if name is not None: + print(name, "current nest: ", nest) + print("type: ", type(t)) + if 'dict' in str(type(t)): + print(t.keys()) + for k in t.keys(): + if t[k] is None: + print(k, "None") + else: + if 'Tensor' in str(type(t[k])): + print(k, t[k].shape) + elif 'dict' in str(type(t[k])): + print(k, 'dict') + visualize_tensors(t[k], name, nest + 1) + elif 'list' in str(type(t[k])): + print(k, len(t[k])) + visualize_tensors(t[k], name, nest + 1) + elif 'list' in str(type(t)): + print("list length: ", len(t)) + for t2 in t: + visualize_tensors(t2, name, nest + 1) + elif 'Tensor' in str(type(t)): + print(t.shape) + else: + print(t) + return "" diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_kidney_fold0_early.pt b/Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_kidney_fold0_early.pt new file mode 100644 index 0000000000000000000000000000000000000000..b7f83060ec54f1e1a98f942ec85ab87ad560b6ff --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_kidney_fold0_early.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d889b3561803f7490f4050c03a02163f099633e4f00fea4cb10b5b993685e5cc +size 290138333 diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_kidney_fold0_noearly_t200.pt b/Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_kidney_fold0_noearly_t200.pt new file mode 100644 index 0000000000000000000000000000000000000000..63852b75a8b3878b9bc664b40a748a8f2f8344cd --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_kidney_fold0_noearly_t200.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:26bc7847ae15377a5586535cbb2e6a1ec5b6a98732f7f795c284d7dcda208c97 +size 290156765 diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_liver_fold0_early.pt b/Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_liver_fold0_early.pt new file mode 100644 index 0000000000000000000000000000000000000000..e9ab9d5546979bb9389e3ff7e3cf626591115b2f --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_liver_fold0_early.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0135000f031f741252b3e706748b674d33e7278402a7cb2500fec5f4966847bd +size 290138333 diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_liver_fold0_noearly_t200.pt b/Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_liver_fold0_noearly_t200.pt new file mode 100644 index 0000000000000000000000000000000000000000..94c05be2a43838624a24ec18e7d0b70b34221e79 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_liver_fold0_noearly_t200.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c2a21980e53efd6758ae92e79a82668f0e1e6d9b52fdf6b2a709cb929ebedb3b +size 290156765 diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_pancreas_fold0_early.pt b/Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_pancreas_fold0_early.pt new file mode 100644 index 0000000000000000000000000000000000000000..a615d012889e61d79e5018dfad4a48f41bf440a4 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_pancreas_fold0_early.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9438e39a44af92bb0fbaf5cc50a3ac3aaa260978a69ac341ed7ec23512c080a5 +size 290138333 diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_pancreas_fold0_noearly_t200.pt b/Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_pancreas_fold0_noearly_t200.pt new file mode 100644 index 0000000000000000000000000000000000000000..82486487bd145d8dac8c823ae22df401d19a1711 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/diff_pancreas_fold0_noearly_t200.pt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fb37011840156f548fd2348dbb5578f9bc81de16719ec226fbef2de6f0244f9d +size 290156765 diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/recon_96d4_all.ckpt b/Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/recon_96d4_all.ckpt new file mode 100644 index 0000000000000000000000000000000000000000..f81576507c8976c66fad6f67a1e26b3b78f6cf46 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/model_weight/recon_96d4_all.ckpt @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ef88523af9590a7325bc9ca41999de191c3fbc41afc6186a8c4db5528446bb1f +size 242615727 diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/utils.py b/Generation_Pipeline_filter/syn_liver/TumorGeneration/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..88e303a085dc90228553170866ec732e2cd86bcd --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/utils.py @@ -0,0 +1,471 @@ +### Tumor Generateion +import random +import cv2 +import elasticdeform +import numpy as np +from scipy.ndimage import gaussian_filter +from TumorGeneration.ldm.ddpm.ddim import DDIMSampler + +# Step 1: Random select (numbers) location for tumor. +def random_select(mask_scan, organ_type): + # we first find z index and then sample point with z slice + # print('mask_scan',np.unique(mask_scan)) + # print('pixel num', (mask_scan == 1).sum()) + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0].min(), np.where(np.any(mask_scan, axis=(0, 1)))[0].max() + # print('z_start, z_end',z_start, z_end) + # we need to strict number z's position (0.3 - 0.7 in the middle of liver) + flag=0 + while 1: + z = round(random.uniform(0.3, 0.7) * (z_end - z_start)) + z_start + liver_mask = mask_scan[..., z] + # erode the mask (we don't want the edge points) + if organ_type == 'liver': + flag+=1 + if flag <= 10: + kernel = np.ones((5,5), dtype=np.uint8) + liver_mask = cv2.erode(liver_mask, kernel, iterations=1) + elif flag >10 and flag <= 20: + kernel = np.ones((3,3), dtype=np.uint8) + liver_mask = cv2.erode(liver_mask, kernel, iterations=1) + else: + pass + print(flag) + if (liver_mask == 1).sum() > 0: + break + + # print('liver_mask', (liver_mask == 1).sum()) + coordinates = np.argwhere(liver_mask == 1) + random_index = np.random.randint(0, len(coordinates)) + xyz = coordinates[random_index].tolist() # get x,y + xyz.append(z) + potential_points = xyz + + return potential_points + +def center_select(mask_scan): + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0].min(), np.where(np.any(mask_scan, axis=(0, 1)))[0].max() + x_start, x_end = np.where(np.any(mask_scan, axis=(1, 2)))[0].min(), np.where(np.any(mask_scan, axis=(1, 2)))[0].max() + y_start, y_end = np.where(np.any(mask_scan, axis=(0, 2)))[0].min(), np.where(np.any(mask_scan, axis=(0, 2)))[0].max() + + z = round(0.5 * (z_end - z_start)) + z_start + x = round(0.5 * (x_end - x_start)) + x_start + y = round(0.5 * (y_end - y_start)) + y_start + + xyz = [x, y, z] + potential_points = xyz + + return potential_points + +# Step 2 : generate the ellipsoid +def get_ellipsoid(x, y, z): + """" + x, y, z is the radius of this ellipsoid in x, y, z direction respectly. + """ + sh = (4*x, 4*y, 4*z) + out = np.zeros(sh, int) + aux = np.zeros(sh) + radii = np.array([x, y, z]) + com = np.array([2*x, 2*y, 2*z]) # center point + + # calculate the ellipsoid + bboxl = np.floor(com-radii).clip(0,None).astype(int) + bboxh = (np.ceil(com+radii)+1).clip(None, sh).astype(int) + roi = out[tuple(map(slice,bboxl,bboxh))] + roiaux = aux[tuple(map(slice,bboxl,bboxh))] + logrid = *map(np.square,np.ogrid[tuple( + map(slice,(bboxl-com)/radii,(bboxh-com-1)/radii,1j*(bboxh-bboxl)))]), + dst = (1-sum(logrid)).clip(0,None) + mask = dst>roiaux + roi[mask] = 1 + np.copyto(roiaux,dst,where=mask) + + return out + +def get_fixed_geo(mask_scan, tumor_type, organ_type): + if tumor_type == 'large': + enlarge_x, enlarge_y, enlarge_z = 280, 280, 280 + else: + enlarge_x, enlarge_y, enlarge_z = 160, 160, 160 + geo_mask = np.zeros((mask_scan.shape[0] + enlarge_x, mask_scan.shape[1] + enlarge_y, mask_scan.shape[2] + enlarge_z), dtype=np.int8) + tiny_radius, small_radius, medium_radius, large_radius = 4, 8, 16, 32 + + if tumor_type == 'tiny': + num_tumor = random.randint(1,3) + # num_tumor = 1 + for _ in range(num_tumor): + # Tiny tumor + x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + sigma = random.uniform(0.5,1) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + + if tumor_type == 'small': + num_tumor = random.randint(1,3) + # num_tumor = 1 + for _ in range(num_tumor): + # Small tumor + x = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + y = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + z = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + sigma = random.randint(1, 2) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'medium': + # num_tumor = random.randint(1, 3) + num_tumor = 1 + for _ in range(num_tumor): + # medium tumor + x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + sigma = random.randint(3, 6) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste medium tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'large': + num_tumor = 1 # random.randint(1,3) + for _ in range(num_tumor): + # Large tumor + + x = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + y = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + z = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + sigma = random.randint(5, 10) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + if organ_type == 'liver' or organ_type == 'kidney' : + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + else: + x_start, x_end = np.where(np.any(geo, axis=(1, 2)))[0].min(), np.where(np.any(geo, axis=(1, 2)))[0].max() + y_start, y_end = np.where(np.any(geo, axis=(0, 2)))[0].min(), np.where(np.any(geo, axis=(0, 2)))[0].max() + z_start, z_end = np.where(np.any(geo, axis=(0, 1)))[0].min(), np.where(np.any(geo, axis=(0, 1)))[0].max() + geo = geo[x_start:x_end, y_start:y_end, z_start:z_end] + + point = center_select(mask_scan) + + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low = new_point[0] - geo.shape[0]//2 + y_low = new_point[1] - geo.shape[1]//2 + z_low = new_point[2] - geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_low+geo.shape[0], y_low:y_low+geo.shape[1], z_low:z_low+geo.shape[2]] += geo + + geo_mask = geo_mask[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + # texture_map = texture_map[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + if ((tumor_type == 'medium') or (tumor_type == 'large')) and (organ_type == 'kidney'): + if random.random() > 0.5: + geo_mask = (geo_mask>=1) + else: + geo_mask = (geo_mask * mask_scan) >=1 + else: + geo_mask = (geo_mask * mask_scan) >=1 + + return geo_mask + + +from .ldm.vq_gan_3d.model.vqgan import VQGAN +import matplotlib.pyplot as plt +import SimpleITK as sitk +from .ldm.ddpm import Unet3D, GaussianDiffusion, Tester +from hydra import initialize, compose +import torch +import yaml +def synt_model_prepare(device, vqgan_ckpt='TumorGeneration/model_weight/recon_96d4_all.ckpt', diffusion_ckpt='TumorGeneration/model_weight/', fold=0, organ='liver'): + with initialize(config_path="diffusion_config/"): + cfg=compose(config_name="ddpm.yaml") + print('diffusion_ckpt',diffusion_ckpt) + vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt) + vqgan = vqgan.to(device) + vqgan.eval() + + early_Unet3D = Unet3D( + dim=cfg.diffusion_img_size, + dim_mults=cfg.dim_mults, + channels=cfg.diffusion_num_channels, + out_dim=cfg.out_dim + ).to(device) + + early_diffusion = GaussianDiffusion( + early_Unet3D, + vqgan_ckpt= vqgan_ckpt, # cfg.vqgan_ckpt, + image_size=cfg.diffusion_img_size, + num_frames=cfg.diffusion_depth_size, + channels=cfg.diffusion_num_channels, + timesteps=4, # cfg.timesteps, + loss_type=cfg.loss_type, + device=device + ).to(device) + + noearly_Unet3D = Unet3D( + dim=cfg.diffusion_img_size, + dim_mults=cfg.dim_mults, + channels=cfg.diffusion_num_channels, + out_dim=cfg.out_dim + ).to(device) + + noearly_diffusion = GaussianDiffusion( + noearly_Unet3D, + vqgan_ckpt= vqgan_ckpt, # cfg.vqgan_ckpt, + image_size=cfg.diffusion_img_size, + num_frames=cfg.diffusion_depth_size, + channels=cfg.diffusion_num_channels, + timesteps=200, # cfg.timesteps, + loss_type=cfg.loss_type, + device=device + ).to(device) + + early_tester = Tester(early_diffusion) + # noearly_tester = Tester(noearly_diffusion) + early_tester.load(diffusion_ckpt+'diff_{}_fold{}_early.pt'.format(organ, fold), map_location=device) + # noearly_tester.load(diffusion_ckpt+'diff_liver_fold{}_noearly_t200.pt'.format(fold), map_location=device) + + # early_checkpoint = torch.load(diffusion_ckpt+'diff_liver_fold{}_early.pt'.format(fold), map_location=device) + noearly_checkpoint = torch.load(diffusion_ckpt+'diff_{}_fold{}_noearly_t200.pt'.format(organ, fold), map_location=device) + # early_diffusion.load_state_dict(early_checkpoint['ema']) + noearly_diffusion.load_state_dict(noearly_checkpoint['ema']) + # early_sampler = DDIMSampler(early_diffusion, schedule="cosine") + noearly_sampler = DDIMSampler(noearly_diffusion, schedule="cosine") + # breakpoint() + return vqgan, early_tester, noearly_sampler + +def synthesize_early_tumor(ct_volume, organ_mask, organ_type, vqgan, tester): + device=ct_volume.device + + # generate tumor mask + tumor_types = ['tiny', 'small'] + tumor_probs = np.array([0.5, 0.5]) + total_tumor_mask = [] + organ_mask_np = organ_mask.cpu().numpy() + with torch.no_grad(): + # get model input + for bs in range(organ_mask_np.shape[0]): + synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type) + total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:]) + total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device) + + volume = ct_volume*2.0 - 1.0 + mask = total_tumor_mask*2.0 - 1.0 + mask_ = 1-total_tumor_mask + masked_volume = (volume*mask_).detach() + + volume = volume.permute(0,1,-1,-3,-2) + masked_volume = masked_volume.permute(0,1,-1,-3,-2) + mask = mask.permute(0,1,-1,-3,-2) + + # vqgan encoder inference + masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True) + masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) / + (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:]) + cond = torch.cat((masked_volume_feat, cc), dim=1) + + # diffusion inference and decoder + tester.ema_model.eval() + sample = tester.ema_model.sample(batch_size=volume.shape[0], cond=cond) + + # if organ_type == 'liver' or organ_type == 'kidney' : + + mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0) + sigma = np.random.uniform(0, 4) # (1, 2) + mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma]) + # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy() + + volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0) + sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0) + + mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device) + final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_ + final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0) + # elif organ_type == 'pancreas': + # final_volume_ = (sample+1.0)/2.0 + final_volume_ = final_volume_.permute(0,1,-2,-1,-3) + organ_tumor_mask = torch.ones_like(organ_mask) + organ_tumor_mask[total_tumor_mask==1] = 2 + + return final_volume_, organ_tumor_mask + +def synthesize_medium_tumor(ct_volume, organ_mask, organ_type, vqgan, sampler, ddim_ts=50): + device=ct_volume.device + + # generate tumor mask + # tumor_types = ['large'] + # tumor_probs = np.array([1.0]) + total_tumor_mask = [] + organ_mask_np = organ_mask.cpu().numpy() + with torch.no_grad(): + # get model input + for bs in range(organ_mask_np.shape[0]): + # synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + synthetic_tumor_type = 'medium' + tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type) + total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:]) + total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device) + + volume = ct_volume*2.0 - 1.0 + mask = total_tumor_mask*2.0 - 1.0 + mask_ = 1-total_tumor_mask + masked_volume = (volume*mask_).detach() + + volume = volume.permute(0,1,-1,-3,-2) + masked_volume = masked_volume.permute(0,1,-1,-3,-2) + mask = mask.permute(0,1,-1,-3,-2) + + # vqgan encoder inference + masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True) + masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) / + (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:]) + cond = torch.cat((masked_volume_feat, cc), dim=1) + + # diffusion inference and decoder + shape = masked_volume_feat.shape[-4:] + samples_ddim, _ = sampler.sample(S=ddim_ts, + conditioning=cond, + batch_size=1, + shape=shape, + verbose=False) + samples_ddim = (((samples_ddim + 1.0) / 2.0) * (vqgan.codebook.embeddings.max() - + vqgan.codebook.embeddings.min())) + vqgan.codebook.embeddings.min() + + sample = vqgan.decode(samples_ddim, quantize=True) + + # if organ_type == 'liver' or organ_type == 'kidney': + # post-process + mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0) + sigma = np.random.uniform(0, 4) # (1, 2) + mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma]) + # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy() + + volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0) + sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0) + + mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device) + final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_ + final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0) + # elif organ_type == 'pancreas': + # final_volume_ = (sample+1.0)/2.0 + + final_volume_ = final_volume_.permute(0,1,-2,-1,-3) + organ_tumor_mask = torch.ones_like(organ_mask) + organ_tumor_mask[total_tumor_mask==1] = 2 + + return final_volume_, organ_tumor_mask + +def synthesize_large_tumor(ct_volume, organ_mask, organ_type, vqgan, sampler, ddim_ts=50): + device=ct_volume.device + + # generate tumor mask + # tumor_types = ['large'] + # tumor_probs = np.array([1.0]) + total_tumor_mask = [] + organ_mask_np = organ_mask.cpu().numpy() + with torch.no_grad(): + # get model input + for bs in range(organ_mask_np.shape[0]): + # synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + synthetic_tumor_type = 'large' + tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type) + total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:]) + total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device) + + volume = ct_volume*2.0 - 1.0 + mask = total_tumor_mask*2.0 - 1.0 + mask_ = 1-total_tumor_mask + masked_volume = (volume*mask_).detach() + + volume = volume.permute(0,1,-1,-3,-2) + masked_volume = masked_volume.permute(0,1,-1,-3,-2) + mask = mask.permute(0,1,-1,-3,-2) + + # vqgan encoder inference + masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True) + masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) / + (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:]) + cond = torch.cat((masked_volume_feat, cc), dim=1) + + # diffusion inference and decoder + shape = masked_volume_feat.shape[-4:] + samples_ddim, _ = sampler.sample(S=ddim_ts, + conditioning=cond, + batch_size=1, + shape=shape, + verbose=False) + samples_ddim = (((samples_ddim + 1.0) / 2.0) * (vqgan.codebook.embeddings.max() - + vqgan.codebook.embeddings.min())) + vqgan.codebook.embeddings.min() + + sample = vqgan.decode(samples_ddim, quantize=True) + + # if organ_type == 'liver' or organ_type == 'kidney': + # post-process + mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0) + sigma = np.random.uniform(0, 4) # (1, 2) + mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma]) + # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy() + + volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0) + sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0) + + mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device) + final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_ + final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0) + # elif organ_type == 'pancreas': + # final_volume_ = (sample+1.0)/2.0 + + final_volume_ = final_volume_.permute(0,1,-2,-1,-3) + organ_tumor_mask = torch.ones_like(organ_mask) + organ_tumor_mask[total_tumor_mask==1] = 2 + + return final_volume_, organ_tumor_mask \ No newline at end of file diff --git a/Generation_Pipeline_filter/syn_liver/TumorGeneration/utils_.py b/Generation_Pipeline_filter/syn_liver/TumorGeneration/utils_.py new file mode 100644 index 0000000000000000000000000000000000000000..312e72d1571b3cd996fd567bcedf939443a4e182 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/TumorGeneration/utils_.py @@ -0,0 +1,298 @@ +### Tumor Generateion +import random +import cv2 +import elasticdeform +import numpy as np +from scipy.ndimage import gaussian_filter + +# Step 1: Random select (numbers) location for tumor. +def random_select(mask_scan): + # we first find z index and then sample point with z slice + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0][[0, -1]] + + # we need to strict number z's position (0.3 - 0.7 in the middle of liver) + z = round(random.uniform(0.3, 0.7) * (z_end - z_start)) + z_start + + liver_mask = mask_scan[..., z] + + # erode the mask (we don't want the edge points) + kernel = np.ones((5,5), dtype=np.uint8) + liver_mask = cv2.erode(liver_mask, kernel, iterations=1) + + coordinates = np.argwhere(liver_mask == 1) + random_index = np.random.randint(0, len(coordinates)) + xyz = coordinates[random_index].tolist() # get x,y + xyz.append(z) + potential_points = xyz + + return potential_points + +# Step 2 : generate the ellipsoid +def get_ellipsoid(x, y, z): + """" + x, y, z is the radius of this ellipsoid in x, y, z direction respectly. + """ + sh = (4*x, 4*y, 4*z) + out = np.zeros(sh, int) + aux = np.zeros(sh) + radii = np.array([x, y, z]) + com = np.array([2*x, 2*y, 2*z]) # center point + + # calculate the ellipsoid + bboxl = np.floor(com-radii).clip(0,None).astype(int) + bboxh = (np.ceil(com+radii)+1).clip(None, sh).astype(int) + roi = out[tuple(map(slice,bboxl,bboxh))] + roiaux = aux[tuple(map(slice,bboxl,bboxh))] + logrid = *map(np.square,np.ogrid[tuple( + map(slice,(bboxl-com)/radii,(bboxh-com-1)/radii,1j*(bboxh-bboxl)))]), + dst = (1-sum(logrid)).clip(0,None) + mask = dst>roiaux + roi[mask] = 1 + np.copyto(roiaux,dst,where=mask) + + return out + +def get_fixed_geo(mask_scan, tumor_type): + + enlarge_x, enlarge_y, enlarge_z = 160, 160, 160 + geo_mask = np.zeros((mask_scan.shape[0] + enlarge_x, mask_scan.shape[1] + enlarge_y, mask_scan.shape[2] + enlarge_z), dtype=np.int8) + tiny_radius, small_radius, medium_radius, large_radius = 4, 8, 16, 32 + + if tumor_type == 'tiny': + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Tiny tumor + x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + sigma = random.uniform(0.5,1) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + + if tumor_type == 'small': + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Small tumor + x = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + y = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + z = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + sigma = random.randint(1, 2) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'medium': + num_tumor = random.randint(2, 5) + for _ in range(num_tumor): + # medium tumor + x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + sigma = random.randint(3, 6) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste medium tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'large': + num_tumor = random.randint(1,3) + for _ in range(num_tumor): + # Large tumor + x = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + y = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + z = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + sigma = random.randint(5, 10) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == "mix": + # tiny + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Tiny tumor + x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + sigma = random.uniform(0.5,1) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + + # small + num_tumor = random.randint(5,10) + for _ in range(num_tumor): + # Small tumor + x = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + y = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + z = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + sigma = random.randint(1, 2) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + # medium + num_tumor = random.randint(2, 5) + for _ in range(num_tumor): + # medium tumor + x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + sigma = random.randint(3, 6) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste medium tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + # large + num_tumor = random.randint(1,3) + for _ in range(num_tumor): + # Large tumor + x = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + y = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + z = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + sigma = random.randint(5, 10) + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + geo_mask = geo_mask[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + # texture_map = texture_map[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + geo_mask = (geo_mask * mask_scan) >=1 + + return geo_mask + + +def get_tumor(volume_scan, mask_scan, tumor_type): + tumor_mask = get_fixed_geo(mask_scan, tumor_type) + + sigma = np.random.uniform(1, 2) + # difference = np.random.uniform(65, 145) + difference = 1 + + # blur the boundary + tumor_mask_blur = gaussian_filter(tumor_mask*1.0, sigma) + + + abnormally_full = volume_scan * (1 - mask_scan) + abnormally + abnormally_mask = mask_scan + geo_mask + + return abnormally_full, abnormally_mask + +def SynthesisTumor(volume_scan, mask_scan, tumor_type): + # for speed_generate_tumor, we only send the liver part into the generate program + x_start, x_end = np.where(np.any(mask_scan, axis=(1, 2)))[0][[0, -1]] + y_start, y_end = np.where(np.any(mask_scan, axis=(0, 2)))[0][[0, -1]] + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0][[0, -1]] + + # shrink the boundary + x_start, x_end = max(0, x_start+1), min(mask_scan.shape[0], x_end-1) + y_start, y_end = max(0, y_start+1), min(mask_scan.shape[1], y_end-1) + z_start, z_end = max(0, z_start+1), min(mask_scan.shape[2], z_end-1) + + ct_volume = volume_scan[x_start:x_end, y_start:y_end, z_start:z_end] + organ_mask = mask_scan[x_start:x_end, y_start:y_end, z_start:z_end] + + # input texture shape: 420, 300, 320 + # we need to cut it into the shape of liver_mask + # for examples, the liver_mask.shape = 286, 173, 46; we should change the texture shape + x_length, y_length, z_length = 64, 64, 64 + crop_x = random.randint(x_start, x_end - x_length - 1) # random select the start point, -1 is to avoid boundary check + crop_y = random.randint(y_start, y_end - y_length - 1) + crop_z = random.randint(z_start, z_end - z_length - 1) + + ct_volume, organ_tumor_mask = get_tumor(ct_volume, organ_mask, tumor_type) + volume_scan[x_start:x_end, y_start:y_end, z_start:z_end] = ct_volume + mask_scan[x_start:x_end, y_start:y_end, z_start:z_end] = organ_tumor_mask + + return volume_scan, mask_scan diff --git a/Generation_Pipeline_filter/syn_liver/healthy_liver_1k.txt b/Generation_Pipeline_filter/syn_liver/healthy_liver_1k.txt new file mode 100644 index 0000000000000000000000000000000000000000..74ed74167da166c49bfa98088ad2683251771bb4 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/healthy_liver_1k.txt @@ -0,0 +1,895 @@ +BDMAP_00004578 +BDMAP_00004183 +BDMAP_00002690 +BDMAP_00004295 +BDMAP_00001736 +BDMAP_00000411 +BDMAP_00003277 +BDMAP_00000696 +BDMAP_00004196 +BDMAP_00001598 +BDMAP_00001183 +BDMAP_00002626 +BDMAP_00004793 +BDMAP_00003385 +BDMAP_00005037 +BDMAP_00004652 +BDMAP_00001383 +BDMAP_00001092 +BDMAP_00004927 +BDMAP_00001618 +BDMAP_00004087 +BDMAP_00002273 +BDMAP_00001288 +BDMAP_00000043 +BDMAP_00003356 +BDMAP_00002776 +BDMAP_00003961 +BDMAP_00002422 +BDMAP_00000345 +BDMAP_00000438 +BDMAP_00001517 +BDMAP_00003564 +BDMAP_00001275 +BDMAP_00003315 +BDMAP_00002986 +BDMAP_00003514 +BDMAP_00000190 +BDMAP_00001434 +BDMAP_00003608 +BDMAP_00001995 +BDMAP_00000414 +BDMAP_00003451 +BDMAP_00002612 +BDMAP_00003744 +BDMAP_00005170 +BDMAP_00002328 +BDMAP_00002940 +BDMAP_00005020 +BDMAP_00000562 +BDMAP_00000810 +BDMAP_00003833 +BDMAP_00000320 +BDMAP_00001791 +BDMAP_00004895 +BDMAP_00003576 +BDMAP_00001924 +BDMAP_00005140 +BDMAP_00003946 +BDMAP_00005067 +BDMAP_00001102 +BDMAP_00001826 +BDMAP_00004131 +BDMAP_00003141 +BDMAP_00002758 +BDMAP_00004969 +BDMAP_00003633 +BDMAP_00004195 +BDMAP_00000030 +BDMAP_00000939 +BDMAP_00001835 +BDMAP_00003762 +BDMAP_00003215 +BDMAP_00003396 +BDMAP_00001078 +BDMAP_00003484 +BDMAP_00001096 +BDMAP_00001688 +BDMAP_00005155 +BDMAP_00005064 +BDMAP_00001862 +BDMAP_00004867 +BDMAP_00001982 +BDMAP_00002295 +BDMAP_00000062 +BDMAP_00000715 +BDMAP_00004608 +BDMAP_00000162 +BDMAP_00003558 +BDMAP_00005070 +BDMAP_00003812 +BDMAP_00000725 +BDMAP_00004624 +BDMAP_00003752 +BDMAP_00001557 +BDMAP_00002185 +BDMAP_00000093 +BDMAP_00003774 +BDMAP_00001701 +BDMAP_00004184 +BDMAP_00000873 +BDMAP_00000236 +BDMAP_00001676 +BDMAP_00001635 +BDMAP_00002475 +BDMAP_00002653 +BDMAP_00003400 +BDMAP_00001863 +BDMAP_00003017 +BDMAP_00001283 +BDMAP_00001359 +BDMAP_00001281 +BDMAP_00004293 +BDMAP_00000582 +BDMAP_00001752 +BDMAP_00004910 +BDMAP_00003373 +BDMAP_00004297 +BDMAP_00003947 +BDMAP_00003612 +BDMAP_00003598 +BDMAP_00002746 +BDMAP_00004552 +BDMAP_00002333 +BDMAP_00002580 +BDMAP_00002871 +BDMAP_00001565 +BDMAP_00003549 +BDMAP_00003976 +BDMAP_00001712 +BDMAP_00001602 +BDMAP_00000812 +BDMAP_00000353 +BDMAP_00001251 +BDMAP_00004841 +BDMAP_00000429 +BDMAP_00000432 +BDMAP_00000159 +BDMAP_00002347 +BDMAP_00002496 +BDMAP_00004735 +BDMAP_00001514 +BDMAP_00003560 +BDMAP_00001209 +BDMAP_00002313 +BDMAP_00005092 +BDMAP_00005009 +BDMAP_00004673 +BDMAP_00000547 +BDMAP_00003255 +BDMAP_00000229 +BDMAP_00001522 +BDMAP_00002426 +BDMAP_00004015 +BDMAP_00004541 +BDMAP_00003952 +BDMAP_00003853 +BDMAP_00001119 +BDMAP_00004198 +BDMAP_00004427 +BDMAP_00004417 +BDMAP_00000833 +BDMAP_00002487 +BDMAP_00002981 +BDMAP_00000653 +BDMAP_00003815 +BDMAP_00003972 +BDMAP_00000373 +BDMAP_00002864 +BDMAP_00002902 +BDMAP_00001836 +BDMAP_00004897 +BDMAP_00002889 +BDMAP_00003493 +BDMAP_00000667 +BDMAP_00004163 +BDMAP_00004586 +BDMAP_00001704 +BDMAP_00002152 +BDMAP_00001258 +BDMAP_00003827 +BDMAP_00001265 +BDMAP_00001040 +BDMAP_00004106 +BDMAP_00000059 +BDMAP_00002363 +BDMAP_00000161 +BDMAP_00001475 +BDMAP_00001747 +BDMAP_00001027 +BDMAP_00000279 +BDMAP_00002242 +BDMAP_00004175 +BDMAP_00003358 +BDMAP_00004815 +BDMAP_00003580 +BDMAP_00001068 +BDMAP_00003327 +BDMAP_00004616 +BDMAP_00000197 +BDMAP_00003740 +BDMAP_00005074 +BDMAP_00001261 +BDMAP_00002775 +BDMAP_00002545 +BDMAP_00000104 +BDMAP_00004738 +BDMAP_00005099 +BDMAP_00004672 +BDMAP_00004074 +BDMAP_00004288 +BDMAP_00003590 +BDMAP_00001545 +BDMAP_00004922 +BDMAP_00002619 +BDMAP_00000874 +BDMAP_00001438 +BDMAP_00003138 +BDMAP_00002251 +BDMAP_00003769 +BDMAP_00003267 +BDMAP_00002216 +BDMAP_00003994 +BDMAP_00002742 +BDMAP_00001089 +BDMAP_00003957 +BDMAP_00001533 +BDMAP_00004636 +BDMAP_00004499 +BDMAP_00000698 +BDMAP_00002232 +BDMAP_00004250 +BDMAP_00004491 +BDMAP_00001636 +BDMAP_00005078 +BDMAP_00004121 +BDMAP_00001845 +BDMAP_00004264 +BDMAP_00000137 +BDMAP_00003516 +BDMAP_00005017 +BDMAP_00000087 +BDMAP_00000319 +BDMAP_00001828 +BDMAP_00000948 +BDMAP_00001977 +BDMAP_00003457 +BDMAP_00005157 +BDMAP_00003150 +BDMAP_00002166 +BDMAP_00003301 +BDMAP_00003680 +BDMAP_00003133 +BDMAP_00000574 +BDMAP_00002305 +BDMAP_00004843 +BDMAP_00002230 +BDMAP_00000332 +BDMAP_00003063 +BDMAP_00002076 +BDMAP_00003319 +BDMAP_00004373 +BDMAP_00004880 +BDMAP_00000623 +BDMAP_00003631 +BDMAP_00001737 +BDMAP_00001057 +BDMAP_00002173 +BDMAP_00000139 +BDMAP_00001891 +BDMAP_00000552 +BDMAP_00004717 +BDMAP_00003172 +BDMAP_00003955 +BDMAP_00001664 +BDMAP_00003070 +BDMAP_00004550 +BDMAP_00002057 +BDMAP_00000616 +BDMAP_00000913 +BDMAP_00000388 +BDMAP_00000355 +BDMAP_00003333 +BDMAP_00004148 +BDMAP_00001985 +BDMAP_00001921 +BDMAP_00001624 +BDMAP_00004129 +BDMAP_00002598 +BDMAP_00000859 +BDMAP_00000558 +BDMAP_00002226 +BDMAP_00000452 +BDMAP_00004829 +BDMAP_00003455 +BDMAP_00002402 +BDMAP_00000117 +BDMAP_00000826 +BDMAP_00000243 +BDMAP_00002319 +BDMAP_00002737 +BDMAP_00002318 +BDMAP_00003357 +BDMAP_00000692 +BDMAP_00003427 +BDMAP_00001441 +BDMAP_00004796 +BDMAP_00002171 +BDMAP_00001296 +BDMAP_00004296 +BDMAP_00003808 +BDMAP_00003058 +BDMAP_00003502 +BDMAP_00001045 +BDMAP_00003438 +BDMAP_00002884 +BDMAP_00004561 +BDMAP_00000462 +BDMAP_00001785 +BDMAP_00000794 +BDMAP_00000942 +BDMAP_00002947 +BDMAP_00004744 +BDMAP_00004328 +BDMAP_00004671 +BDMAP_00005108 +BDMAP_00002278 +BDMAP_00000679 +BDMAP_00004903 +BDMAP_00001732 +BDMAP_00001095 +BDMAP_00003343 +BDMAP_00001289 +BDMAP_00001109 +BDMAP_00003650 +BDMAP_00001710 +BDMAP_00003031 +BDMAP_00001617 +BDMAP_00001246 +BDMAP_00004894 +BDMAP_00003520 +BDMAP_00004097 +BDMAP_00001020 +BDMAP_00003600 +BDMAP_00001518 +BDMAP_00000416 +BDMAP_00004990 +BDMAP_00005151 +BDMAP_00000132 +BDMAP_00000138 +BDMAP_00004885 +BDMAP_00000771 +BDMAP_00003928 +BDMAP_00001419 +BDMAP_00003130 +BDMAP_00001892 +BDMAP_00003886 +BDMAP_00004479 +BDMAP_00003918 +BDMAP_00003324 +BDMAP_00002410 +BDMAP_00002509 +BDMAP_00000701 +BDMAP_00003847 +BDMAP_00004450 +BDMAP_00003363 +BDMAP_00002875 +BDMAP_00002793 +BDMAP_00005113 +BDMAP_00000465 +BDMAP_00004847 +BDMAP_00004294 +BDMAP_00000936 +BDMAP_00002476 +BDMAP_00003840 +BDMAP_00004130 +BDMAP_00003614 +BDMAP_00000883 +BDMAP_00000542 +BDMAP_00002562 +BDMAP_00000285 +BDMAP_00001256 +BDMAP_00004597 +BDMAP_00002260 +BDMAP_00001067 +BDMAP_00000968 +BDMAP_00005085 +BDMAP_00003412 +BDMAP_00003884 +BDMAP_00001420 +BDMAP_00003268 +BDMAP_00001735 +BDMAP_00003392 +BDMAP_00000241 +BDMAP_00003326 +BDMAP_00001853 +BDMAP_00001126 +BDMAP_00002237 +BDMAP_00003809 +BDMAP_00001584 +BDMAP_00003359 +BDMAP_00002730 +BDMAP_00000923 +BDMAP_00000687 +BDMAP_00003281 +BDMAP_00004431 +BDMAP_00001440 +BDMAP_00001410 +BDMAP_00004650 +BDMAP_00004065 +BDMAP_00001806 +BDMAP_00002227 +BDMAP_00001906 +BDMAP_00000331 +BDMAP_00001130 +BDMAP_00003178 +BDMAP_00002707 +BDMAP_00001646 +BDMAP_00001707 +BDMAP_00003592 +BDMAP_00003943 +BDMAP_00002361 +BDMAP_00004901 +BDMAP_00003329 +BDMAP_00005075 +BDMAP_00002326 +BDMAP_00003713 +BDMAP_00003832 +BDMAP_00004165 +BDMAP_00004415 +BDMAP_00004331 +BDMAP_00001035 +BDMAP_00004457 +BDMAP_00003347 +BDMAP_00001422 +BDMAP_00002437 +BDMAP_00003996 +BDMAP_00003461 +BDMAP_00002751 +BDMAP_00002523 +BDMAP_00000439 +BDMAP_00004746 +BDMAP_00002188 +BDMAP_00004253 +BDMAP_00000935 +BDMAP_00002451 +BDMAP_00003971 +BDMAP_00000926 +BDMAP_00003109 +BDMAP_00000660 +BDMAP_00001169 +BDMAP_00001331 +BDMAP_00001175 +BDMAP_00000881 +BDMAP_00000263 +BDMAP_00002401 +BDMAP_00005167 +BDMAP_00002041 +BDMAP_00000656 +BDMAP_00000366 +BDMAP_00002582 +BDMAP_00001238 +BDMAP_00001590 +BDMAP_00001784 +BDMAP_00001564 +BDMAP_00004719 +BDMAP_00001917 +BDMAP_00003956 +BDMAP_00003225 +BDMAP_00000982 +BDMAP_00004992 +BDMAP_00003479 +BDMAP_00001215 +BDMAP_00004147 +BDMAP_00001711 +BDMAP_00000626 +BDMAP_00000516 +BDMAP_00004876 +BDMAP_00003376 +BDMAP_00001628 +BDMAP_00001148 +BDMAP_00003672 +BDMAP_00001205 +BDMAP_00004651 +BDMAP_00000987 +BDMAP_00004104 +BDMAP_00001647 +BDMAP_00000998 +BDMAP_00002244 +BDMAP_00004676 +BDMAP_00001908 +BDMAP_00000714 +BDMAP_00001104 +BDMAP_00001911 +BDMAP_00000882 +BDMAP_00003930 +BDMAP_00000368 +BDMAP_00003923 +BDMAP_00002099 +BDMAP_00000240 +BDMAP_00003658 +BDMAP_00005077 +BDMAP_00002696 +BDMAP_00002184 +BDMAP_00003890 +BDMAP_00002704 +BDMAP_00000066 +BDMAP_00005006 +BDMAP_00001242 +BDMAP_00002396 +BDMAP_00004389 +BDMAP_00002656 +BDMAP_00000469 +BDMAP_00001138 +BDMAP_00004773 +BDMAP_00004033 +BDMAP_00004128 +BDMAP_00002631 +BDMAP_00004925 +BDMAP_00004475 +BDMAP_00001521 +BDMAP_00000364 +BDMAP_00002953 +BDMAP_00003776 +BDMAP_00004154 +BDMAP_00002654 +BDMAP_00002959 +BDMAP_00002199 +BDMAP_00003551 +BDMAP_00002465 +BDMAP_00005154 +BDMAP_00002648 +BDMAP_00000128 +BDMAP_00001001 +BDMAP_00002017 +BDMAP_00004712 +BDMAP_00004286 +BDMAP_00000568 +BDMAP_00004858 +BDMAP_00001782 +BDMAP_00001496 +BDMAP_00004407 +BDMAP_00002250 +BDMAP_00001212 +BDMAP_00000972 +BDMAP_00004374 +BDMAP_00002846 +BDMAP_00002472 +BDMAP_00000569 +BDMAP_00004981 +BDMAP_00000176 +BDMAP_00003510 +BDMAP_00003771 +BDMAP_00002804 +BDMAP_00004558 +BDMAP_00003411 +BDMAP_00001563 +BDMAP_00000604 +BDMAP_00002075 +BDMAP_00005160 +BDMAP_00001511 +BDMAP_00001273 +BDMAP_00002603 +BDMAP_00001656 +BDMAP_00003822 +BDMAP_00004510 +BDMAP_00001809 +BDMAP_00002944 +BDMAP_00002739 +BDMAP_00002609 +BDMAP_00003849 +BDMAP_00001128 +BDMAP_00003717 +BDMAP_00000036 +BDMAP_00002863 +BDMAP_00004956 +BDMAP_00004229 +BDMAP_00003425 +BDMAP_00001865 +BDMAP_00000608 +BDMAP_00004620 +BDMAP_00000589 +BDMAP_00001597 +BDMAP_00003543 +BDMAP_00004645 +BDMAP_00004395 +BDMAP_00005105 +BDMAP_00001426 +BDMAP_00000264 +BDMAP_00001504 +BDMAP_00001649 +BDMAP_00000662 +BDMAP_00002854 +BDMAP_00004060 +BDMAP_00003440 +BDMAP_00003367 +BDMAP_00004011 +BDMAP_00003634 +BDMAP_00003443 +BDMAP_00000828 +BDMAP_00000889 +BDMAP_00000321 +BDMAP_00004615 +BDMAP_00000244 +BDMAP_00003685 +BDMAP_00001461 +BDMAP_00001396 +BDMAP_00004262 +BDMAP_00004579 +BDMAP_00005022 +BDMAP_00004804 +BDMAP_00001632 +BDMAP_00002661 +BDMAP_00000980 +BDMAP_00001445 +BDMAP_00000809 +BDMAP_00004384 +BDMAP_00003114 +BDMAP_00000435 +BDMAP_00003406 +BDMAP_00002899 +BDMAP_00002164 +BDMAP_00002498 +BDMAP_00000039 +BDMAP_00002524 +BDMAP_00000805 +BDMAP_00004604 +BDMAP_00000338 +BDMAP_00002990 +BDMAP_00001516 +BDMAP_00002896 +BDMAP_00004549 +BDMAP_00000259 +BDMAP_00001945 +BDMAP_00002695 +BDMAP_00005141 +BDMAP_00002828 +BDMAP_00003781 +BDMAP_00003900 +BDMAP_00004278 +BDMAP_00004551 +BDMAP_00000532 +BDMAP_00002844 +BDMAP_00001476 +BDMAP_00004887 +BDMAP_00005174 +BDMAP_00000836 +BDMAP_00001456 +BDMAP_00001607 +BDMAP_00003164 +BDMAP_00002404 +BDMAP_00003036 +BDMAP_00001225 +BDMAP_00002022 +BDMAP_00004030 +BDMAP_00000329 +BDMAP_00002253 +BDMAP_00000154 +BDMAP_00003111 +BDMAP_00003384 +BDMAP_00000023 +BDMAP_00001125 +BDMAP_00001414 +BDMAP_00002383 +BDMAP_00003483 +BDMAP_00000034 +BDMAP_00001413 +BDMAP_00003767 +BDMAP_00001368 +BDMAP_00003448 +BDMAP_00000940 +BDMAP_00000430 +BDMAP_00003153 +BDMAP_00003603 +BDMAP_00003202 +BDMAP_00002421 +BDMAP_00005001 +BDMAP_00004447 +BDMAP_00001325 +BDMAP_00003168 +BDMAP_00000887 +BDMAP_00004481 +BDMAP_00001324 +BDMAP_00004066 +BDMAP_00001474 +BDMAP_00004850 +BDMAP_00002233 +BDMAP_00000511 +BDMAP_00001223 +BDMAP_00003581 +BDMAP_00002930 +BDMAP_00001305 +BDMAP_00002689 +BDMAP_00002332 +BDMAP_00000683 +BDMAP_00003300 +BDMAP_00003701 +BDMAP_00001015 +BDMAP_00001562 +BDMAP_00001898 +BDMAP_00001247 +BDMAP_00001941 +BDMAP_00002840 +BDMAP_00002440 +BDMAP_00000245 +BDMAP_00002855 +BDMAP_00004493 +BDMAP_00000989 +BDMAP_00003736 +BDMAP_00002265 +BDMAP_00004039 +BDMAP_00002826 +BDMAP_00002924 +BDMAP_00003299 +BDMAP_00001361 +BDMAP_00004014 +BDMAP_00001444 +BDMAP_00001370 +BDMAP_00002304 +BDMAP_00000774 +BDMAP_00000614 +BDMAP_00000434 +BDMAP_00001230 +BDMAP_00000044 +BDMAP_00001768 +BDMAP_00004783 +BDMAP_00004494 +BDMAP_00001905 +BDMAP_00003824 +BDMAP_00002309 +BDMAP_00004511 +BDMAP_00000233 +BDMAP_00002845 +BDMAP_00005016 +BDMAP_00002829 +BDMAP_00001059 +BDMAP_00001549 +BDMAP_00002403 +BDMAP_00001794 +BDMAP_00001286 +BDMAP_00003294 +BDMAP_00003722 +BDMAP_00000902 +BDMAP_00002298 +BDMAP_00005191 +BDMAP_00001487 +BDMAP_00003364 +BDMAP_00001605 +BDMAP_00001483 +BDMAP_00000676 +BDMAP_00002945 +BDMAP_00005073 +BDMAP_00002085 +BDMAP_00000716 +BDMAP_00003435 +BDMAP_00002803 +BDMAP_00002663 +BDMAP_00003727 +BDMAP_00000839 +BDMAP_00002068 +BDMAP_00004764 +BDMAP_00002114 +BDMAP_00004741 +BDMAP_00004077 +BDMAP_00004870 +BDMAP_00000571 +BDMAP_00004115 +BDMAP_00001868 +BDMAP_00004113 +BDMAP_00002039 +BDMAP_00004257 +BDMAP_00001620 +BDMAP_00000470 +BDMAP_00000149 +BDMAP_00002815 +BDMAP_00000304 +BDMAP_00005185 +BDMAP_00003113 +BDMAP_00005063 +BDMAP_00000122 +BDMAP_00004482 +BDMAP_00002471 +BDMAP_00004023 +BDMAP_00000225 +BDMAP_00003657 +BDMAP_00001255 +BDMAP_00002616 +BDMAP_00002407 +BDMAP_00002060 +BDMAP_00004546 +BDMAP_00004917 +BDMAP_00003615 +BDMAP_00003525 +BDMAP_00002120 +BDMAP_00000481 +BDMAP_00004770 +BDMAP_00003683 +BDMAP_00000618 +BDMAP_00001875 +BDMAP_00003409 +BDMAP_00003381 +BDMAP_00004398 +BDMAP_00000867 +BDMAP_00000487 +BDMAP_00003073 +BDMAP_00002592 +BDMAP_00005120 +BDMAP_00003128 +BDMAP_00001754 +BDMAP_00004232 +BDMAP_00000855 +BDMAP_00000069 +BDMAP_00002744 +BDMAP_00004808 +BDMAP_00004031 +BDMAP_00001842 +BDMAP_00000324 +BDMAP_00002933 +BDMAP_00004954 +BDMAP_00000541 +BDMAP_00002458 +BDMAP_00002288 +BDMAP_00002807 +BDMAP_00000837 +BDMAP_00002065 +BDMAP_00000152 +BDMAP_00003491 +BDMAP_00001464 +BDMAP_00003486 +BDMAP_00003244 +BDMAP_00000871 +BDMAP_00002362 +BDMAP_00000993 +BDMAP_00000219 +BDMAP_00000192 +BDMAP_00001218 +BDMAP_00001024 +BDMAP_00004980 +BDMAP_00000713 +BDMAP_00001523 +BDMAP_00002688 +BDMAP_00003143 +BDMAP_00005114 +BDMAP_00003749 +BDMAP_00002354 +BDMAP_00000052 +BDMAP_00002710 +BDMAP_00004817 +BDMAP_00004964 +BDMAP_00004775 +BDMAP_00005005 +BDMAP_00004216 +BDMAP_00002936 +BDMAP_00000956 +BDMAP_00002942 +BDMAP_00001705 +BDMAP_00001823 +BDMAP_00002387 +BDMAP_00000690 +BDMAP_00002021 +BDMAP_00000851 +BDMAP_00000427 +BDMAP_00002133 +BDMAP_00004231 +BDMAP_00005169 +BDMAP_00003640 +BDMAP_00000977 +BDMAP_00002103 +BDMAP_00000449 +BDMAP_00001214 +BDMAP_00003506 +BDMAP_00002411 +BDMAP_00003973 +BDMAP_00001912 +BDMAP_00000710 +BDMAP_00004514 +BDMAP_00001807 +BDMAP_00001769 +BDMAP_00001746 +BDMAP_00001804 +BDMAP_00002484 +BDMAP_00003444 +BDMAP_00002029 +BDMAP_00001237 +BDMAP_00004420 +BDMAP_00000431 +BDMAP_00003252 +BDMAP_00005081 +BDMAP_00003694 +BDMAP_00002655 +BDMAP_00004641 +BDMAP_00000297 +BDMAP_00001077 +BDMAP_00003254 +BDMAP_00000447 +BDMAP_00004834 diff --git a/Generation_Pipeline_filter/syn_liver/out/BDMAP_00000411/ct.nii.gz b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00000411/ct.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..7edc8b76382e04f5c986db3413027fc4c0aad857 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00000411/ct.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ccce1a9d60005f0ebedea7a9bfc7f4ca0228d2ba6285d52d6ef40eda6714f6f3 +size 20965188 diff --git a/Generation_Pipeline_filter/syn_liver/out/BDMAP_00000411/segmentations/liver_tumor.nii.gz b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00000411/segmentations/liver_tumor.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..081f89a3ef28cf7e7129292e724f9598792ed6af --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00000411/segmentations/liver_tumor.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:1ddbe505108fb18bca8b6edf1150249647080ad0ef05bfa750409590ef2880f9 +size 64437 diff --git a/Generation_Pipeline_filter/syn_liver/out/BDMAP_00000696/ct.nii.gz b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00000696/ct.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..ee3e0ec8edab76d7ec29fb2c636701cf76cae3ca --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00000696/ct.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b6e057c93fe99ecf8b9d2c81fc2f4535df63b19fefd623b8f7ad96b8d1463e1a +size 27445316 diff --git a/Generation_Pipeline_filter/syn_liver/out/BDMAP_00000696/segmentations/liver_tumor.nii.gz b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00000696/segmentations/liver_tumor.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..c8ea5f042f4dad2a0bb30099741c801d6219b812 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00000696/segmentations/liver_tumor.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:12e2b1662f93f991b04ac06b4c50573b8321a49671a65366d1de28769ea37805 +size 76815 diff --git a/Generation_Pipeline_filter/syn_liver/out/BDMAP_00001736/ct.nii.gz b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00001736/ct.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..c9ed74289806b22ccfa5587d9d5a509738340012 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00001736/ct.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f2eb22ba25ca011daff851e2a3c2d96e0378df78586caadde56217a049fc37f5 +size 16868219 diff --git a/Generation_Pipeline_filter/syn_liver/out/BDMAP_00001736/segmentations/liver_tumor.nii.gz b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00001736/segmentations/liver_tumor.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..b1b1024bfa6be3254254b3362eb6bfb21c91bdce --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00001736/segmentations/liver_tumor.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ca7474ac04c3fb9ff09bf080f4b0c3c9d054d9f6b3f56a8b9a09c4283fa2ac31 +size 46592 diff --git a/Generation_Pipeline_filter/syn_liver/out/BDMAP_00002690/ct.nii.gz b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00002690/ct.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..f04d51e1b78baa500ba5b0365631a1ec99b89576 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00002690/ct.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:fad3428ec7670ad5d3d277275732c148cf9b7510a261e47dd3f967f17cae3511 +size 20584373 diff --git a/Generation_Pipeline_filter/syn_liver/out/BDMAP_00002690/segmentations/liver_tumor.nii.gz b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00002690/segmentations/liver_tumor.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..1135e837e6436cdb526e4e9a75b1226c29d37b28 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00002690/segmentations/liver_tumor.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:00e316a875b7b1ed74b42023fbf58dab903030e65be627f0857038598a902d83 +size 53855 diff --git a/Generation_Pipeline_filter/syn_liver/out/BDMAP_00003277/ct.nii.gz b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00003277/ct.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..a240cb76dd09fd1e4da05a51cce12c4796e0deec --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00003277/ct.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:30e2630e9acb2526793d24b67ee57a80beb430d09980444cc3ca4cd61504214e +size 26242723 diff --git a/Generation_Pipeline_filter/syn_liver/out/BDMAP_00003277/segmentations/liver_tumor.nii.gz b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00003277/segmentations/liver_tumor.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..cdb627effd08232f3e8a2e2239c508fae3d9eb23 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00003277/segmentations/liver_tumor.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b993b9c74d95feebd2a4c7a1617cd2c4d21ca926096e5d1df39e760a802aca70 +size 76864 diff --git a/Generation_Pipeline_filter/syn_liver/out/BDMAP_00004183/ct.nii.gz b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00004183/ct.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..52ba74996742ffcd0e59674b4ab9f4c5bff036b2 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00004183/ct.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:13a7ee052fc2fbaf0139f3a2fcd280723330e986407038348e23658097d0cc44 +size 25002393 diff --git a/Generation_Pipeline_filter/syn_liver/out/BDMAP_00004183/segmentations/liver_tumor.nii.gz b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00004183/segmentations/liver_tumor.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..21185b4710e5bd59041c7669f3f8238fe24896f2 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00004183/segmentations/liver_tumor.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6d88c86229effcde00b19dc23ae357edb5656f0a702cabca6613450201652702 +size 66834 diff --git a/Generation_Pipeline_filter/syn_liver/out/BDMAP_00004295/ct.nii.gz b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00004295/ct.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..e29995333fdf7abc66c9e4b60db7659d57610d58 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00004295/ct.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b07ba1abc9f9a530487ce875a9362a8fa83123cccb6210a474c5a4a6b664a566 +size 29591389 diff --git a/Generation_Pipeline_filter/syn_liver/out/BDMAP_00004295/segmentations/liver_tumor.nii.gz b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00004295/segmentations/liver_tumor.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..7b6328910407068433d61c26b1d172ddcddc6e3e --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00004295/segmentations/liver_tumor.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8227c09d33bb25cac6668f8d354501049da6884a3be8f5c38c2bfcb3120adab1 +size 83006 diff --git a/Generation_Pipeline_filter/syn_liver/out/BDMAP_00004578/ct.nii.gz b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00004578/ct.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..e0ad7f3c8326ae9254548bdb723ea816cbeb4226 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00004578/ct.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:6afaac2847ba7b1be2e6a8fff1b285d7ffbe6aa3a21cb4a4daea2232bb7ad9ab +size 22153425 diff --git a/Generation_Pipeline_filter/syn_liver/out/BDMAP_00004578/segmentations/liver_tumor.nii.gz b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00004578/segmentations/liver_tumor.nii.gz new file mode 100644 index 0000000000000000000000000000000000000000..4a8c58f425c5713a43f0ba1d39b34b75d768a58a --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/out/BDMAP_00004578/segmentations/liver_tumor.nii.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4ff6073b1f1266c2fed6a0caeac0102fb157b6ffcccced9779c2e23a3b4aba09 +size 66167 diff --git a/Generation_Pipeline_filter/syn_liver/requirements.txt b/Generation_Pipeline_filter/syn_liver/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9c1067c8eb4f44699bbf517601396113cea4b370 --- /dev/null +++ b/Generation_Pipeline_filter/syn_liver/requirements.txt @@ -0,0 +1,94 @@ +absl-py==1.1.0 +accelerate==0.11.0 +aiohttp==3.8.1 +aiosignal==1.2.0 +antlr4-python3-runtime==4.9.3 +async-timeout==4.0.2 +attrs==21.4.0 +autopep8==1.6.0 +cachetools==5.2.0 +certifi==2022.6.15 +charset-normalizer==2.0.12 +click==8.1.3 +cycler==0.11.0 +Deprecated==1.2.13 +docker-pycreds==0.4.0 +einops==0.4.1 +einops-exts==0.0.3 +ema-pytorch==0.0.8 +fonttools==4.34.4 +frozenlist==1.3.0 +fsspec==2022.5.0 +ftfy==6.1.1 +future==0.18.2 +gitdb==4.0.9 +GitPython==3.1.27 +google-auth==2.9.0 +google-auth-oauthlib==0.4.6 +grpcio==1.47.0 +h5py==3.7.0 +humanize==4.2.2 +hydra-core==1.2.0 +idna==3.3 +imageio==2.19.3 +imageio-ffmpeg==0.4.7 +importlib-metadata==4.12.0 +importlib-resources==5.9.0 +joblib==1.1.0 +kiwisolver==1.4.3 +lxml==4.9.1 +Markdown==3.3.7 +matplotlib==3.5.2 +multidict==6.0.2 +networkx==2.8.5 +nibabel==4.0.1 +nilearn==0.9.1 +numpy==1.23.0 +oauthlib==3.2.0 +omegaconf==2.2.3 +pandas==1.4.3 +Pillow==9.1.1 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pycodestyle==2.8.0 +pyDeprecate==0.3.1 +pydicom==2.3.0 +pytorch-lightning==1.6.4 +pytz==2022.1 +PyWavelets==1.3.0 +PyYAML==6.0 +pyzmq==19.0.2 +regex==2022.6.2 +requests==2.28.0 +requests-oauthlib==1.3.1 +rotary-embedding-torch==0.1.5 +rsa==4.8 +scikit-image==0.19.3 +scikit-learn==1.1.2 +scikit-video==1.1.11 +scipy==1.8.1 +seaborn==0.11.2 +sentry-sdk==1.7.2 +setproctitle==1.2.3 +shortuuid==1.0.9 +SimpleITK==2.1.1.2 +smmap==5.0.0 +tensorboard==2.9.1 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.1 +threadpoolctl==3.1.0 +tifffile==2022.8.3 +toml==0.10.2 +torch-tb-profiler==0.4.0 +torchio==0.18.80 +torchmetrics==0.9.1 +tqdm==4.64.0 +typing_extensions==4.2.0 +urllib3==1.26.9 +wandb==0.12.21 +Werkzeug==2.1.2 +wrapt==1.14.1 +yarl==1.7.2 +zipp==3.8.0 +wandb +tensorboardX==2.4.1 diff --git a/Generation_Pipeline_filter/syn_pancreas/CT_syn_pancreas_data_new.py b/Generation_Pipeline_filter/syn_pancreas/CT_syn_pancreas_data_new.py new file mode 100644 index 0000000000000000000000000000000000000000..bb5272568ad1183b75c4aa820c7bc6fd59cee6b1 --- /dev/null +++ b/Generation_Pipeline_filter/syn_pancreas/CT_syn_pancreas_data_new.py @@ -0,0 +1,242 @@ +import os, time, csv +import numpy as np +import torch +from sklearn.metrics import confusion_matrix +from scipy import ndimage +from scipy.ndimage import label +from functools import partial +import monai +from monai.transforms import AsDiscrete,AsDiscreted,Compose,Invertd,SaveImaged +from monai import transforms, data +from TumorGeneration.utils import synthesize_early_tumor, synthesize_medium_tumor, synthesize_large_tumor, synt_model_prepare +import nibabel as nib + +import warnings +warnings.filterwarnings("ignore") + +import argparse +parser = argparse.ArgumentParser(description='pancreas tumor validation') + +# file dir +parser.add_argument('--data_root', default=None, type=str) +parser.add_argument('--organ_type', default='pancreas', type=str) +parser.add_argument('--save_dir', default='out', type=str) +parser.add_argument('--data_file', default='out', type=str) +parser.add_argument('--ddim_ts', default=50, type=int) +parser.add_argument('--fg_thresh', default=30, type=int) +parser.add_argument('--start', default=0, type=int) +parser.add_argument('--end', default=1000, type=int) + +def voxel2R(A): + return (np.array(A)/4*3/np.pi)**(1/3) + +class RandCropByPosNegLabeld_select(transforms.RandCropByPosNegLabeld): + def __init__(self, keys, label_key, spatial_size, + pos=1.0, neg=1.0, num_samples=1, + image_key=None, image_threshold=0.0, allow_missing_keys=True, + fg_thresh=0): + super().__init__(keys=keys, label_key=label_key, spatial_size=spatial_size, + pos=pos, neg=neg, num_samples=num_samples, + image_key=image_key, image_threshold=image_threshold, allow_missing_keys=allow_missing_keys) + self.fg_thresh = fg_thresh + + def R2voxel(self,R): + return (4/3*np.pi)*(R)**(3) + + def __call__(self, data): + d = dict(data) + data_name = d['name'] + d.pop('name') + + if '10_Decathlon' in data_name or '05_KiTS' in data_name: + d_crop = super().__call__(d) + + else: + flag=0 + while 1: + flag+=1 + + d_crop = super().__call__(d) + pixel_num = (d_crop[0]['label']>0).sum() + + if pixel_num > self.R2voxel(self.fg_thresh): + break + if flag>5 and pixel_num > self.R2voxel(max(self.fg_thresh-5, 5)): + break + if flag>10 and pixel_num > self.R2voxel(max(self.fg_thresh-10, 5)): + break + if flag>15 and pixel_num > self.R2voxel(max(self.fg_thresh-15, 5)): + break + if flag>20 and pixel_num > self.R2voxel(max(self.fg_thresh-20, 5)): + break + if flag>25 and pixel_num > self.R2voxel(max(self.fg_thresh-25, 5)): + break + if flag>50: + break + + d_crop[0]['name'] = data_name + + return d_crop + +def _get_loader(args): + # val_data_dir = args.val_dir + # datalist_json = args.json_dir + val_org_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label", "raw_image"]), + transforms.AddChanneld(keys=["image", "label", "raw_image"]), + transforms.Orientationd(keys=["image", "label"], axcodes="RAS"), + transforms.Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear")), + transforms.ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), + transforms.SpatialPadd(keys=["image", "label"], mode=["minimum", "constant"], spatial_size=[96, 96, 96]), + RandCropByPosNegLabeld_select( + keys=["image", "label", "name"], + label_key="label", + spatial_size=(96, 96, 96), + pos=1, + neg=0, + num_samples=1, + image_key="image", + image_threshold=0, + fg_thresh = args.fg_thresh, + ), + transforms.ToTensord(keys=["image", "label", "raw_image"]), + ] + ) + + val_img=[] + val_lbl=[] + val_name=[] + + for line in open(args.data_file): + # name = line.strip().split()[1].split('.')[0] + # val_img.append(args.data_root + line.strip().split()[0]) + # val_lbl.append(args.data_root + line.strip().split()[1]) + # breakpoint() + name = line.strip() + val_img.append(os.path.join(args.data_root, name, 'ct.nii.gz')) + val_lbl.append(os.path.join(args.data_root, name, 'segmentations/pancreas.nii.gz')) + val_name.append(name) + data_dicts_val = [{'image': image, 'raw_image':image, 'label': label, 'name': name} + for image, label, name in zip(val_img, val_lbl, val_name)] + + if args.end < len(data_dicts_val): + data_dicts_val = data_dicts_val[args.start:args.end] + else: + data_dicts_val = data_dicts_val[args.start:] + print('val len {}'.format(len(data_dicts_val))) + val_org_ds = data.Dataset(data_dicts_val, transform=val_org_transform) + val_org_loader = data.DataLoader(val_org_ds, batch_size=1, shuffle=False, num_workers=4, sampler=None, pin_memory=True) + + post_transforms = Compose([ + Invertd( + keys=['image'], + transform=val_org_transform, + orig_keys="image", + nearest_interp=False, + # nearest_interp=True, + to_tensor=True, + ), + Invertd( + keys=['label'], + transform=val_org_transform, + orig_keys="label", + nearest_interp=False, + # nearest_interp=True, + to_tensor=True, + ) + ]) + return val_org_loader, post_transforms + +def main(): + args = parser.parse_args() + output_dir = args.save_dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + print("MAIN Argument values:") + for k, v in vars(args).items(): + print(k, '=>', v) + print('-----------------') + + ## loader and post_transform + val_loader, post_transforms = _get_loader(args) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device) + model.load_state_dict(torch.load("../best_metric_model_classification3d_dict.pth")) + model.eval() + + start_time=0 + with torch.no_grad(): + for idx, val_data in enumerate(val_loader): + print('idx',idx) + if idx == 0: + start_time = time.time() + # val_inputs = val_data["image"] + # name = val_data['label_meta_dict']['filename_or_obj'][0].split('/')[-1].split('.')[0] + + vqgan, early_sampler, noearly_sampler= synt_model_prepare(device = torch.device("cuda"), fold=0, organ=args.organ_type) + + healthy_data, healthy_target, data_names, raw_data = val_data['image'], val_data['label'], val_data['name'], val_data['raw_image'] + case_name = data_names[0].split('/')[-1] + print('case_name', case_name) + original_affine = val_data["label_meta_dict"]["original_affine"][0].numpy() + + if healthy_target.sum() == 0: + val_data = [post_transforms(i) for i in data.decollate_batch(val_data)] + tumor_mask = val_data[0]['label'][0].cpu().numpy().astype(np.uint8) + tumor_mask_ = np.zeros_like(tumor_mask) + nib.save(nib.Nifti1Image(tumor_mask_, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/pancreas_tumor.nii.gz')) + continue + + healthy_data, healthy_target = healthy_data.cuda(), healthy_target.cuda() + healthy_target = (healthy_target>0).to(healthy_target) + + tumor_types = ['early', 'medium', 'large'] + # tumor_probs = np.array([0.45, 0.45, 0.1]) + # tumor_probs = np.array([1.0, 0.0, 0.0]) + tumor_probs = np.array([0.5, 0.4, 0.1]) + synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + print('synthetic_tumor_type',synthetic_tumor_type) + flag=0 + while 1: + flag+=1 + if synthetic_tumor_type == 'early': + synt_data, synt_target = synthesize_early_tumor(healthy_data, healthy_target, args.organ_type, vqgan, early_sampler) + elif synthetic_tumor_type == 'medium': + synt_data, synt_target = synthesize_medium_tumor(healthy_data, healthy_target, args.organ_type, vqgan, noearly_sampler, ddim_ts=args.ddim_ts) + elif synthetic_tumor_type == 'large': + synt_data, synt_target = synthesize_large_tumor(healthy_data, healthy_target, args.organ_type, vqgan, noearly_sampler, ddim_ts=args.ddim_ts) + + syn_confidence = model(synt_data).sigmoid()[:,1] + if syn_confidence>0.01: + break + elif flag > 40 and syn_confidence>0.005: + break + elif flag > 60 and syn_confidence>0.001: + break + + val_data['image'] = synt_data.detach() + val_data['label'] = synt_target.detach() + + val_data = [post_transforms(i) for i in data.decollate_batch(val_data)] + synt_data = val_data[0]['image'][0] + synt_target = val_data[0]['label'][0] + final_data = raw_data[0,0] + + synt_data = (synt_data*(250+175)-175) + final_data[synt_target>1] = synt_data[synt_target>1] + final_data = final_data.cpu().numpy() + final_label = (synt_target>=1.5).cpu().numpy().astype(np.uint8) + + os.makedirs(os.path.join(output_dir, f'{case_name}'), exist_ok=True) + os.makedirs(os.path.join(output_dir, f'{case_name}/segmentations'), exist_ok=True) + nib.save(nib.Nifti1Image(final_data, original_affine), os.path.join(output_dir, f'{case_name}/ct.nii.gz')) + nib.save(nib.Nifti1Image(final_label, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/pancreas_tumor.nii.gz')) + + print('time = ', time.time()-start_time) + start_time = time.time() + + +if __name__ == "__main__": + main() diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/.DS_Store b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..fb0d207e124cfecc777faae5b4a56c5ca1b9cd2e Binary files /dev/null and b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/.DS_Store differ diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/README.md b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/README.md new file mode 100644 index 0000000000000000000000000000000000000000..201911031ca988c410781a60dc3c192a65ee56b3 --- /dev/null +++ b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/README.md @@ -0,0 +1,5 @@ +```bash +wget https://www.dropbox.com/scl/fi/k856fhk60kck8uqxtxazw/model_weight.tar.gz?rlkey=hrcn4cbt690dzern1bkbfejll +mv model_weight.tar.gz?rlkey=hrcn4cbt690dzern1bkbfejll model_weight.tar.gz +tar -xzvf model_weight.tar.gz +``` diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/TumorGenerated.py b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/TumorGenerated.py new file mode 100644 index 0000000000000000000000000000000000000000..9983cf047b1532f5edd20c0d4c78102d6d611eb9 --- /dev/null +++ b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/TumorGenerated.py @@ -0,0 +1,39 @@ +import random +from typing import Hashable, Mapping, Dict + +from monai.config import KeysCollection +from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms.transform import MapTransform, RandomizableTransform + +from .utils_ import SynthesisTumor +import numpy as np + +class TumorGenerated(RandomizableTransform, MapTransform): + def __init__(self, + keys: KeysCollection, + prob: float = 0.1, + tumor_prob = [0.2, 0.2, 0.2, 0.2, 0.2], + allow_missing_keys: bool = False + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob) + random.seed(0) + np.random.seed(0) + + self.tumor_types = ['tiny', 'small', 'medium', 'large', 'mix'] + + assert len(tumor_prob) == 5 + self.tumor_prob = np.array(tumor_prob) + + + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + self.randomize(None) + + if self._do_transform and (np.max(d['label']) <= 1): + tumor_type = np.random.choice(self.tumor_types, p=self.tumor_prob.ravel()) + + d['image'][0], d['label'][0] = SynthesisTumor(d['image'][0], d['label'][0], tumor_type) + # print(tumor_type, d['image'].shape, np.max(d['label'])) + return d diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/__init__.py b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1afc8a195ba5fd106ca18d4e219c123a75e6e831 --- /dev/null +++ b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/__init__.py @@ -0,0 +1,5 @@ +### Online Version TumorGeneration ### + +from .TumorGenerated import TumorGenerated + +from .utils import synthesize_early_tumor, synthesize_medium_tumor, synthesize_large_tumor \ No newline at end of file diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02cb9a3bda15ccd819e927754d80c3538090b4d2 Binary files /dev/null and b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc1142ed95f5cd5ccc2ab25fb5f501ffd782f335 Binary files /dev/null and b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/__pycache__/utils.cpython-38.pyc b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad72b834e61a83099d491b2a359824c51a42beb4 Binary files /dev/null and b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/__pycache__/utils.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/__pycache__/utils_.cpython-38.pyc b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/__pycache__/utils_.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c67d98211992fb7cb17035aac7241ee0afb0ae44 Binary files /dev/null and b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/__pycache__/utils_.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/diffusion_config/ddpm.yaml b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/diffusion_config/ddpm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..474a973b1f76c2a026e2df916c4a153b5bbea05f --- /dev/null +++ b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/diffusion_config/ddpm.yaml @@ -0,0 +1,29 @@ +vqgan_ckpt: None + +# Have to be derived from VQ-GAN Latent space dimensions +diffusion_img_size: 24 +diffusion_depth_size: 24 +diffusion_num_channels: 17 # 17 +out_dim: 8 +dim_mults: [1,2,4,8] +results_folder: checkpoints/ddpm/ +results_folder_postfix: 'own_dataset_t2' +load_milestone: False # False + +batch_size: 2 # 40 +num_workers: 20 +logger: wandb +objective: pred_x0 +save_and_sample_every: 1000 +denoising_fn: Unet3D +train_lr: 1e-4 +timesteps: 2 # number of steps +sampling_timesteps: 250 # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) +loss_type: l1 # L1 or L2 +train_num_steps: 700000 # total training steps +gradient_accumulate_every: 2 # gradient accumulation steps +ema_decay: 0.995 # exponential moving average decay +amp: False # turn on mixed precision +num_sample_rows: 1 +gpus: 0 + diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/diffusion_config/vq_gan_3d.yaml b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/diffusion_config/vq_gan_3d.yaml new file mode 100644 index 0000000000000000000000000000000000000000..29940377c5cadb0c06322de8ac60d0713f799024 --- /dev/null +++ b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/diffusion_config/vq_gan_3d.yaml @@ -0,0 +1,37 @@ +seed: 1234 +batch_size: 2 # 30 +num_workers: 32 # 30 + +gpus: 1 +accumulate_grad_batches: 1 +default_root_dir: checkpoints/vq_gan/ +default_root_dir_postfix: 'flair' +resume_from_checkpoint: +max_steps: -1 +max_epochs: -1 +precision: 16 +gradient_clip_val: 1.0 + + +embedding_dim: 8 # 256 +n_codes: 16384 # 2048 +n_hiddens: 16 +lr: 3e-4 +downsample: [2, 2, 2] # [4, 4, 4] +disc_channels: 64 +disc_layers: 3 +discriminator_iter_start: 10000 # 50000 +disc_loss_type: hinge +image_gan_weight: 1.0 +video_gan_weight: 1.0 +l1_weight: 4.0 +gan_feat_weight: 4.0 # 0.0 +perceptual_weight: 4.0 # 0.0 +i3d_feat: False +restart_thres: 1.0 +no_random_restart: False +norm_type: group +padding_type: replicate +num_groups: 32 + + diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__init__.py b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a01336e46e37471097d8e1420d0ffd5d803a1edd --- /dev/null +++ b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__init__.py @@ -0,0 +1 @@ +from .diffusion import Unet3D, GaussianDiffusion, Tester diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fe5bab6ebc89b70a4b1a01afcaff758e59ed73b Binary files /dev/null and b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c25d2c3abacd1a912c6753385896dd607193ab33 Binary files /dev/null and b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8479abbdbbc5bcbd3be9dda38a303ff83cef8fc5 Binary files /dev/null and b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d933fb26b3a9bad8718b0fc74180a3e595d19854 Binary files /dev/null and b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7ac1a6aea44a1d1741652f2f0a598fb18b2730d Binary files /dev/null and b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20c6786f9995cbd455c900944b0ab9501a97c202 Binary files /dev/null and b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12e044abbf40919d2be472579cb32f762b2d8f4e Binary files /dev/null and b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/ddim.py b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..1bc51f94355732aedb0ffd254b8166144c468370 --- /dev/null +++ b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/ddim.py @@ -0,0 +1,206 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from .util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): # "uniform" 'quad' + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + # breakpoint() + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, T, H, W = shape + # breakpoint() + size = (batch_size, C, T, H, W) + # print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + # iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(time_range): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + # breakpoint() + e_t = self.model.denoise_fn(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.denoise_fn(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/diffusion.py b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..87111dc3b2ab671e798efc3a320ae81d024f20b9 --- /dev/null +++ b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/diffusion.py @@ -0,0 +1,1016 @@ +"Largely taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch" + +import math +import copy +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial + +from torch.utils import data +from pathlib import Path +from torch.optim import Adam +from torchvision import transforms as T, utils +from torch.cuda.amp import autocast, GradScaler +from PIL import Image + +from tqdm import tqdm +from einops import rearrange +from einops_exts import check_shape, rearrange_many + +from rotary_embedding_torch import RotaryEmbedding + +from .text import tokenize, bert_embed, BERT_MODEL_DIM +from torch.utils.data import Dataset, DataLoader +from ..vq_gan_3d.model.vqgan import VQGAN + +import matplotlib.pyplot as plt + +# helpers functions + + +def exists(x): + return x is not None + + +def noop(*args, **kwargs): + pass + + +def is_odd(n): + return (n % 2) == 1 + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +def cycle(dl): + while True: + for data in dl: + yield data + + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + + +def prob_mask_like(shape, prob, device): + if prob == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif prob == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob + + +def is_list_str(x): + if not isinstance(x, (list, tuple)): + return False + return all([type(el) == str for el in x]) + +# relative positional bias + + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads=8, + num_buckets=32, + max_distance=128 + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / + max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min( + val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype=torch.long, device=device) + k_pos = torch.arange(n, dtype=torch.long, device=device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket( + rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') + +# small helper modules + + +class EMA(): + def __init__(self, beta): + super().__init__() + self.beta = beta + + def update_model_average(self, ma_model, current_model): + for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + old_weight, up_weight = ma_params.data, current_params.data + ma_params.data = self.update_average(old_weight, up_weight) + + def update_average(self, old, new): + if old is None: + return new + return old * self.beta + (1 - self.beta) * new + + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +def Upsample(dim): + return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) + + +def Downsample(dim): + return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) + + +class LayerNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1)) + + def forward(self, x): + var = torch.var(x, dim=1, unbiased=False, keepdim=True) + mean = torch.mean(x, dim=1, keepdim=True) + return (x - mean) / (var + self.eps).sqrt() * self.gamma + + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = LayerNorm(dim) + + def forward(self, x, **kwargs): + x = self.norm(x) + return self.fn(x, **kwargs) + +# building block modules + + +class Block(nn.Module): + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding=(0, 1, 1)) + self.norm = nn.GroupNorm(groups, dim_out) + self.act = nn.SiLU() + + def forward(self, x, scale_shift=None): + x = self.proj(x) + x = self.norm(x) + + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + + return self.act(x) + + +class ResnetBlock(nn.Module): + def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): + super().__init__() + self.mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(time_emb_dim, dim_out * 2) + ) if exists(time_emb_dim) else None + + self.block1 = Block(dim, dim_out, groups=groups) + self.block2 = Block(dim_out, dim_out, groups=groups) + self.res_conv = nn.Conv3d( + dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb=None): + + scale_shift = None + if exists(self.mlp): + assert exists(time_emb), 'time emb must be passed in' + time_emb = self.mlp(time_emb) + time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') + scale_shift = time_emb.chunk(2, dim=1) + + h = self.block1(x, scale_shift=scale_shift) + + h = self.block2(h) + return h + self.res_conv(x) + + +class SpatialLinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, f, h, w = x.shape + x = rearrange(x, 'b c f h w -> (b f) c h w') + + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = rearrange_many( + qkv, 'b (h c) x y -> b h c (x y)', h=self.heads) + + q = q.softmax(dim=-2) + k = k.softmax(dim=-1) + + q = q * self.scale + context = torch.einsum('b h d n, b h e n -> b h d e', k, v) + + out = torch.einsum('b h d e, b h d n -> b h e n', context, q) + out = rearrange(out, 'b h c (x y) -> b (h c) x y', + h=self.heads, x=h, y=w) + out = self.to_out(out) + return rearrange(out, '(b f) c h w -> b c f h w', b=b) + +# attention along space and time + + +class EinopsToAndFrom(nn.Module): + def __init__(self, from_einops, to_einops, fn): + super().__init__() + self.from_einops = from_einops + self.to_einops = to_einops + self.fn = fn + + def forward(self, x, **kwargs): + shape = x.shape + reconstitute_kwargs = dict( + tuple(zip(self.from_einops.split(' '), shape))) + x = rearrange(x, f'{self.from_einops} -> {self.to_einops}') + x = self.fn(x, **kwargs) + x = rearrange( + x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + heads=4, + dim_head=32, + rotary_emb=None + ): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + + self.rotary_emb = rotary_emb + self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False) + self.to_out = nn.Linear(hidden_dim, dim, bias=False) + + def forward( + self, + x, + pos_bias=None, + focus_present_mask=None + ): + n, device = x.shape[-2], x.device + + qkv = self.to_qkv(x).chunk(3, dim=-1) + + if exists(focus_present_mask) and focus_present_mask.all(): + # if all batch samples are focusing on present + # it would be equivalent to passing that token's values through to the output + values = qkv[-1] + return self.to_out(values) + + # split out heads + + q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads) + + # scale + + q = q * self.scale + + # rotate positions into queries and keys for time attention + + if exists(self.rotary_emb): + q = self.rotary_emb.rotate_queries_or_keys(q) + k = self.rotary_emb.rotate_queries_or_keys(k) + + # similarity + + sim = einsum('... h i d, ... h j d -> ... h i j', q, k) + + # relative positional bias + + if exists(pos_bias): + sim = sim + pos_bias + + if exists(focus_present_mask) and not (~focus_present_mask).all(): + attend_all_mask = torch.ones( + (n, n), device=device, dtype=torch.bool) + attend_self_mask = torch.eye(n, device=device, dtype=torch.bool) + + mask = torch.where( + rearrange(focus_present_mask, 'b -> b 1 1 1 1'), + rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), + rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), + ) + + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + # numerical stability + + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + # aggregate values + + out = einsum('... h i j, ... h j d -> ... h i d', attn, v) + out = rearrange(out, '... h n d -> ... n (h d)') + return self.to_out(out) + +# model + + +class Unet3D(nn.Module): + def __init__( + self, + dim, + cond_dim=None, + out_dim=None, + dim_mults=(1, 2, 4, 8), + channels=3, + attn_heads=8, + attn_dim_head=32, + use_bert_text_cond=False, + init_dim=None, + init_kernel_size=7, + use_sparse_linear_attn=True, + block_type='resnet', + resnet_groups=8 + ): + super().__init__() + self.channels = channels + + # temporal attention and its relative positional encoding + + rotary_emb = RotaryEmbedding(min(32, attn_dim_head)) + + def temporal_attn(dim): return EinopsToAndFrom('b c f h w', 'b (h w) f c', Attention( + dim, heads=attn_heads, dim_head=attn_dim_head, rotary_emb=rotary_emb)) + + # realistically will not be able to generate that many frames of video... yet + self.time_rel_pos_bias = RelativePositionBias( + heads=attn_heads, max_distance=32) + + # initial conv + + init_dim = default(init_dim, dim) + assert is_odd(init_kernel_size) + + init_padding = init_kernel_size // 2 + self.init_conv = nn.Conv3d(channels, init_dim, (1, init_kernel_size, + init_kernel_size), padding=(0, init_padding, init_padding)) + + self.init_temporal_attn = Residual( + PreNorm(init_dim, temporal_attn(init_dim))) + + # dimensions + + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + # time conditioning + + time_dim = dim * 4 + self.time_mlp = nn.Sequential( + SinusoidalPosEmb(dim), + nn.Linear(dim, time_dim), + nn.GELU(), + nn.Linear(time_dim, time_dim) + ) + + # text conditioning + + self.has_cond = exists(cond_dim) or use_bert_text_cond + cond_dim = BERT_MODEL_DIM if use_bert_text_cond else cond_dim + + self.null_cond_emb = nn.Parameter( + torch.randn(1, cond_dim)) if self.has_cond else None + + cond_dim = time_dim + int(cond_dim or 0) + + # layers + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + + num_resolutions = len(in_out) + # block type + + block_klass = partial(ResnetBlock, groups=resnet_groups) + block_klass_cond = partial(block_klass, time_emb_dim=cond_dim) + + # modules for all layers + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + + self.downs.append(nn.ModuleList([ + block_klass_cond(dim_in, dim_out), + block_klass_cond(dim_out, dim_out), + Residual(PreNorm(dim_out, SpatialLinearAttention( + dim_out, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(), + Residual(PreNorm(dim_out, temporal_attn(dim_out))), + Downsample(dim_out) if not is_last else nn.Identity() + ])) + + mid_dim = dims[-1] + self.mid_block1 = block_klass_cond(mid_dim, mid_dim) + + spatial_attn = EinopsToAndFrom( + 'b c f h w', 'b f (h w) c', Attention(mid_dim, heads=attn_heads)) + + self.mid_spatial_attn = Residual(PreNorm(mid_dim, spatial_attn)) + self.mid_temporal_attn = Residual( + PreNorm(mid_dim, temporal_attn(mid_dim))) + + self.mid_block2 = block_klass_cond(mid_dim, mid_dim) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): + is_last = ind >= (num_resolutions - 1) + + self.ups.append(nn.ModuleList([ + block_klass_cond(dim_out * 2, dim_in), + block_klass_cond(dim_in, dim_in), + Residual(PreNorm(dim_in, SpatialLinearAttention( + dim_in, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(), + Residual(PreNorm(dim_in, temporal_attn(dim_in))), + Upsample(dim_in) if not is_last else nn.Identity() + ])) + + out_dim = default(out_dim, channels) + self.final_conv = nn.Sequential( + block_klass(dim * 2, dim), + nn.Conv3d(dim, out_dim, 1) + ) + + def forward_with_cond_scale( + self, + *args, + cond_scale=2., + **kwargs + ): + logits = self.forward(*args, null_cond_prob=0., **kwargs) + if cond_scale == 1 or not self.has_cond: + return logits + + null_logits = self.forward(*args, null_cond_prob=1., **kwargs) + return null_logits + (logits - null_logits) * cond_scale + + def forward( + self, + x, + time, + cond=None, + null_cond_prob=0., + focus_present_mask=None, + # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time) + prob_focus_present=0. + ): + assert not (self.has_cond and not exists(cond) + ), 'cond must be passed in if cond_dim specified' + x = torch.cat([x, cond], dim=1) + + batch, device = x.shape[0], x.device + + focus_present_mask = default(focus_present_mask, lambda: prob_mask_like( + (batch,), prob_focus_present, device=device)) + + time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device=x.device) + + x = self.init_conv(x) + r = x.clone() + + x = self.init_temporal_attn(x, pos_bias=time_rel_pos_bias) + + t = self.time_mlp(time) if exists(self.time_mlp) else None # [2, 128] + + # classifier free guidance + + if self.has_cond: + batch, device = x.shape[0], x.device + mask = prob_mask_like((batch,), null_cond_prob, device=device) + cond = torch.where(rearrange(mask, 'b -> b 1'), + self.null_cond_emb, cond) + t = torch.cat((t, cond), dim=-1) + + h = [] + + for block1, block2, spatial_attn, temporal_attn, downsample in self.downs: + x = block1(x, t) + x = block2(x, t) + x = spatial_attn(x) + x = temporal_attn(x, pos_bias=time_rel_pos_bias, + focus_present_mask=focus_present_mask) + h.append(x) + x = downsample(x) + + # [2, 256, 32, 4, 4] + x = self.mid_block1(x, t) + x = self.mid_spatial_attn(x) + x = self.mid_temporal_attn( + x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask) + x = self.mid_block2(x, t) + + for block1, block2, spatial_attn, temporal_attn, upsample in self.ups: + x = torch.cat((x, h.pop()), dim=1) + x = block1(x, t) + x = block2(x, t) + x = spatial_attn(x) + x = temporal_attn(x, pos_bias=time_rel_pos_bias, + focus_present_mask=focus_present_mask) + x = upsample(x) + + x = torch.cat((x, r), dim=1) + return self.final_conv(x) + +# gaussian diffusion trainer class + + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = torch.linspace(0, timesteps, steps, dtype=torch.float64) + alphas_cumprod = torch.cos( + ((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return torch.clip(betas, 0, 0.9999) + + +class GaussianDiffusion(nn.Module): + def __init__( + self, + denoise_fn, + *, + image_size, + num_frames, + text_use_bert_cls=False, + channels=3, + timesteps=1000, + loss_type='l1', + use_dynamic_thres=False, # from the Imagen paper + dynamic_thres_percentile=0.9, + vqgan_ckpt=None, + device=None + ): + super().__init__() + self.channels = channels + self.image_size = image_size + self.num_frames = num_frames + self.denoise_fn = denoise_fn + self.device = device + + if vqgan_ckpt: + self.vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt).cuda() + self.vqgan.eval() + else: + self.vqgan = None + + betas = cosine_beta_schedule(timesteps) + + alphas = 1. - betas + alphas_cumprod = torch.cumprod(alphas, axis=0) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.loss_type = loss_type + + # register buffer helper function that casts float64 to float32 + + def register_buffer(name, val): return self.register_buffer( + name, val.to(torch.float32)) + + register_buffer('betas', betas) + register_buffer('alphas_cumprod', alphas_cumprod) + register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) + + # calculations for diffusion q(x_t | x_{t-1}) and others + + register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) + register_buffer('sqrt_one_minus_alphas_cumprod', + torch.sqrt(1. - alphas_cumprod)) + register_buffer('log_one_minus_alphas_cumprod', + torch.log(1. - alphas_cumprod)) + register_buffer('sqrt_recip_alphas_cumprod', + torch.sqrt(1. / alphas_cumprod)) + register_buffer('sqrt_recipm1_alphas_cumprod', + torch.sqrt(1. / alphas_cumprod - 1)) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + + posterior_variance = betas * \ + (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + + register_buffer('posterior_variance', posterior_variance) + + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + + register_buffer('posterior_log_variance_clipped', + torch.log(posterior_variance.clamp(min=1e-20))) + register_buffer('posterior_mean_coef1', betas * + torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) + register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) + * torch.sqrt(alphas) / (1. - alphas_cumprod)) + + # text conditioning parameters + + self.text_use_bert_cls = text_use_bert_cls + + # dynamic thresholding when sampling + + self.use_dynamic_thres = use_dynamic_thres + self.dynamic_thres_percentile = dynamic_thres_percentile + + def q_mean_variance(self, x_start, t): + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract(1. - self.alphas_cumprod, t, x_start.shape) + log_variance = extract( + self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract( + self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool, cond=None, cond_scale=1.): + x_recon = self.predict_start_from_noise( + x, t=t, noise=self.denoise_fn.forward_with_cond_scale(x, t, cond=cond, cond_scale=cond_scale)) + + if clip_denoised: + s = 1. + if self.use_dynamic_thres: + s = torch.quantile( + rearrange(x_recon, 'b ... -> b (...)').abs(), + self.dynamic_thres_percentile, + dim=-1 + ) + + s.clamp_(min=1.) + s = s.view(-1, *((1,) * (x_recon.ndim - 1))) + + # clip by threshold, depending on whether static or dynamic + x_recon = x_recon.clamp(-s, s) / s + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.inference_mode() + def p_sample(self, x, t, cond=None, cond_scale=1., clip_denoised=True): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance( + x=x, t=t, clip_denoised=clip_denoised, cond=cond, cond_scale=cond_scale) + noise = torch.randn_like(x) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, + *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.inference_mode() + def p_sample_loop(self, shape, cond=None, cond_scale=1.): + device = self.betas.device + + b = shape[0] + img = torch.randn(shape, device=device) + # print('cond', cond.shape) + for i in reversed(range(0, self.num_timesteps)): + img = self.p_sample(img, torch.full( + (b,), i, device=device, dtype=torch.long), cond=cond, cond_scale=cond_scale) + + return img + + @torch.inference_mode() + def sample(self, cond=None, cond_scale=1., batch_size=16): + device = next(self.denoise_fn.parameters()).device + + if is_list_str(cond): + cond = bert_embed(tokenize(cond)).to(device) + + # batch_size = cond.shape[0] if exists(cond) else batch_size + batch_size = batch_size + image_size = self.image_size + channels = 8 # self.channels + num_frames = self.num_frames + # print((batch_size, channels, num_frames, image_size, image_size)) + # print('cond_',cond.shape) + _sample = self.p_sample_loop( + (batch_size, channels, num_frames, image_size, image_size), cond=cond, cond_scale=cond_scale) + + if isinstance(self.vqgan, VQGAN): + # denormalize TODO: Remove eventually + _sample = (((_sample + 1.0) / 2.0) * (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) + self.vqgan.codebook.embeddings.min() + + _sample = self.vqgan.decode(_sample, quantize=True) + else: + unnormalize_img(_sample) + + return _sample + + @torch.inference_mode() + def interpolate(self, x1, x2, t=None, lam=0.5): + b, *_, device = *x1.shape, x1.device + t = default(t, self.num_timesteps - 1) + + assert x1.shape == x2.shape + + t_batched = torch.stack([torch.tensor(t, device=device)] * b) + xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) + + img = (1 - lam) * xt1 + lam * xt2 + for i in reversed(range(0, t)): + img = self.p_sample(img, torch.full( + (b,), i, device=device, dtype=torch.long)) + + return img + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, + t, x_start.shape) * noise + ) + + def p_losses(self, x_start, t, cond=None, noise=None, **kwargs): + b, c, f, h, w, device = *x_start.shape, x_start.device + noise = default(noise, lambda: torch.randn_like(x_start)) + # breakpoint() + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # [2, 8, 32, 32, 32] + + if is_list_str(cond): + cond = bert_embed( + tokenize(cond), return_cls_repr=self.text_use_bert_cls) + cond = cond.to(device) + + x_recon = self.denoise_fn(x_noisy, t, cond=cond, **kwargs) + + if self.loss_type == 'l1': + loss = F.l1_loss(noise, x_recon) + elif self.loss_type == 'l2': + loss = F.mse_loss(noise, x_recon) + else: + raise NotImplementedError() + + return loss + + def forward(self, x, *args, **kwargs): + bs = int(x.shape[0]/2) + img=x[:bs,...] + mask=x[bs:,...] + mask_=(1-mask).detach() + masked_img = (img*mask_).detach() + masked_img=masked_img.permute(0,1,-1,-3,-2) + img=img.permute(0,1,-1,-3,-2) + mask=mask.permute(0,1,-1,-3,-2) + # breakpoint() + if isinstance(self.vqgan, VQGAN): + with torch.no_grad(): + img = self.vqgan.encode( + img, quantize=False, include_embeddings=True) + # normalize to -1 and 1 + img = ((img - self.vqgan.codebook.embeddings.min()) / + (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + masked_img = self.vqgan.encode( + masked_img, quantize=False, include_embeddings=True) + # normalize to -1 and 1 + masked_img = ((masked_img - self.vqgan.codebook.embeddings.min()) / + (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + else: + print("Hi") + img = normalize_img(img) + masked_img = normalize_img(masked_img) + mask = mask*2.0 - 1.0 + cc = torch.nn.functional.interpolate(mask, size=masked_img.shape[-3:]) + cond = torch.cat((masked_img, cc), dim=1) + + b, device, img_size, = img.shape[0], img.device, self.image_size + t = torch.randint(0, self.num_timesteps, (b,), device=device).long() + # breakpoint() + return self.p_losses(img, t, cond=cond, *args, **kwargs) + +# trainer class + + +CHANNELS_TO_MODE = { + 1: 'L', + 3: 'RGB', + 4: 'RGBA' +} + + +def seek_all_images(img, channels=3): + assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid' + mode = CHANNELS_TO_MODE[channels] + + i = 0 + while True: + try: + img.seek(i) + yield img.convert(mode) + except EOFError: + break + i += 1 + +# tensor of shape (channels, frames, height, width) -> gif + + +def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True): + tensor = ((tensor - tensor.min()) / (tensor.max() - tensor.min())) * 1.0 + images = map(T.ToPILImage(), tensor.unbind(dim=1)) + first_img, *rest_imgs = images + first_img.save(path, save_all=True, append_images=rest_imgs, + duration=duration, loop=loop, optimize=optimize) + return images + +# gif -> (channels, frame, height, width) tensor + + +def gif_to_tensor(path, channels=3, transform=T.ToTensor()): + img = Image.open(path) + tensors = tuple(map(transform, seek_all_images(img, channels=channels))) + return torch.stack(tensors, dim=1) + + +def identity(t, *args, **kwargs): + return t + + +def normalize_img(t): + return t * 2 - 1 + + +def unnormalize_img(t): + return (t + 1) * 0.5 + + +def cast_num_frames(t, *, frames): + f = t.shape[1] + + if f == frames: + return t + + if f > frames: + return t[:, :frames] + + return F.pad(t, (0, 0, 0, 0, 0, frames - f)) + + +class Dataset(data.Dataset): + def __init__( + self, + folder, + image_size, + channels=3, + num_frames=16, + horizontal_flip=False, + force_num_frames=True, + exts=['gif'] + ): + super().__init__() + self.folder = folder + self.image_size = image_size + self.channels = channels + self.paths = [p for ext in exts for p in Path( + f'{folder}').glob(f'**/*.{ext}')] + + self.cast_num_frames_fn = partial( + cast_num_frames, frames=num_frames) if force_num_frames else identity + + self.transform = T.Compose([ + T.Resize(image_size), + T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity), + T.CenterCrop(image_size), + T.ToTensor() + ]) + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = self.paths[index] + tensor = gif_to_tensor(path, self.channels, transform=self.transform) + return self.cast_num_frames_fn(tensor) + +# trainer class + + +class Tester(object): + def __init__( + self, + diffusion_model, + ): + super().__init__() + self.model = diffusion_model + self.ema_model = copy.deepcopy(self.model) + self.step=0 + self.image_size = diffusion_model.image_size + + self.reset_parameters() + + def reset_parameters(self): + self.ema_model.load_state_dict(self.model.state_dict()) + + + def load(self, milestone, map_location=None, **kwargs): + if milestone == -1: + all_milestones = [int(p.stem.split('-')[-1]) + for p in Path(self.results_folder).glob('**/*.pt')] + assert len( + all_milestones) > 0, 'need to have at least one milestone to load from latest checkpoint (milestone == -1)' + milestone = max(all_milestones) + + if map_location: + data = torch.load(milestone, map_location=map_location) + else: + data = torch.load(milestone) + + self.step = data['step'] + self.model.load_state_dict(data['model'], **kwargs) + self.ema_model.load_state_dict(data['ema'], **kwargs) + + diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/text.py b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/text.py new file mode 100644 index 0000000000000000000000000000000000000000..42dca78bec5075fb4f59c522aa3d00cc395d7536 --- /dev/null +++ b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/text.py @@ -0,0 +1,94 @@ +"Taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch" + +import torch +from einops import rearrange + + +def exists(val): + return val is not None + +# singleton globals + + +MODEL = None +TOKENIZER = None +BERT_MODEL_DIM = 768 + + +def get_tokenizer(): + global TOKENIZER + if not exists(TOKENIZER): + TOKENIZER = torch.hub.load( + 'huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased') + return TOKENIZER + + +def get_bert(): + global MODEL + if not exists(MODEL): + MODEL = torch.hub.load( + 'huggingface/pytorch-transformers', 'model', 'bert-base-cased') + if torch.cuda.is_available(): + MODEL = MODEL.cuda() + + return MODEL + +# tokenize + + +def tokenize(texts, add_special_tokens=True): + if not isinstance(texts, (list, tuple)): + texts = [texts] + + tokenizer = get_tokenizer() + + encoding = tokenizer.batch_encode_plus( + texts, + add_special_tokens=add_special_tokens, + padding=True, + return_tensors='pt' + ) + + token_ids = encoding.input_ids + return token_ids + +# embedding function + + +@torch.no_grad() +def bert_embed( + token_ids, + return_cls_repr=False, + eps=1e-8, + pad_id=0. +): + model = get_bert() + mask = token_ids != pad_id + + if torch.cuda.is_available(): + token_ids = token_ids.cuda() + mask = mask.cuda() + + outputs = model( + input_ids=token_ids, + attention_mask=mask, + output_hidden_states=True + ) + + hidden_state = outputs.hidden_states[-1] + + if return_cls_repr: + # return [cls] as representation + return hidden_state[:, 0] + + if not exists(mask): + return hidden_state.mean(dim=1) + + # mean all tokens excluding [cls], accounting for length + mask = mask[:, 1:] + mask = rearrange(mask, 'b n -> b n 1') + + numer = (hidden_state[:, 1:] * mask).sum(dim=1) + denom = mask.sum(dim=1) + masked_mean = numer / (denom + eps) + return diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/time_embedding.py b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/time_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..7c911fa4f485295005cda9c9c2099db3ddbda15c --- /dev/null +++ b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/time_embedding.py @@ -0,0 +1,75 @@ +import math +import torch +import torch.nn as nn + +from monai.networks.layers.utils import get_act_layer + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, emb_dim=16, downscale_freq_shift=1, max_period=10000, flip_sin_to_cos=False): + super().__init__() + self.emb_dim = emb_dim + self.downscale_freq_shift = downscale_freq_shift + self.max_period = max_period + self.flip_sin_to_cos = flip_sin_to_cos + + def forward(self, x): + device = x.device + half_dim = self.emb_dim // 2 + emb = math.log(self.max_period) / \ + (half_dim - self.downscale_freq_shift) + emb = torch.exp(-emb*torch.arange(half_dim, device=device)) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + + if self.flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + if self.emb_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class LearnedSinusoidalPosEmb(nn.Module): + """ following @crowsonkb 's lead with learned sinusoidal pos emb """ + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, emb_dim): + super().__init__() + self.emb_dim = emb_dim + half_dim = emb_dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x): + x = x[:, None] + freqs = x * self.weights[None, :] * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + if self.emb_dim % 2 == 1: + fouriered = torch.nn.functional.pad(fouriered, (0, 1, 0, 0)) + return fouriered + + +class TimeEmbbeding(nn.Module): + def __init__( + self, + emb_dim=64, + pos_embedder=SinusoidalPosEmb, + pos_embedder_kwargs={}, + act_name=("SWISH", {}) # Swish = SiLU + ): + super().__init__() + self.emb_dim = emb_dim + self.pos_emb_dim = pos_embedder_kwargs.get('emb_dim', emb_dim//4) + pos_embedder_kwargs['emb_dim'] = self.pos_emb_dim + self.pos_embedder = pos_embedder(**pos_embedder_kwargs) + + self.time_emb = nn.Sequential( + self.pos_embedder, + nn.Linear(self.pos_emb_dim, self.emb_dim), + get_act_layer(act_name), + nn.Linear(self.emb_dim, self.emb_dim) + ) + + def forward(self, time): + return self.time_emb(time) diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/unet.py b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..e9cae8021e874a92b6baaaf4decd802516c2dc87 --- /dev/null +++ b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/unet.py @@ -0,0 +1,226 @@ +from ddpm.time_embedding import TimeEmbbeding + +import monai.networks.nets as nets +import torch +import torch.nn as nn +from einops import rearrange + +from monai.networks.blocks import UnetBasicBlock, UnetResBlock, UnetUpBlock, Convolution, UnetOutBlock +from monai.networks.layers.utils import get_act_layer + + +class DownBlock(nn.Module): + def __init__( + self, + spatial_dims, + in_ch, + out_ch, + time_emb_dim, + cond_emb_dim, + act_name=("swish", {}), + **kwargs): + super(DownBlock, self).__init__() + self.loca_time_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(time_emb_dim, in_ch) # in_ch * 2 + ) + if cond_emb_dim is not None: + self.loca_cond_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(cond_emb_dim, in_ch), + ) + self.down_op = UnetBasicBlock( + spatial_dims, in_ch, out_ch, act_name=act_name, **kwargs) + + def forward(self, x, time_emb, cond_emb): + b, c, *_ = x.shape + sp_dim = x.ndim-2 + + # ------------ Time ---------- + time_emb = self.loca_time_embedder(time_emb) + time_emb = time_emb.reshape(b, c, *((1,)*sp_dim)) + # scale, shift = time_emb.chunk(2, dim = 1) + + # ------------ Combine ------------ + # x = x * (scale + 1) + shift + x = x + time_emb + + # ----------- Condition ------------ + if cond_emb is not None: + cond_emb = self.loca_cond_embedder(cond_emb) + cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim)) + x = x + cond_emb + + # ----------- Image --------- + y = self.down_op(x) + return y + + +class UpBlock(nn.Module): + def __init__( + self, + spatial_dims, + skip_ch, + enc_ch, + time_emb_dim, + cond_emb_dim, + act_name=("swish", {}), + **kwargs): + super(UpBlock, self).__init__() + self.up_op = UnetUpBlock(spatial_dims, enc_ch, + skip_ch, act_name=act_name, **kwargs) + self.loca_time_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(time_emb_dim, skip_ch * 2), + ) + if cond_emb_dim is not None: + self.loca_cond_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(cond_emb_dim, skip_ch * 2), + ) + + def forward(self, x_skip, x_enc, time_emb, cond_emb): + b, c, *_ = x_enc.shape + sp_dim = x_enc.ndim-2 + + # ----------- Time -------------- + time_emb = self.loca_time_embedder(time_emb) + time_emb = time_emb.reshape(b, c, *((1,)*sp_dim)) + # scale, shift = time_emb.chunk(2, dim = 1) + + # -------- Combine ------------- + # y = x * (scale + 1) + shift + x_enc = x_enc + time_emb + + # ----------- Condition ------------ + if cond_emb is not None: + cond_emb = self.loca_cond_embedder(cond_emb) + cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim)) + x_enc = x_enc + cond_emb + + # ----------- Image ------------- + y = self.up_op(x_enc, x_skip) + + # -------- Combine ------------- + # y = y * (scale + 1) + shift + + return y + + +class UNet(nn.Module): + + def __init__(self, + in_ch=1, + out_ch=1, + spatial_dims=3, + hid_chs=[32, 64, 128, 256, 512], + kernel_sizes=[(1, 3, 3), (1, 3, 3), (1, 3, 3), 3, 3], + strides=[1, (1, 2, 2), (1, 2, 2), 2, 2], + upsample_kernel_sizes=None, + act_name=("SWISH", {}), + norm_name=("INSTANCE", {"affine": True}), + time_embedder=TimeEmbbeding, + time_embedder_kwargs={}, + cond_embedder=None, + cond_embedder_kwargs={}, + # True = all but last layer, 0/False=disable, 1=only first layer, ... + deep_ver_supervision=True, + estimate_variance=False, + use_self_conditioning=False, + **kwargs + ): + super().__init__() + if upsample_kernel_sizes is None: + upsample_kernel_sizes = strides[1:] + + # ------------- Time-Embedder----------- + self.time_embedder = time_embedder(**time_embedder_kwargs) + + # ------------- Condition-Embedder----------- + if cond_embedder is not None: + self.cond_embedder = cond_embedder(**cond_embedder_kwargs) + cond_emb_dim = self.cond_embedder.emb_dim + else: + self.cond_embedder = None + cond_emb_dim = None + + # ----------- In-Convolution ------------ + in_ch = in_ch*2 if use_self_conditioning else in_ch + self.inc = UnetBasicBlock(spatial_dims, in_ch, hid_chs[0], kernel_size=kernel_sizes[0], stride=strides[0], + act_name=act_name, norm_name=norm_name, **kwargs) + + # ----------- Encoder ---------------- + self.encoders = nn.ModuleList([ + DownBlock(spatial_dims, hid_chs[i-1], hid_chs[i], time_emb_dim=self.time_embedder.emb_dim, + cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[ + i], stride=strides[i], act_name=act_name, + norm_name=norm_name, **kwargs) + for i in range(1, len(strides)) + ]) + + # ------------ Decoder ---------- + self.decoders = nn.ModuleList([ + UpBlock(spatial_dims, hid_chs[i], hid_chs[i+1], time_emb_dim=self.time_embedder.emb_dim, + cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[i + + 1], stride=strides[i+1], act_name=act_name, + norm_name=norm_name, upsample_kernel_size=upsample_kernel_sizes[i], **kwargs) + for i in range(len(strides)-1) + ]) + + # --------------- Out-Convolution ---------------- + out_ch_hor = out_ch*2 if estimate_variance else out_ch + self.outc = UnetOutBlock( + spatial_dims, hid_chs[0], out_ch_hor, dropout=None) + if isinstance(deep_ver_supervision, bool): + deep_ver_supervision = len( + strides)-2 if deep_ver_supervision else 0 + self.outc_ver = nn.ModuleList([ + UnetOutBlock(spatial_dims, hid_chs[i], out_ch, dropout=None) + for i in range(1, deep_ver_supervision+1) + ]) + + def forward(self, x_t, t, cond=None, self_cond=None, **kwargs): + condition = cond + # x_t [B, C, (D), H, W] + # t [B,] + + # -------- In-Convolution -------------- + x = [None for _ in range(len(self.encoders)+1)] + x_t = torch.cat([x_t, self_cond], + dim=1) if self_cond is not None else x_t + x[0] = self.inc(x_t) + + # -------- Time Embedding (Gloabl) ----------- + time_emb = self.time_embedder(t) # [B, C] + + # -------- Condition Embedding (Gloabl) ----------- + if (condition is None) or (self.cond_embedder is None): + cond_emb = None + else: + cond_emb = self.cond_embedder(condition) # [B, C] + + # --------- Encoder -------------- + for i in range(len(self.encoders)): + x[i+1] = self.encoders[i](x[i], time_emb, cond_emb) + + # -------- Decoder ----------- + for i in range(len(self.decoders), 0, -1): + x[i-1] = self.decoders[i-1](x[i-1], x[i], time_emb, cond_emb) + + # ---------Out-Convolution ------------ + y_hor = self.outc(x[0]) + y_ver = [outc_ver_i(x[i+1]) + for i, outc_ver_i in enumerate(self.outc_ver)] + + return y_hor # , y_ver + + def forward_with_cond_scale(self, *args, cond_scale=0., **kwargs): + return self.forward(*args, **kwargs) + + +if __name__ == '__main__': + model = UNet(in_ch=3) + input = torch.randn((1, 3, 16, 128, 128)) + time = torch.randn((1,)) + out_hor, out_ver = model(input, time) + print(out_hor[0].shape) diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/util.py b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..2047cf56802046bbf37cfd33f819ed75a239a7df --- /dev/null +++ b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/ddpm/util.py @@ -0,0 +1,271 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +# from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + if c != 1: + steps_out = ddim_timesteps + 1 + else: + steps_out = ddim_timesteps + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c9267820ac3147285b3dbc3eca053f259ca015c Binary files /dev/null and b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__init__.py b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38c08845c1ee2f7b4b52bf95fa282b0808931a3a --- /dev/null +++ b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__init__.py @@ -0,0 +1,3 @@ +from .vqgan import VQGAN +from .codebook import Codebook +from .lpips import LPIPS diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5af25919e8562ffd0096c9a9795af5e6311d5dc Binary files /dev/null and b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c14899b2127c37d8cdba304c2380cbf6fc3ccd3 Binary files /dev/null and b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6fd1b98c409d33c2783dcf7b43ee55e86cd775d Binary files /dev/null and b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..708f349ec739b22404f5242cc4945d38baf8872f Binary files /dev/null and b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth new file mode 100644 index 0000000000000000000000000000000000000000..f57dcf5cc764d61c8a460365847fb2137ff0a62d --- /dev/null +++ b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868 +size 7289 diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/codebook.py b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/codebook.py new file mode 100644 index 0000000000000000000000000000000000000000..c17ed701a2824cc0dcd75670d47a6ab3842e9d35 --- /dev/null +++ b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/codebook.py @@ -0,0 +1,109 @@ +""" Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from ..utils import shift_dim + + +class Codebook(nn.Module): + def __init__(self, n_codes, embedding_dim, no_random_restart=False, restart_thres=1.0): + super().__init__() + self.register_buffer('embeddings', torch.randn(n_codes, embedding_dim)) + self.register_buffer('N', torch.zeros(n_codes)) + self.register_buffer('z_avg', self.embeddings.data.clone()) + + self.n_codes = n_codes + self.embedding_dim = embedding_dim + self._need_init = True + self.no_random_restart = no_random_restart + self.restart_thres = restart_thres + + def _tile(self, x): + d, ew = x.shape + if d < self.n_codes: + n_repeats = (self.n_codes + d - 1) // d + std = 0.01 / np.sqrt(ew) + x = x.repeat(n_repeats, 1) + x = x + torch.randn_like(x) * std + return x + + def _init_embeddings(self, z): + # z: [b, c, t, h, w] + self._need_init = False + breakpoint() + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [65536, 8] [2, 8, 32, 32, 32] + y = self._tile(flat_inputs) # [65536, 8] + + d = y.shape[0] + _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + self.embeddings.data.copy_(_k_rand) + self.z_avg.data.copy_(_k_rand) + self.N.data.copy_(torch.ones(self.n_codes)) + + def forward(self, z): + # z: [b, c, t, h, w] + if self._need_init and self.training: + self._init_embeddings(z) + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [bthw, c] [65536, 8] + distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \ + - 2 * flat_inputs @ self.embeddings.t() \ + + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) # [bthw, c] [65536, 8] + + encoding_indices = torch.argmin(distances, dim=1) # [65536] + encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as( + flat_inputs) # [bthw, ncode] [65536, 16384] + encoding_indices = encoding_indices.view( + z.shape[0], *z.shape[2:]) # [b, t, h, w, ncode] [2, 32, 32, 32] + + embeddings = F.embedding( + encoding_indices, self.embeddings) # [b, t, h, w, c] self.embeddings [16384, 8] + embeddings = shift_dim(embeddings, -1, 1) # [b, c, t, h, w] [2, 8, 32, 32, 32] + + commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) + + # EMA codebook update + if self.training: + n_total = encode_onehot.sum(dim=0) # [16384] + encode_sum = flat_inputs.t() @ encode_onehot # [8, 16384] + if dist.is_initialized(): + dist.all_reduce(n_total) + dist.all_reduce(encode_sum) + + self.N.data.mul_(0.99).add_(n_total, alpha=0.01) + self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) + + n = self.N.sum() + weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n + encode_normalized = self.z_avg / weights.unsqueeze(1) + self.embeddings.data.copy_(encode_normalized) + + y = self._tile(flat_inputs) + _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + + if not self.no_random_restart: + usage = (self.N.view(self.n_codes, 1) + >= self.restart_thres).float() + self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) + + embeddings_st = (embeddings - z).detach() + z + + avg_probs = torch.mean(encode_onehot, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * + torch.log(avg_probs + 1e-10))) + + return dict(embeddings=embeddings_st, encodings=encoding_indices, + commitment_loss=commitment_loss, perplexity=perplexity) + + def dictionary_lookup(self, encodings): + embeddings = F.embedding(encodings, self.embeddings) + return embeddings diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/lpips.py b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..019434b9b29945d67e6e0a95dec324df7ff908f3 --- /dev/null +++ b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/lpips.py @@ -0,0 +1,181 @@ +"""Adapted from https://github.com/SongweiGe/TATS""" + +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + + +from collections import namedtuple +from torchvision import models +import torch.nn as nn +import torch +from tqdm import tqdm +import requests +import os +import hashlib +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format( + name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + # self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + self.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name is not "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + model.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer( + input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor( + outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) + for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor( + [-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor( + [.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, + padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple( + "VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, + h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x/(norm_factor+eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py new file mode 100644 index 0000000000000000000000000000000000000000..a64ec8878c0ecbf418775e2958b2cbf578ce2918 --- /dev/null +++ b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py @@ -0,0 +1,561 @@ +"""Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import math +import argparse +import numpy as np +import pickle as pkl + +import pytorch_lightning as pl +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from ..utils import shift_dim, adopt_weight, comp_getattr +from .lpips import LPIPS +from .codebook import Codebook + + +def silu(x): + return x*torch.sigmoid(x) + + +class SiLU(nn.Module): + def __init__(self): + super(SiLU, self).__init__() + + def forward(self, x): + return silu(x) + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1. - logits_real)) + loss_fake = torch.mean(F.relu(1. + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake))) + return d_loss + + +class VQGAN(pl.LightningModule): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.embedding_dim = cfg.model.embedding_dim # 8 + self.n_codes = cfg.model.n_codes # 16384 + + self.encoder = Encoder(cfg.model.n_hiddens, # 16 + cfg.model.downsample, # [2, 2, 2] + cfg.dataset.image_channels, # 1 + cfg.model.norm_type, # group + cfg.model.padding_type, # replicate + cfg.model.num_groups, # 32 + ) + self.decoder = Decoder( + cfg.model.n_hiddens, cfg.model.downsample, cfg.dataset.image_channels, cfg.model.norm_type, cfg.model.num_groups) + self.enc_out_ch = self.encoder.out_channels + self.pre_vq_conv = SamePadConv3d( + self.enc_out_ch, cfg.model.embedding_dim, 1, padding_type=cfg.model.padding_type) + self.post_vq_conv = SamePadConv3d( + cfg.model.embedding_dim, self.enc_out_ch, 1) + + self.codebook = Codebook(cfg.model.n_codes, cfg.model.embedding_dim, + no_random_restart=cfg.model.no_random_restart, restart_thres=cfg.model.restart_thres) + + self.gan_feat_weight = cfg.model.gan_feat_weight + # TODO: Changed batchnorm from sync to normal + self.image_discriminator = NLayerDiscriminator( + cfg.dataset.image_channels, cfg.model.disc_channels, cfg.model.disc_layers, norm_layer=nn.BatchNorm2d) + self.video_discriminator = NLayerDiscriminator3D( + cfg.dataset.image_channels, cfg.model.disc_channels, cfg.model.disc_layers, norm_layer=nn.BatchNorm3d) + + if cfg.model.disc_loss_type == 'vanilla': + self.disc_loss = vanilla_d_loss + elif cfg.model.disc_loss_type == 'hinge': + self.disc_loss = hinge_d_loss + + self.perceptual_model = LPIPS().eval() + + self.image_gan_weight = cfg.model.image_gan_weight + self.video_gan_weight = cfg.model.video_gan_weight + + self.perceptual_weight = cfg.model.perceptual_weight + + self.l1_weight = cfg.model.l1_weight + self.save_hyperparameters() + + def encode(self, x, include_embeddings=False, quantize=True): + h = self.pre_vq_conv(self.encoder(x)) + if quantize: + vq_output = self.codebook(h) + if include_embeddings: + return vq_output['embeddings'], vq_output['encodings'] + else: + return vq_output['encodings'] + return h + + def decode(self, latent, quantize=False): + if quantize: + vq_output = self.codebook(latent) + latent = vq_output['encodings'] + h = F.embedding(latent, self.codebook.embeddings) + h = self.post_vq_conv(shift_dim(h, -1, 1)) + return self.decoder(h) + + def forward(self, x, optimizer_idx=None, log_image=False): + B, C, T, H, W = x.shape + z = self.pre_vq_conv(self.encoder(x)) # [2, 32, 32, 32, 32] [2, 8, 32, 32, 32] + vq_output = self.codebook(z) # ['embeddings', 'encodings', 'commitment_loss', 'perplexity'] + x_recon = self.decoder(self.post_vq_conv(vq_output['embeddings'])) # [2, 8, 32, 32, 32] [2, 32, 32, 32, 32] + + recon_loss = F.l1_loss(x_recon, x) * self.l1_weight + + # Selects one random 2D image from each 3D Image + frame_idx = torch.randint(0, T, [B]).cuda() + frame_idx_selected = frame_idx.reshape(-1, + 1, 1, 1, 1).repeat(1, C, 1, H, W) # [2, 1, 1, 64, 64] + frames = torch.gather(x, 2, frame_idx_selected).squeeze(2) # [2, 1, 64, 64] + frames_recon = torch.gather(x_recon, 2, frame_idx_selected).squeeze(2) # [2, 1, 64, 64] + + if log_image: + return frames, frames_recon, x, x_recon + + if optimizer_idx == 0: + # Autoencoder - train the "generator" + + # Perceptual loss + perceptual_loss = 0 + if self.perceptual_weight > 0: + perceptual_loss = self.perceptual_model( + frames, frames_recon).mean() * self.perceptual_weight + + # Discriminator loss (turned on after a certain epoch) + logits_image_fake, pred_image_fake = self.image_discriminator( + frames_recon) + logits_video_fake, pred_video_fake = self.video_discriminator( + x_recon) + g_image_loss = -torch.mean(logits_image_fake) + g_video_loss = -torch.mean(logits_video_fake) + g_loss = self.image_gan_weight*g_image_loss + self.video_gan_weight*g_video_loss + disc_factor = adopt_weight( + self.global_step, threshold=self.cfg.model.discriminator_iter_start) + aeloss = disc_factor * g_loss + + # GAN feature matching loss - tune features such that we get the same prediction result on the discriminator + image_gan_feat_loss = 0 + video_gan_feat_loss = 0 + feat_weights = 4.0 / (3 + 1) + if self.image_gan_weight > 0: + logits_image_real, pred_image_real = self.image_discriminator( + frames) + for i in range(len(pred_image_fake)-1): + image_gan_feat_loss += feat_weights * \ + F.l1_loss(pred_image_fake[i], pred_image_real[i].detach( + )) * (self.image_gan_weight > 0) + if self.video_gan_weight > 0: + logits_video_real, pred_video_real = self.video_discriminator( + x) + for i in range(len(pred_video_fake)-1): + video_gan_feat_loss += feat_weights * \ + F.l1_loss(pred_video_fake[i], pred_video_real[i].detach( + )) * (self.video_gan_weight > 0) + gan_feat_loss = disc_factor * self.gan_feat_weight * \ + (image_gan_feat_loss + video_gan_feat_loss) + + self.log("train/g_image_loss", g_image_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/g_video_loss", g_video_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/image_gan_feat_loss", image_gan_feat_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/video_gan_feat_loss", video_gan_feat_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/perceptual_loss", perceptual_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log("train/recon_loss", recon_loss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + self.log("train/aeloss", aeloss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + self.log("train/commitment_loss", vq_output['commitment_loss'], + prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log('train/perplexity', vq_output['perplexity'], + prog_bar=True, logger=True, on_step=True, on_epoch=True) + return recon_loss, x_recon, vq_output, aeloss, perceptual_loss, gan_feat_loss + + if optimizer_idx == 1: + # Train discriminator + logits_image_real, _ = self.image_discriminator(frames.detach()) + logits_video_real, _ = self.video_discriminator(x.detach()) + + logits_image_fake, _ = self.image_discriminator( + frames_recon.detach()) + logits_video_fake, _ = self.video_discriminator(x_recon.detach()) + + d_image_loss = self.disc_loss(logits_image_real, logits_image_fake) + d_video_loss = self.disc_loss(logits_video_real, logits_video_fake) + disc_factor = adopt_weight( + self.global_step, threshold=self.cfg.model.discriminator_iter_start) + discloss = disc_factor * \ + (self.image_gan_weight*d_image_loss + + self.video_gan_weight*d_video_loss) + + self.log("train/logits_image_real", logits_image_real.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_image_fake", logits_image_fake.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_video_real", logits_video_real.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_video_fake", logits_video_fake.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/d_image_loss", d_image_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/d_video_loss", d_video_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/discloss", discloss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + return discloss + + perceptual_loss = self.perceptual_model( + frames, frames_recon) * self.perceptual_weight + return recon_loss, x_recon, vq_output, perceptual_loss + + def training_step(self, batch, batch_idx, optimizer_idx): + x = batch['image'] + if optimizer_idx == 0: + recon_loss, _, vq_output, aeloss, perceptual_loss, gan_feat_loss = self.forward( + x, optimizer_idx) + commitment_loss = vq_output['commitment_loss'] + loss = recon_loss + commitment_loss + aeloss + perceptual_loss + gan_feat_loss + if optimizer_idx == 1: + discloss = self.forward(x, optimizer_idx) + loss = discloss + return loss + + def validation_step(self, batch, batch_idx): + x = batch['image'] # TODO: batch['stft'] + recon_loss, _, vq_output, perceptual_loss = self.forward(x) + self.log('val/recon_loss', recon_loss, prog_bar=True) + self.log('val/perceptual_loss', perceptual_loss, prog_bar=True) + self.log('val/perplexity', vq_output['perplexity'], prog_bar=True) + self.log('val/commitment_loss', + vq_output['commitment_loss'], prog_bar=True) + + def configure_optimizers(self): + lr = self.cfg.model.lr + opt_ae = torch.optim.Adam(list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.pre_vq_conv.parameters()) + + list(self.post_vq_conv.parameters()) + + list(self.codebook.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(list(self.image_discriminator.parameters()) + + list(self.video_discriminator.parameters()), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def log_images(self, batch, **kwargs): + log = dict() + x = batch['image'] + x = x.to(self.device) + frames, frames_rec, _, _ = self(x, log_image=True) + log["inputs"] = frames + log["reconstructions"] = frames_rec + #log['mean_org'] = batch['mean_org'] + #log['std_org'] = batch['std_org'] + return log + + def log_videos(self, batch, **kwargs): + log = dict() + x = batch['image'] + _, _, x, x_rec = self(x, log_image=True) + log["inputs"] = x + log["reconstructions"] = x_rec + #log['mean_org'] = batch['mean_org'] + #log['std_org'] = batch['std_org'] + return log + + +def Normalize(in_channels, norm_type='group', num_groups=32): + assert norm_type in ['group', 'batch'] + if norm_type == 'group': + # TODO Changed num_groups from 32 to 8 + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == 'batch': + return torch.nn.SyncBatchNorm(in_channels) + + +class Encoder(nn.Module): + def __init__(self, n_hiddens, downsample, image_channel=3, norm_type='group', padding_type='replicate', num_groups=32): + super().__init__() + n_times_downsample = np.array([int(math.log2(d)) for d in downsample]) + self.conv_blocks = nn.ModuleList() + max_ds = n_times_downsample.max() + + self.conv_first = SamePadConv3d( + image_channel, n_hiddens, kernel_size=3, padding_type=padding_type) + + for i in range(max_ds): + block = nn.Module() + in_channels = n_hiddens * 2**i + out_channels = n_hiddens * 2**(i+1) + stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) + block.down = SamePadConv3d( + in_channels, out_channels, 4, stride=stride, padding_type=padding_type) + block.res = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + self.conv_blocks.append(block) + n_times_downsample -= 1 + + self.final_block = nn.Sequential( + Normalize(out_channels, norm_type, num_groups=num_groups), + SiLU() + ) + + self.out_channels = out_channels + + def forward(self, x): + h = self.conv_first(x) + for block in self.conv_blocks: + h = block.down(h) + h = block.res(h) + h = self.final_block(h) + return h + + +class Decoder(nn.Module): + def __init__(self, n_hiddens, upsample, image_channel, norm_type='group', num_groups=32): + super().__init__() + + n_times_upsample = np.array([int(math.log2(d)) for d in upsample]) + max_us = n_times_upsample.max() + + in_channels = n_hiddens*2**max_us + self.final_block = nn.Sequential( + Normalize(in_channels, norm_type, num_groups=num_groups), + SiLU() + ) + + self.conv_blocks = nn.ModuleList() + for i in range(max_us): + block = nn.Module() + in_channels = in_channels if i == 0 else n_hiddens*2**(max_us-i+1) + out_channels = n_hiddens*2**(max_us-i) + us = tuple([2 if d > 0 else 1 for d in n_times_upsample]) + block.up = SamePadConvTranspose3d( + in_channels, out_channels, 4, stride=us) + block.res1 = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + block.res2 = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + self.conv_blocks.append(block) + n_times_upsample -= 1 + + self.conv_last = SamePadConv3d( + out_channels, image_channel, kernel_size=3) + + def forward(self, x): + h = self.final_block(x) + for i, block in enumerate(self.conv_blocks): + h = block.up(h) + h = block.res1(h) + h = block.res2(h) + h = self.conv_last(h) + return h + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group', padding_type='replicate', num_groups=32): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, norm_type, num_groups=num_groups) + self.conv1 = SamePadConv3d( + in_channels, out_channels, kernel_size=3, padding_type=padding_type) + self.dropout = torch.nn.Dropout(dropout) + self.norm2 = Normalize(in_channels, norm_type, num_groups=num_groups) + self.conv2 = SamePadConv3d( + out_channels, out_channels, kernel_size=3, padding_type=padding_type) + if self.in_channels != self.out_channels: + self.conv_shortcut = SamePadConv3d( + in_channels, out_channels, kernel_size=3, padding_type=padding_type) + + def forward(self, x): + h = x + h = self.norm1(h) + h = silu(h) + h = self.conv1(h) + h = self.norm2(h) + h = silu(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.conv_shortcut(x) + + return x+h + + +# Does not support dilation +class SamePadConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + if isinstance(stride, int): + stride = (stride,) * 3 + + # assumes that the input shape is divisible by stride + total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) + pad_input = [] + for p in total_pad[::-1]: # reverse since F.pad starts from last dim + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + self.pad_input = pad_input + self.padding_type = padding_type + + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, + stride=stride, padding=0, bias=bias) + + def forward(self, x): + return self.conv(F.pad(x, self.pad_input, mode=self.padding_type)) + + +class SamePadConvTranspose3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + if isinstance(stride, int): + stride = (stride,) * 3 + + total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) + pad_input = [] + for p in total_pad[::-1]: # reverse since F.pad starts from last dim + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + self.pad_input = pad_input + self.padding_type = padding_type + + self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, + stride=stride, bias=bias, + padding=tuple([k - 1 for k in kernel_size])) + + def forward(self, x): + return self.convt(F.pad(x, self.pad_input, mode=self.padding_type)) + + +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True): + # def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=True): + super(NLayerDiscriminator, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, + stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, + stride=1, padding=padw)]] + + if use_sigmoid: + sequence += [[nn.Sigmoid()]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[-1], res[1:] + else: + return self.model(input), _ + + +class NLayerDiscriminator3D(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True): + super(NLayerDiscriminator3D, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv3d(input_nc, ndf, kernel_size=kw, + stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv3d(nf, 1, kernel_size=kw, + stride=1, padding=padw)]] + + if use_sigmoid: + sequence += [[nn.Sigmoid()]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[-1], res[1:] + else: + return self.model(input), _ diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/utils.py b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..acf1a21270d98d0a1ea9935d00f9f16d89d54551 --- /dev/null +++ b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/utils.py @@ -0,0 +1,177 @@ +""" Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import warnings +import torch +import imageio + +import math +import numpy as np + +import sys +import pdb as pdb_original +import logging + +import imageio.core.util +logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR) + + +class ForkedPdb(pdb_original.Pdb): + """A Pdb subclass that may be used + from a forked multiprocessing child + + """ + + def interaction(self, *args, **kwargs): + _stdin = sys.stdin + try: + sys.stdin = open('/dev/stdin') + pdb_original.Pdb.interaction(self, *args, **kwargs) + finally: + sys.stdin = _stdin + + +# Shifts src_tf dim to dest dim +# i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c) +def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): + n_dims = len(x.shape) + if src_dim < 0: + src_dim = n_dims + src_dim + if dest_dim < 0: + dest_dim = n_dims + dest_dim + + assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims + + dims = list(range(n_dims)) + del dims[src_dim] + + permutation = [] + ctr = 0 + for i in range(n_dims): + if i == dest_dim: + permutation.append(src_dim) + else: + permutation.append(dims[ctr]) + ctr += 1 + x = x.permute(permutation) + if make_contiguous: + x = x.contiguous() + return x + + +# reshapes tensor start from dim i (inclusive) +# to dim j (exclusive) to the desired shape +# e.g. if x.shape = (b, thw, c) then +# view_range(x, 1, 2, (t, h, w)) returns +# x of shape (b, t, h, w, c) +def view_range(x, i, j, shape): + shape = tuple(shape) + + n_dims = len(x.shape) + if i < 0: + i = n_dims + i + + if j is None: + j = n_dims + elif j < 0: + j = n_dims + j + + assert 0 <= i < j <= n_dims + + x_shape = x.shape + target_shape = x_shape[:i] + shape + x_shape[j:] + return x.view(target_shape) + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def tensor_slice(x, begin, size): + assert all([b >= 0 for b in begin]) + size = [l - b if s == -1 else s + for s, b, l in zip(size, begin, x.shape)] + assert all([s >= 0 for s in size]) + + slices = [slice(b, b + s) for b, s in zip(begin, size)] + return x[slices] + + +def adopt_weight(global_step, threshold=0, value=0.): + weight = 1 + if global_step < threshold: + weight = value + return weight + + +def save_video_grid(video, fname, nrow=None, fps=6): + b, c, t, h, w = video.shape + video = video.permute(0, 2, 3, 4, 1) + video = (video.cpu().numpy() * 255).astype('uint8') + if nrow is None: + nrow = math.ceil(math.sqrt(b)) + ncol = math.ceil(b / nrow) + padding = 1 + video_grid = np.zeros((t, (padding + h) * nrow + padding, + (padding + w) * ncol + padding, c), dtype='uint8') + for i in range(b): + r = i // ncol + c = i % ncol + start_r = (padding + h) * r + start_c = (padding + w) * c + video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] + video = [] + for i in range(t): + video.append(video_grid[i]) + imageio.mimsave(fname, video, fps=fps) + ## skvideo.io.vwrite(fname, video_grid, inputdict={'-r': '5'}) + #print('saved videos to', fname) + + +def comp_getattr(args, attr_name, default=None): + if hasattr(args, attr_name): + return getattr(args, attr_name) + else: + return default + + +def visualize_tensors(t, name=None, nest=0): + if name is not None: + print(name, "current nest: ", nest) + print("type: ", type(t)) + if 'dict' in str(type(t)): + print(t.keys()) + for k in t.keys(): + if t[k] is None: + print(k, "None") + else: + if 'Tensor' in str(type(t[k])): + print(k, t[k].shape) + elif 'dict' in str(type(t[k])): + print(k, 'dict') + visualize_tensors(t[k], name, nest + 1) + elif 'list' in str(type(t[k])): + print(k, len(t[k])) + visualize_tensors(t[k], name, nest + 1) + elif 'list' in str(type(t)): + print("list length: ", len(t)) + for t2 in t: + visualize_tensors(t2, name, nest + 1) + elif 'Tensor' in str(type(t)): + print(t.shape) + else: + print(t) + return "" diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/utils.py b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..54a3d68432d165ab2895859a89d7be4d150e9721 --- /dev/null +++ b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/utils.py @@ -0,0 +1,465 @@ +### Tumor Generateion +import random +import cv2 +import elasticdeform +import numpy as np +from scipy.ndimage import gaussian_filter +from TumorGeneration.ldm.ddpm.ddim import DDIMSampler + +# Step 1: Random select (numbers) location for tumor. +def random_select(mask_scan, organ_type): + # we first find z index and then sample point with z slice + # print('mask_scan',np.unique(mask_scan)) + # print('pixel num', (mask_scan == 1).sum()) + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0].min(), np.where(np.any(mask_scan, axis=(0, 1)))[0].max() + # print('z_start, z_end',z_start, z_end) + # we need to strict number z's position (0.3 - 0.7 in the middle of liver) + while 1: + z = round(random.uniform(0.3, 0.7) * (z_end - z_start)) + z_start + liver_mask = mask_scan[..., z] + # erode the mask (we don't want the edge points) + if organ_type == 'liver': + kernel = np.ones((5,5), dtype=np.uint8) + liver_mask = cv2.erode(liver_mask, kernel, iterations=1) + if (liver_mask == 1).sum() > 0: + break + + + + # print('liver_mask', (liver_mask == 1).sum()) + coordinates = np.argwhere(liver_mask == 1) + random_index = np.random.randint(0, len(coordinates)) + xyz = coordinates[random_index].tolist() # get x,y + xyz.append(z) + potential_points = xyz + + return potential_points + +def center_select(mask_scan): + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0].min(), np.where(np.any(mask_scan, axis=(0, 1)))[0].max() + x_start, x_end = np.where(np.any(mask_scan, axis=(1, 2)))[0].min(), np.where(np.any(mask_scan, axis=(1, 2)))[0].max() + y_start, y_end = np.where(np.any(mask_scan, axis=(0, 2)))[0].min(), np.where(np.any(mask_scan, axis=(0, 2)))[0].max() + + z = round(0.5 * (z_end - z_start)) + z_start + x = round(0.5 * (x_end - x_start)) + x_start + y = round(0.5 * (y_end - y_start)) + y_start + + xyz = [x, y, z] + potential_points = xyz + + return potential_points + +# Step 2 : generate the ellipsoid +def get_ellipsoid(x, y, z): + """" + x, y, z is the radius of this ellipsoid in x, y, z direction respectly. + """ + sh = (4*x, 4*y, 4*z) + out = np.zeros(sh, int) + aux = np.zeros(sh) + radii = np.array([x, y, z]) + com = np.array([2*x, 2*y, 2*z]) # center point + + # calculate the ellipsoid + bboxl = np.floor(com-radii).clip(0,None).astype(int) + bboxh = (np.ceil(com+radii)+1).clip(None, sh).astype(int) + roi = out[tuple(map(slice,bboxl,bboxh))] + roiaux = aux[tuple(map(slice,bboxl,bboxh))] + logrid = *map(np.square,np.ogrid[tuple( + map(slice,(bboxl-com)/radii,(bboxh-com-1)/radii,1j*(bboxh-bboxl)))]), + dst = (1-sum(logrid)).clip(0,None) + mask = dst>roiaux + roi[mask] = 1 + np.copyto(roiaux,dst,where=mask) + + return out + +def get_fixed_geo(mask_scan, tumor_type, organ_type): + if tumor_type == 'large': + enlarge_x, enlarge_y, enlarge_z = 280, 280, 280 + else: + enlarge_x, enlarge_y, enlarge_z = 160, 160, 160 + geo_mask = np.zeros((mask_scan.shape[0] + enlarge_x, mask_scan.shape[1] + enlarge_y, mask_scan.shape[2] + enlarge_z), dtype=np.int8) + tiny_radius, small_radius, medium_radius, large_radius = 4, 8, 16, 32 + + if tumor_type == 'tiny': + num_tumor = random.randint(1,3) + # num_tumor = 1 + for _ in range(num_tumor): + # Tiny tumor + x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + sigma = random.uniform(0.5,1) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + + if tumor_type == 'small': + num_tumor = random.randint(1,3) + # num_tumor = 1 + for _ in range(num_tumor): + # Small tumor + x = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + y = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + z = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + sigma = random.randint(1, 2) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'medium': + # num_tumor = random.randint(1, 3) + num_tumor = 1 + for _ in range(num_tumor): + # medium tumor + x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + sigma = random.randint(3, 6) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste medium tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'large': + num_tumor = 1 # random.randint(1,3) + for _ in range(num_tumor): + # Large tumor + + x = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + y = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + z = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + sigma = random.randint(5, 10) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + if organ_type == 'liver' or organ_type == 'kidney' : + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + else: + x_start, x_end = np.where(np.any(geo, axis=(1, 2)))[0].min(), np.where(np.any(geo, axis=(1, 2)))[0].max() + y_start, y_end = np.where(np.any(geo, axis=(0, 2)))[0].min(), np.where(np.any(geo, axis=(0, 2)))[0].max() + z_start, z_end = np.where(np.any(geo, axis=(0, 1)))[0].min(), np.where(np.any(geo, axis=(0, 1)))[0].max() + geo = geo[x_start:x_end, y_start:y_end, z_start:z_end] + + point = center_select(mask_scan) + + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low = new_point[0] - geo.shape[0]//2 + y_low = new_point[1] - geo.shape[1]//2 + z_low = new_point[2] - geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_low+geo.shape[0], y_low:y_low+geo.shape[1], z_low:z_low+geo.shape[2]] += geo + + geo_mask = geo_mask[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + # texture_map = texture_map[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + if ((tumor_type == 'medium') or (tumor_type == 'large')) and (organ_type == 'kidney'): + if random.random() > 0.5: + geo_mask = (geo_mask>=1) + else: + geo_mask = (geo_mask * mask_scan) >=1 + else: + geo_mask = (geo_mask * mask_scan) >=1 + + return geo_mask + + +from .ldm.vq_gan_3d.model.vqgan import VQGAN +import matplotlib.pyplot as plt +import SimpleITK as sitk +from .ldm.ddpm import Unet3D, GaussianDiffusion, Tester +from hydra import initialize, compose +import torch +import yaml +def synt_model_prepare(device, vqgan_ckpt='TumorGeneration/model_weight/recon_96d4_all.ckpt', diffusion_ckpt='TumorGeneration/model_weight/', fold=0, organ='liver'): + with initialize(config_path="diffusion_config/"): + cfg=compose(config_name="ddpm.yaml") + print('diffusion_ckpt',diffusion_ckpt) + vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt) + vqgan = vqgan.to(device) + vqgan.eval() + + early_Unet3D = Unet3D( + dim=cfg.diffusion_img_size, + dim_mults=cfg.dim_mults, + channels=cfg.diffusion_num_channels, + out_dim=cfg.out_dim + ).to(device) + + early_diffusion = GaussianDiffusion( + early_Unet3D, + vqgan_ckpt= vqgan_ckpt, # cfg.vqgan_ckpt, + image_size=cfg.diffusion_img_size, + num_frames=cfg.diffusion_depth_size, + channels=cfg.diffusion_num_channels, + timesteps=4, # cfg.timesteps, + loss_type=cfg.loss_type, + device=device + ).to(device) + + noearly_Unet3D = Unet3D( + dim=cfg.diffusion_img_size, + dim_mults=cfg.dim_mults, + channels=cfg.diffusion_num_channels, + out_dim=cfg.out_dim + ).to(device) + + noearly_diffusion = GaussianDiffusion( + noearly_Unet3D, + vqgan_ckpt= vqgan_ckpt, # cfg.vqgan_ckpt, + image_size=cfg.diffusion_img_size, + num_frames=cfg.diffusion_depth_size, + channels=cfg.diffusion_num_channels, + timesteps=200, # cfg.timesteps, + loss_type=cfg.loss_type, + device=device + ).to(device) + + early_tester = Tester(early_diffusion) + # noearly_tester = Tester(noearly_diffusion) + early_tester.load(diffusion_ckpt+'diff_{}_fold{}_early.pt'.format(organ, fold), map_location=device) + # noearly_tester.load(diffusion_ckpt+'diff_liver_fold{}_noearly_t200.pt'.format(fold), map_location=device) + + # early_checkpoint = torch.load(diffusion_ckpt+'diff_liver_fold{}_early.pt'.format(fold), map_location=device) + noearly_checkpoint = torch.load(diffusion_ckpt+'diff_{}_fold{}_noearly_t200.pt'.format(organ, fold), map_location=device) + # early_diffusion.load_state_dict(early_checkpoint['ema']) + noearly_diffusion.load_state_dict(noearly_checkpoint['ema']) + # early_sampler = DDIMSampler(early_diffusion, schedule="cosine") + noearly_sampler = DDIMSampler(noearly_diffusion, schedule="cosine") + # breakpoint() + return vqgan, early_tester, noearly_sampler + +def synthesize_early_tumor(ct_volume, organ_mask, organ_type, vqgan, tester): + device=ct_volume.device + + # generate tumor mask + tumor_types = ['tiny', 'small'] + # tumor_probs = np.array([0.5, 0.5]) + tumor_probs = np.array([0.2, 0.8]) + total_tumor_mask = [] + organ_mask_np = organ_mask.cpu().numpy() + with torch.no_grad(): + # get model input + for bs in range(organ_mask_np.shape[0]): + synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type) + total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:]) + total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device) + + volume = ct_volume*2.0 - 1.0 + mask = total_tumor_mask*2.0 - 1.0 + mask_ = 1-total_tumor_mask + masked_volume = (volume*mask_).detach() + + volume = volume.permute(0,1,-1,-3,-2) + masked_volume = masked_volume.permute(0,1,-1,-3,-2) + mask = mask.permute(0,1,-1,-3,-2) + + # vqgan encoder inference + masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True) + masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) / + (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:]) + cond = torch.cat((masked_volume_feat, cc), dim=1) + + # diffusion inference and decoder + tester.ema_model.eval() + sample = tester.ema_model.sample(batch_size=volume.shape[0], cond=cond) + + # if organ_type == 'liver' or organ_type == 'kidney' : + + mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0) + sigma = np.random.uniform(0, 4) # (1, 2) + mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma]) + # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy() + + volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0) + sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0) + + mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device) + final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_ + final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0) + # elif organ_type == 'pancreas': + # final_volume_ = (sample+1.0)/2.0 + final_volume_ = final_volume_.permute(0,1,-2,-1,-3) + organ_tumor_mask = torch.ones_like(organ_mask) + organ_tumor_mask[total_tumor_mask==1] = 2 + + return final_volume_, organ_tumor_mask + +def synthesize_medium_tumor(ct_volume, organ_mask, organ_type, vqgan, sampler, ddim_ts=50): + device=ct_volume.device + + # generate tumor mask + # tumor_types = ['large'] + # tumor_probs = np.array([1.0]) + total_tumor_mask = [] + organ_mask_np = organ_mask.cpu().numpy() + with torch.no_grad(): + # get model input + for bs in range(organ_mask_np.shape[0]): + # synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + synthetic_tumor_type = 'medium' + tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type) + total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:]) + total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device) + + volume = ct_volume*2.0 - 1.0 + mask = total_tumor_mask*2.0 - 1.0 + mask_ = 1-total_tumor_mask + masked_volume = (volume*mask_).detach() + + volume = volume.permute(0,1,-1,-3,-2) + masked_volume = masked_volume.permute(0,1,-1,-3,-2) + mask = mask.permute(0,1,-1,-3,-2) + + # vqgan encoder inference + masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True) + masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) / + (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:]) + cond = torch.cat((masked_volume_feat, cc), dim=1) + + # diffusion inference and decoder + shape = masked_volume_feat.shape[-4:] + samples_ddim, _ = sampler.sample(S=ddim_ts, + conditioning=cond, + batch_size=1, + shape=shape, + verbose=False) + samples_ddim = (((samples_ddim + 1.0) / 2.0) * (vqgan.codebook.embeddings.max() - + vqgan.codebook.embeddings.min())) + vqgan.codebook.embeddings.min() + + sample = vqgan.decode(samples_ddim, quantize=True) + + # if organ_type == 'liver' or organ_type == 'kidney': + # post-process + mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0) + sigma = np.random.uniform(0, 4) # (1, 2) + mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma]) + # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy() + + volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0) + sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0) + + mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device) + final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_ + final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0) + # elif organ_type == 'pancreas': + # final_volume_ = (sample+1.0)/2.0 + + final_volume_ = final_volume_.permute(0,1,-2,-1,-3) + organ_tumor_mask = torch.ones_like(organ_mask) + organ_tumor_mask[total_tumor_mask==1] = 2 + + return final_volume_, organ_tumor_mask + +def synthesize_large_tumor(ct_volume, organ_mask, organ_type, vqgan, sampler, ddim_ts=50): + device=ct_volume.device + + # generate tumor mask + # tumor_types = ['large'] + # tumor_probs = np.array([1.0]) + total_tumor_mask = [] + organ_mask_np = organ_mask.cpu().numpy() + with torch.no_grad(): + # get model input + for bs in range(organ_mask_np.shape[0]): + # synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + synthetic_tumor_type = 'large' + tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type) + total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:]) + total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device) + + volume = ct_volume*2.0 - 1.0 + mask = total_tumor_mask*2.0 - 1.0 + mask_ = 1-total_tumor_mask + masked_volume = (volume*mask_).detach() + + volume = volume.permute(0,1,-1,-3,-2) + masked_volume = masked_volume.permute(0,1,-1,-3,-2) + mask = mask.permute(0,1,-1,-3,-2) + + # vqgan encoder inference + masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True) + masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) / + (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:]) + cond = torch.cat((masked_volume_feat, cc), dim=1) + + # diffusion inference and decoder + shape = masked_volume_feat.shape[-4:] + samples_ddim, _ = sampler.sample(S=ddim_ts, + conditioning=cond, + batch_size=1, + shape=shape, + verbose=False) + samples_ddim = (((samples_ddim + 1.0) / 2.0) * (vqgan.codebook.embeddings.max() - + vqgan.codebook.embeddings.min())) + vqgan.codebook.embeddings.min() + + sample = vqgan.decode(samples_ddim, quantize=True) + + # if organ_type == 'liver' or organ_type == 'kidney': + # post-process + mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0) + sigma = np.random.uniform(0, 4) # (1, 2) + mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma]) + # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy() + + volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0) + sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0) + + mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device) + final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_ + final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0) + # elif organ_type == 'pancreas': + # final_volume_ = (sample+1.0)/2.0 + + final_volume_ = final_volume_.permute(0,1,-2,-1,-3) + organ_tumor_mask = torch.ones_like(organ_mask) + organ_tumor_mask[total_tumor_mask==1] = 2 + + return final_volume_, organ_tumor_mask \ No newline at end of file diff --git a/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/utils_.py b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/utils_.py new file mode 100644 index 0000000000000000000000000000000000000000..312e72d1571b3cd996fd567bcedf939443a4e182 --- /dev/null +++ b/Generation_Pipeline_filter/syn_pancreas/TumorGeneration/utils_.py @@ -0,0 +1,298 @@ +### Tumor Generateion +import random +import cv2 +import elasticdeform +import numpy as np +from scipy.ndimage import gaussian_filter + +# Step 1: Random select (numbers) location for tumor. +def random_select(mask_scan): + # we first find z index and then sample point with z slice + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0][[0, -1]] + + # we need to strict number z's position (0.3 - 0.7 in the middle of liver) + z = round(random.uniform(0.3, 0.7) * (z_end - z_start)) + z_start + + liver_mask = mask_scan[..., z] + + # erode the mask (we don't want the edge points) + kernel = np.ones((5,5), dtype=np.uint8) + liver_mask = cv2.erode(liver_mask, kernel, iterations=1) + + coordinates = np.argwhere(liver_mask == 1) + random_index = np.random.randint(0, len(coordinates)) + xyz = coordinates[random_index].tolist() # get x,y + xyz.append(z) + potential_points = xyz + + return potential_points + +# Step 2 : generate the ellipsoid +def get_ellipsoid(x, y, z): + """" + x, y, z is the radius of this ellipsoid in x, y, z direction respectly. + """ + sh = (4*x, 4*y, 4*z) + out = np.zeros(sh, int) + aux = np.zeros(sh) + radii = np.array([x, y, z]) + com = np.array([2*x, 2*y, 2*z]) # center point + + # calculate the ellipsoid + bboxl = np.floor(com-radii).clip(0,None).astype(int) + bboxh = (np.ceil(com+radii)+1).clip(None, sh).astype(int) + roi = out[tuple(map(slice,bboxl,bboxh))] + roiaux = aux[tuple(map(slice,bboxl,bboxh))] + logrid = *map(np.square,np.ogrid[tuple( + map(slice,(bboxl-com)/radii,(bboxh-com-1)/radii,1j*(bboxh-bboxl)))]), + dst = (1-sum(logrid)).clip(0,None) + mask = dst>roiaux + roi[mask] = 1 + np.copyto(roiaux,dst,where=mask) + + return out + +def get_fixed_geo(mask_scan, tumor_type): + + enlarge_x, enlarge_y, enlarge_z = 160, 160, 160 + geo_mask = np.zeros((mask_scan.shape[0] + enlarge_x, mask_scan.shape[1] + enlarge_y, mask_scan.shape[2] + enlarge_z), dtype=np.int8) + tiny_radius, small_radius, medium_radius, large_radius = 4, 8, 16, 32 + + if tumor_type == 'tiny': + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Tiny tumor + x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + sigma = random.uniform(0.5,1) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + + if tumor_type == 'small': + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Small tumor + x = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + y = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + z = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + sigma = random.randint(1, 2) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'medium': + num_tumor = random.randint(2, 5) + for _ in range(num_tumor): + # medium tumor + x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + sigma = random.randint(3, 6) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste medium tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'large': + num_tumor = random.randint(1,3) + for _ in range(num_tumor): + # Large tumor + x = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + y = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + z = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + sigma = random.randint(5, 10) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == "mix": + # tiny + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Tiny tumor + x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + sigma = random.uniform(0.5,1) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + + # small + num_tumor = random.randint(5,10) + for _ in range(num_tumor): + # Small tumor + x = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + y = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + z = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + sigma = random.randint(1, 2) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + # medium + num_tumor = random.randint(2, 5) + for _ in range(num_tumor): + # medium tumor + x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + sigma = random.randint(3, 6) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste medium tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + # large + num_tumor = random.randint(1,3) + for _ in range(num_tumor): + # Large tumor + x = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + y = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + z = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + sigma = random.randint(5, 10) + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + geo_mask = geo_mask[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + # texture_map = texture_map[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + geo_mask = (geo_mask * mask_scan) >=1 + + return geo_mask + + +def get_tumor(volume_scan, mask_scan, tumor_type): + tumor_mask = get_fixed_geo(mask_scan, tumor_type) + + sigma = np.random.uniform(1, 2) + # difference = np.random.uniform(65, 145) + difference = 1 + + # blur the boundary + tumor_mask_blur = gaussian_filter(tumor_mask*1.0, sigma) + + + abnormally_full = volume_scan * (1 - mask_scan) + abnormally + abnormally_mask = mask_scan + geo_mask + + return abnormally_full, abnormally_mask + +def SynthesisTumor(volume_scan, mask_scan, tumor_type): + # for speed_generate_tumor, we only send the liver part into the generate program + x_start, x_end = np.where(np.any(mask_scan, axis=(1, 2)))[0][[0, -1]] + y_start, y_end = np.where(np.any(mask_scan, axis=(0, 2)))[0][[0, -1]] + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0][[0, -1]] + + # shrink the boundary + x_start, x_end = max(0, x_start+1), min(mask_scan.shape[0], x_end-1) + y_start, y_end = max(0, y_start+1), min(mask_scan.shape[1], y_end-1) + z_start, z_end = max(0, z_start+1), min(mask_scan.shape[2], z_end-1) + + ct_volume = volume_scan[x_start:x_end, y_start:y_end, z_start:z_end] + organ_mask = mask_scan[x_start:x_end, y_start:y_end, z_start:z_end] + + # input texture shape: 420, 300, 320 + # we need to cut it into the shape of liver_mask + # for examples, the liver_mask.shape = 286, 173, 46; we should change the texture shape + x_length, y_length, z_length = 64, 64, 64 + crop_x = random.randint(x_start, x_end - x_length - 1) # random select the start point, -1 is to avoid boundary check + crop_y = random.randint(y_start, y_end - y_length - 1) + crop_z = random.randint(z_start, z_end - z_length - 1) + + ct_volume, organ_tumor_mask = get_tumor(ct_volume, organ_mask, tumor_type) + volume_scan[x_start:x_end, y_start:y_end, z_start:z_end] = ct_volume + mask_scan[x_start:x_end, y_start:y_end, z_start:z_end] = organ_tumor_mask + + return volume_scan, mask_scan diff --git a/Generation_Pipeline_filter/syn_pancreas/healthy_pancreas_1k.txt b/Generation_Pipeline_filter/syn_pancreas/healthy_pancreas_1k.txt new file mode 100644 index 0000000000000000000000000000000000000000..5b54b3f0c568b2953320a4691f0196f8315a79fe --- /dev/null +++ b/Generation_Pipeline_filter/syn_pancreas/healthy_pancreas_1k.txt @@ -0,0 +1,774 @@ +BDMAP_00004652 +BDMAP_00002164 +BDMAP_00000439 +BDMAP_00004775 +BDMAP_00001148 +BDMAP_00003384 +BDMAP_00003634 +BDMAP_00000030 +BDMAP_00005119 +BDMAP_00003580 +BDMAP_00000320 +BDMAP_00002185 +BDMAP_00002487 +BDMAP_00002775 +BDMAP_00000159 +BDMAP_00004475 +BDMAP_00004278 +BDMAP_00004131 +BDMAP_00000244 +BDMAP_00001966 +BDMAP_00000939 +BDMAP_00003928 +BDMAP_00003031 +BDMAP_00000568 +BDMAP_00001309 +BDMAP_00004232 +BDMAP_00004431 +BDMAP_00001794 +BDMAP_00001752 +BDMAP_00004228 +BDMAP_00002363 +BDMAP_00001020 +BDMAP_00002739 +BDMAP_00004462 +BDMAP_00003749 +BDMAP_00002472 +BDMAP_00003455 +BDMAP_00003911 +BDMAP_00003486 +BDMAP_00001496 +BDMAP_00004378 +BDMAP_00002282 +BDMAP_00001343 +BDMAP_00001747 +BDMAP_00004450 +BDMAP_00000607 +BDMAP_00000100 +BDMAP_00001620 +BDMAP_00004415 +BDMAP_00004897 +BDMAP_00001737 +BDMAP_00000604 +BDMAP_00004922 +BDMAP_00004870 +BDMAP_00002060 +BDMAP_00003412 +BDMAP_00004850 +BDMAP_00000794 +BDMAP_00000989 +BDMAP_00000205 +BDMAP_00003151 +BDMAP_00001255 +BDMAP_00000667 +BDMAP_00003343 +BDMAP_00001237 +BDMAP_00000023 +BDMAP_00003281 +BDMAP_00000907 +BDMAP_00000432 +BDMAP_00001782 +BDMAP_00002166 +BDMAP_00004481 +BDMAP_00003363 +BDMAP_00001474 +BDMAP_00001995 +BDMAP_00002854 +BDMAP_00003603 +BDMAP_00001383 +BDMAP_00002656 +BDMAP_00004427 +BDMAP_00001563 +BDMAP_00001809 +BDMAP_00002114 +BDMAP_00000304 +BDMAP_00001692 +BDMAP_00001688 +BDMAP_00001119 +BDMAP_00000449 +BDMAP_00004014 +BDMAP_00002524 +BDMAP_00000725 +BDMAP_00000918 +BDMAP_00000642 +BDMAP_00001238 +BDMAP_00002373 +BDMAP_00002326 +BDMAP_00004373 +BDMAP_00003615 +BDMAP_00003324 +BDMAP_00002654 +BDMAP_00002849 +BDMAP_00003491 +BDMAP_00002655 +BDMAP_00002712 +BDMAP_00001516 +BDMAP_00000469 +BDMAP_00001549 +BDMAP_00000713 +BDMAP_00000745 +BDMAP_00000259 +BDMAP_00003569 +BDMAP_00005067 +BDMAP_00004185 +BDMAP_00003357 +BDMAP_00002419 +BDMAP_00002598 +BDMAP_00002167 +BDMAP_00000137 +BDMAP_00003448 +BDMAP_00000965 +BDMAP_00000232 +BDMAP_00004608 +BDMAP_00003680 +BDMAP_00000716 +BDMAP_00002403 +BDMAP_00004216 +BDMAP_00001359 +BDMAP_00004175 +BDMAP_00002791 +BDMAP_00002940 +BDMAP_00000355 +BDMAP_00004294 +BDMAP_00001426 +BDMAP_00001475 +BDMAP_00002986 +BDMAP_00002884 +BDMAP_00000400 +BDMAP_00002410 +BDMAP_00000297 +BDMAP_00001636 +BDMAP_00005113 +BDMAP_00004074 +BDMAP_00002333 +BDMAP_00003976 +BDMAP_00002383 +BDMAP_00000161 +BDMAP_00001212 +BDMAP_00000366 +BDMAP_00003070 +BDMAP_00003943 +BDMAP_00003930 +BDMAP_00003164 +BDMAP_00001906 +BDMAP_00002889 +BDMAP_00004163 +BDMAP_00001456 +BDMAP_00003972 +BDMAP_00004586 +BDMAP_00000626 +BDMAP_00001095 +BDMAP_00000532 +BDMAP_00003377 +BDMAP_00003225 +BDMAP_00001289 +BDMAP_00001275 +BDMAP_00004509 +BDMAP_00000998 +BDMAP_00000836 +BDMAP_00001015 +BDMAP_00004650 +BDMAP_00005186 +BDMAP_00000608 +BDMAP_00003898 +BDMAP_00002696 +BDMAP_00003560 +BDMAP_00004578 +BDMAP_00000828 +BDMAP_00000690 +BDMAP_00003564 +BDMAP_00005174 +BDMAP_00000132 +BDMAP_00005105 +BDMAP_00000902 +BDMAP_00003947 +BDMAP_00002184 +BDMAP_00001785 +BDMAP_00002361 +BDMAP_00003255 +BDMAP_00000971 +BDMAP_00003493 +BDMAP_00002267 +BDMAP_00005154 +BDMAP_00000982 +BDMAP_00005157 +BDMAP_00004384 +BDMAP_00003063 +BDMAP_00001982 +BDMAP_00002273 +BDMAP_00001102 +BDMAP_00002689 +BDMAP_00000034 +BDMAP_00001514 +BDMAP_00005081 +BDMAP_00001786 +BDMAP_00004033 +BDMAP_00004457 +BDMAP_00000710 +BDMAP_00001198 +BDMAP_00004479 +BDMAP_00000873 +BDMAP_00000362 +BDMAP_00004616 +BDMAP_00003128 +BDMAP_00001607 +BDMAP_00004104 +BDMAP_00001517 +BDMAP_00004639 +BDMAP_00005170 +BDMAP_00002305 +BDMAP_00004746 +BDMAP_00003333 +BDMAP_00001807 +BDMAP_00004579 +BDMAP_00002260 +BDMAP_00004416 +BDMAP_00003932 +BDMAP_00001316 +BDMAP_00003411 +BDMAP_00000839 +BDMAP_00004738 +BDMAP_00001438 +BDMAP_00003435 +BDMAP_00001697 +BDMAP_00001911 +BDMAP_00001735 +BDMAP_00002902 +BDMAP_00001834 +BDMAP_00000069 +BDMAP_00004066 +BDMAP_00000434 +BDMAP_00004744 +BDMAP_00000347 +BDMAP_00001246 +BDMAP_00003150 +BDMAP_00003957 +BDMAP_00001768 +BDMAP_00002663 +BDMAP_00004147 +BDMAP_00003510 +BDMAP_00002242 +BDMAP_00005016 +BDMAP_00002275 +BDMAP_00001924 +BDMAP_00002214 +BDMAP_00002529 +BDMAP_00000562 +BDMAP_00000122 +BDMAP_00002707 +BDMAP_00000874 +BDMAP_00000176 +BDMAP_00002804 +BDMAP_00005005 +BDMAP_00001422 +BDMAP_00005017 +BDMAP_00000653 +BDMAP_00002609 +BDMAP_00003327 +BDMAP_00002484 +BDMAP_00004673 +BDMAP_00004493 +BDMAP_00003740 +BDMAP_00002271 +BDMAP_00002742 +BDMAP_00002826 +BDMAP_00001035 +BDMAP_00002068 +BDMAP_00003815 +BDMAP_00003052 +BDMAP_00004499 +BDMAP_00002065 +BDMAP_00001025 +BDMAP_00004888 +BDMAP_00002592 +BDMAP_00004030 +BDMAP_00001024 +BDMAP_00002041 +BDMAP_00002807 +BDMAP_00002751 +BDMAP_00003272 +BDMAP_00004600 +BDMAP_00004154 +BDMAP_00003774 +BDMAP_00000948 +BDMAP_00002173 +BDMAP_00004510 +BDMAP_00000104 +BDMAP_00004374 +BDMAP_00000429 +BDMAP_00004420 +BDMAP_00001853 +BDMAP_00003600 +BDMAP_00002349 +BDMAP_00001863 +BDMAP_00004830 +BDMAP_00002981 +BDMAP_00001941 +BDMAP_00001128 +BDMAP_00005151 +BDMAP_00003890 +BDMAP_00003640 +BDMAP_00004257 +BDMAP_00004943 +BDMAP_00001068 +BDMAP_00001305 +BDMAP_00000414 +BDMAP_00000465 +BDMAP_00003727 +BDMAP_00002199 +BDMAP_00001769 +BDMAP_00004187 +BDMAP_00001891 +BDMAP_00000980 +BDMAP_00003923 +BDMAP_00000942 +BDMAP_00001114 +BDMAP_00001602 +BDMAP_00002845 +BDMAP_00003178 +BDMAP_00003409 +BDMAP_00001562 +BDMAP_00002909 +BDMAP_00003808 +BDMAP_00001169 +BDMAP_00001104 +BDMAP_00001483 +BDMAP_00005009 +BDMAP_00001957 +BDMAP_00003153 +BDMAP_00001444 +BDMAP_00000851 +BDMAP_00005191 +BDMAP_00000687 +BDMAP_00003722 +BDMAP_00003330 +BDMAP_00002347 +BDMAP_00002955 +BDMAP_00001089 +BDMAP_00004529 +BDMAP_00003268 +BDMAP_00001522 +BDMAP_00001502 +BDMAP_00000240 +BDMAP_00004867 +BDMAP_00000480 +BDMAP_00000452 +BDMAP_00002918 +BDMAP_00002953 +BDMAP_00002039 +BDMAP_00000889 +BDMAP_00002746 +BDMAP_00003608 +BDMAP_00003664 +BDMAP_00003299 +BDMAP_00001445 +BDMAP_00000113 +BDMAP_00001705 +BDMAP_00000044 +BDMAP_00003513 +BDMAP_00001261 +BDMAP_00004990 +BDMAP_00003143 +BDMAP_00003111 +BDMAP_00002319 +BDMAP_00004664 +BDMAP_00003717 +BDMAP_00004717 +BDMAP_00004745 +BDMAP_00000671 +BDMAP_00002990 +BDMAP_00004901 +BDMAP_00002545 +BDMAP_00004980 +BDMAP_00000913 +BDMAP_00000437 +BDMAP_00002864 +BDMAP_00000364 +BDMAP_00004195 +BDMAP_00000162 +BDMAP_00002840 +BDMAP_00000233 +BDMAP_00002744 +BDMAP_00001218 +BDMAP_00002289 +BDMAP_00000229 +BDMAP_00005114 +BDMAP_00000279 +BDMAP_00003832 +BDMAP_00000241 +BDMAP_00002251 +BDMAP_00001676 +BDMAP_00001635 +BDMAP_00003444 +BDMAP_00002265 +BDMAP_00002498 +BDMAP_00001209 +BDMAP_00001138 +BDMAP_00002407 +BDMAP_00003798 +BDMAP_00001325 +BDMAP_00002631 +BDMAP_00004304 +BDMAP_00001078 +BDMAP_00002562 +BDMAP_00003576 +BDMAP_00001977 +BDMAP_00002396 +BDMAP_00001333 +BDMAP_00004925 +BDMAP_00004903 +BDMAP_00000273 +BDMAP_00000571 +BDMAP_00001027 +BDMAP_00000149 +BDMAP_00001962 +BDMAP_00003481 +BDMAP_00001256 +BDMAP_00000871 +BDMAP_00000926 +BDMAP_00000572 +BDMAP_00004558 +BDMAP_00000435 +BDMAP_00000837 +BDMAP_00003713 +BDMAP_00002875 +BDMAP_00004645 +BDMAP_00001711 +BDMAP_00001296 +BDMAP_00002648 +BDMAP_00004561 +BDMAP_00002318 +BDMAP_00001835 +BDMAP_00003524 +BDMAP_00002959 +BDMAP_00002422 +BDMAP_00004597 +BDMAP_00000487 +BDMAP_00002359 +BDMAP_00005001 +BDMAP_00004817 +BDMAP_00001539 +BDMAP_00002936 +BDMAP_00002719 +BDMAP_00005167 +BDMAP_00001265 +BDMAP_00001471 +BDMAP_00001511 +BDMAP_00005139 +BDMAP_00002426 +BDMAP_00002288 +BDMAP_00004808 +BDMAP_00002085 +BDMAP_00004435 +BDMAP_00000319 +BDMAP_00003614 +BDMAP_00001109 +BDMAP_00000331 +BDMAP_00004491 +BDMAP_00002440 +BDMAP_00003373 +BDMAP_00005065 +BDMAP_00005006 +BDMAP_00002509 +BDMAP_00003973 +BDMAP_00004417 +BDMAP_00000935 +BDMAP_00004624 +BDMAP_00003364 +BDMAP_00005085 +BDMAP_00003073 +BDMAP_00002730 +BDMAP_00004825 +BDMAP_00000039 +BDMAP_00004615 +BDMAP_00003736 +BDMAP_00005097 +BDMAP_00003074 +BDMAP_00000662 +BDMAP_00001122 +BDMAP_00002252 +BDMAP_00001396 +BDMAP_00004011 +BDMAP_00004981 +BDMAP_00004165 +BDMAP_00003920 +BDMAP_00001215 +BDMAP_00003867 +BDMAP_00000923 +BDMAP_00002626 +BDMAP_00003315 +BDMAP_00000660 +BDMAP_00000329 +BDMAP_00004508 +BDMAP_00001518 +BDMAP_00003849 +BDMAP_00003897 +BDMAP_00003300 +BDMAP_00002253 +BDMAP_00003514 +BDMAP_00000117 +BDMAP_00002421 +BDMAP_00001413 +BDMAP_00004328 +BDMAP_00001130 +BDMAP_00000043 +BDMAP_00001410 +BDMAP_00000245 +BDMAP_00004117 +BDMAP_00002401 +BDMAP_00003857 +BDMAP_00000921 +BDMAP_00000138 +BDMAP_00003113 +BDMAP_00003358 +BDMAP_00002099 +BDMAP_00004016 +BDMAP_00003439 +BDMAP_00002152 +BDMAP_00003767 +BDMAP_00001598 +BDMAP_00003482 +BDMAP_00003520 +BDMAP_00002075 +BDMAP_00000987 +BDMAP_00003946 +BDMAP_00005160 +BDMAP_00001286 +BDMAP_00003359 +BDMAP_00002661 +BDMAP_00004704 +BDMAP_00003994 +BDMAP_00002226 +BDMAP_00000968 +BDMAP_00003556 +BDMAP_00003236 +BDMAP_00001791 +BDMAP_00004712 +BDMAP_00001077 +BDMAP_00003955 +BDMAP_00002479 +BDMAP_00001865 +BDMAP_00001059 +BDMAP_00002704 +BDMAP_00000656 +BDMAP_00001379 +BDMAP_00000883 +BDMAP_00002856 +BDMAP_00004199 +BDMAP_00001200 +BDMAP_00005083 +BDMAP_00004552 +BDMAP_00000616 +BDMAP_00004834 +BDMAP_00004815 +BDMAP_00001826 +BDMAP_00000615 +BDMAP_00001045 +BDMAP_00002695 +BDMAP_00004017 +BDMAP_00002103 +BDMAP_00002057 +BDMAP_00004620 +BDMAP_00000128 +BDMAP_00001185 +BDMAP_00002612 +BDMAP_00005073 +BDMAP_00001753 +BDMAP_00004196 +BDMAP_00004281 +BDMAP_00002717 +BDMAP_00000263 +BDMAP_00004103 +BDMAP_00003381 +BDMAP_00001093 +BDMAP_00000373 +BDMAP_00000881 +BDMAP_00002230 +BDMAP_00001707 +BDMAP_00002476 +BDMAP_00003294 +BDMAP_00004482 +BDMAP_00003267 +BDMAP_00002710 +BDMAP_00002451 +BDMAP_00001270 +BDMAP_00004878 +BDMAP_00001784 +BDMAP_00001281 +BDMAP_00002283 +BDMAP_00001183 +BDMAP_00001945 +BDMAP_00004604 +BDMAP_00000413 +BDMAP_00003506 +BDMAP_00002458 +BDMAP_00000977 +BDMAP_00000833 +BDMAP_00001055 +BDMAP_00002495 +BDMAP_00000887 +BDMAP_00002496 +BDMAP_00002942 +BDMAP_00000574 +BDMAP_00001868 +BDMAP_00000547 +BDMAP_00001230 +BDMAP_00003762 +BDMAP_00003971 +BDMAP_00000321 +BDMAP_00004876 +BDMAP_00003833 +BDMAP_00003461 +BDMAP_00003301 +BDMAP_00002846 +BDMAP_00002582 +BDMAP_00001710 +BDMAP_00001487 +BDMAP_00000936 +BDMAP_00004121 +BDMAP_00004459 +BDMAP_00000219 +BDMAP_00000091 +BDMAP_00001283 +BDMAP_00000084 +BDMAP_00000516 +BDMAP_00004250 +BDMAP_00001732 +BDMAP_00003694 +BDMAP_00004031 +BDMAP_00001557 +BDMAP_00002437 +BDMAP_00002933 +BDMAP_00000264 +BDMAP_00005099 +BDMAP_00004296 +BDMAP_00001917 +BDMAP_00003252 +BDMAP_00004389 +BDMAP_00002463 +BDMAP_00004253 +BDMAP_00004910 +BDMAP_00003172 +BDMAP_00001624 +BDMAP_00003484 +BDMAP_00001907 +BDMAP_00003952 +BDMAP_00002653 +BDMAP_00000368 +BDMAP_00000569 +BDMAP_00004995 +BDMAP_00003956 +BDMAP_00003497 +BDMAP_00003058 +BDMAP_00000552 +BDMAP_00000481 +BDMAP_00000805 +BDMAP_00003002 +BDMAP_00000698 +BDMAP_00004783 +BDMAP_00001324 +BDMAP_00002133 +BDMAP_00005120 +BDMAP_00003581 +BDMAP_00004890 +BDMAP_00001533 +BDMAP_00004039 +BDMAP_00000190 +BDMAP_00004028 +BDMAP_00004130 +BDMAP_00001370 +BDMAP_00002805 +BDMAP_00001397 +BDMAP_00001126 +BDMAP_00001875 +BDMAP_00005130 +BDMAP_00003361 +BDMAP_00002485 +BDMAP_00001273 +BDMAP_00000582 +BDMAP_00003672 +BDMAP_00000778 +BDMAP_00002841 +BDMAP_00001242 +BDMAP_00000345 +BDMAP_00000036 +BDMAP_00003996 +BDMAP_00003701 +BDMAP_00003425 +BDMAP_00001656 +BDMAP_00001802 +BDMAP_00001420 +BDMAP_00003752 +BDMAP_00002924 +BDMAP_00003202 +BDMAP_00000831 +BDMAP_00003392 +BDMAP_00002022 +BDMAP_00001223 +BDMAP_00003457 +BDMAP_00001236 +BDMAP_00000810 +BDMAP_00004676 +BDMAP_00003847 +BDMAP_00001225 +BDMAP_00005168 +BDMAP_00004113 +BDMAP_00002828 +BDMAP_00004087 +BDMAP_00004407 +BDMAP_00002748 +BDMAP_00003516 +BDMAP_00004395 +BDMAP_00001985 +BDMAP_00001171 +BDMAP_00000101 +BDMAP_00002117 +BDMAP_00001434 +BDMAP_00000139 +BDMAP_00002465 +BDMAP_00001251 +BDMAP_00001908 +BDMAP_00002354 +BDMAP_00002776 +BDMAP_00004887 +BDMAP_00000066 +BDMAP_00003549 +BDMAP_00000812 +BDMAP_00000353 +BDMAP_00004894 +BDMAP_00004956 +BDMAP_00002871 +BDMAP_00004764 +BDMAP_00004551 +BDMAP_00002404 +BDMAP_00000059 +BDMAP_00002017 +BDMAP_00003558 +BDMAP_00004065 +BDMAP_00003406 +BDMAP_00002471 +BDMAP_00000941 +BDMAP_00003109 +BDMAP_00000511 +BDMAP_00000826 +BDMAP_00004839 +BDMAP_00004671 +BDMAP_00002930 +BDMAP_00004331 +BDMAP_00001664 +BDMAP_00001001 +BDMAP_00001766 +BDMAP_00003827 +BDMAP_00001258 +BDMAP_00001892 +BDMAP_00000062 +BDMAP_00000867 +BDMAP_00002803 +BDMAP_00000285 +BDMAP_00001647 +BDMAP_00005077 +BDMAP_00000152 +BDMAP_00000709 +BDMAP_00002172 +BDMAP_00004148 +BDMAP_00001010 diff --git a/Generation_Pipeline_filter/syn_pancreas/requirements.txt b/Generation_Pipeline_filter/syn_pancreas/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9c1067c8eb4f44699bbf517601396113cea4b370 --- /dev/null +++ b/Generation_Pipeline_filter/syn_pancreas/requirements.txt @@ -0,0 +1,94 @@ +absl-py==1.1.0 +accelerate==0.11.0 +aiohttp==3.8.1 +aiosignal==1.2.0 +antlr4-python3-runtime==4.9.3 +async-timeout==4.0.2 +attrs==21.4.0 +autopep8==1.6.0 +cachetools==5.2.0 +certifi==2022.6.15 +charset-normalizer==2.0.12 +click==8.1.3 +cycler==0.11.0 +Deprecated==1.2.13 +docker-pycreds==0.4.0 +einops==0.4.1 +einops-exts==0.0.3 +ema-pytorch==0.0.8 +fonttools==4.34.4 +frozenlist==1.3.0 +fsspec==2022.5.0 +ftfy==6.1.1 +future==0.18.2 +gitdb==4.0.9 +GitPython==3.1.27 +google-auth==2.9.0 +google-auth-oauthlib==0.4.6 +grpcio==1.47.0 +h5py==3.7.0 +humanize==4.2.2 +hydra-core==1.2.0 +idna==3.3 +imageio==2.19.3 +imageio-ffmpeg==0.4.7 +importlib-metadata==4.12.0 +importlib-resources==5.9.0 +joblib==1.1.0 +kiwisolver==1.4.3 +lxml==4.9.1 +Markdown==3.3.7 +matplotlib==3.5.2 +multidict==6.0.2 +networkx==2.8.5 +nibabel==4.0.1 +nilearn==0.9.1 +numpy==1.23.0 +oauthlib==3.2.0 +omegaconf==2.2.3 +pandas==1.4.3 +Pillow==9.1.1 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pycodestyle==2.8.0 +pyDeprecate==0.3.1 +pydicom==2.3.0 +pytorch-lightning==1.6.4 +pytz==2022.1 +PyWavelets==1.3.0 +PyYAML==6.0 +pyzmq==19.0.2 +regex==2022.6.2 +requests==2.28.0 +requests-oauthlib==1.3.1 +rotary-embedding-torch==0.1.5 +rsa==4.8 +scikit-image==0.19.3 +scikit-learn==1.1.2 +scikit-video==1.1.11 +scipy==1.8.1 +seaborn==0.11.2 +sentry-sdk==1.7.2 +setproctitle==1.2.3 +shortuuid==1.0.9 +SimpleITK==2.1.1.2 +smmap==5.0.0 +tensorboard==2.9.1 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.1 +threadpoolctl==3.1.0 +tifffile==2022.8.3 +toml==0.10.2 +torch-tb-profiler==0.4.0 +torchio==0.18.80 +torchmetrics==0.9.1 +tqdm==4.64.0 +typing_extensions==4.2.0 +urllib3==1.26.9 +wandb==0.12.21 +Werkzeug==2.1.2 +wrapt==1.14.1 +yarl==1.7.2 +zipp==3.8.0 +wandb +tensorboardX==2.4.1 diff --git a/Generation_Pipeline_filter/val_set/bodymap_colon.txt b/Generation_Pipeline_filter/val_set/bodymap_colon.txt new file mode 100644 index 0000000000000000000000000000000000000000..c529ab8278b77ca39b67afc9ea118c67cb4ebcec --- /dev/null +++ b/Generation_Pipeline_filter/val_set/bodymap_colon.txt @@ -0,0 +1,25 @@ +BDMAP_00004910 +BDMAP_00001438 +BDMAP_00000568 +BDMAP_00002828 +BDMAP_00003634 +BDMAP_00004121 +BDMAP_00004764 +BDMAP_00003972 +BDMAP_00003113 +BDMAP_00005001 +BDMAP_00001785 +BDMAP_00005016 +BDMAP_00002739 +BDMAP_00003299 +BDMAP_00003357 +BDMAP_00001078 +BDMAP_00000874 +BDMAP_00003560 +BDMAP_00003373 +BDMAP_00003172 +BDMAP_00002875 +BDMAP_00000552 +BDMAP_00003510 +BDMAP_00004604 +BDMAP_00002598 diff --git a/Generation_Pipeline_filter/val_set/bodymap_kidney.txt b/Generation_Pipeline_filter/val_set/bodymap_kidney.txt new file mode 100644 index 0000000000000000000000000000000000000000..f420fbe5a18b65775d1eb8bc4362eb4e9cc1421b --- /dev/null +++ b/Generation_Pipeline_filter/val_set/bodymap_kidney.txt @@ -0,0 +1,24 @@ +BDMAP_00000487 +BDMAP_00002631 +BDMAP_00002744 +BDMAP_00000833 +BDMAP_00002648 +BDMAP_00002840 +BDMAP_00000608 +BDMAP_00002804 +BDMAP_00002775 +BDMAP_00004551 +BDMAP_00001413 +BDMAP_00000511 +BDMAP_00003150 +BDMAP_00000794 +BDMAP_00001255 +BDMAP_00002242 +BDMAP_00004746 +BDMAP_00002864 +BDMAP_00003486 +BDMAP_00004250 +BDMAP_00003143 +BDMAP_00003164 +BDMAP_00004578 +BDMAP_00001735 diff --git a/Generation_Pipeline_filter/val_set/bodymap_liver.txt b/Generation_Pipeline_filter/val_set/bodymap_liver.txt new file mode 100644 index 0000000000000000000000000000000000000000..d16f55abd27cc01769b6e96a3a9606a3abb568c9 --- /dev/null +++ b/Generation_Pipeline_filter/val_set/bodymap_liver.txt @@ -0,0 +1,25 @@ +BDMAP_00004281 +BDMAP_00003481 +BDMAP_00004890 +BDMAP_00001786 +BDMAP_00000101 +BDMAP_00004117 +BDMAP_00000615 +BDMAP_00000921 +BDMAP_00005130 +BDMAP_00004378 +BDMAP_00004704 +BDMAP_00003439 +BDMAP_00002717 +BDMAP_00004878 +BDMAP_00000100 +BDMAP_00001309 +BDMAP_00002214 +BDMAP_00001198 +BDMAP_00001962 +BDMAP_00002463 +BDMAP_00005139 +BDMAP_00000831 +BDMAP_00002955 +BDMAP_00003272 +BDMAP_00000745 diff --git a/Generation_Pipeline_filter/val_set/bodymap_pancreas.txt b/Generation_Pipeline_filter/val_set/bodymap_pancreas.txt new file mode 100644 index 0000000000000000000000000000000000000000..6da3361e4b14cfd1237e52dc2fd014ae4681c135 --- /dev/null +++ b/Generation_Pipeline_filter/val_set/bodymap_pancreas.txt @@ -0,0 +1,24 @@ +BDMAP_00000332 +BDMAP_00004858 +BDMAP_00005155 +BDMAP_00001205 +BDMAP_00004770 +BDMAP_00001361 +BDMAP_00002944 +BDMAP_00003961 +BDMAP_00000430 +BDMAP_00000679 +BDMAP_00003809 +BDMAP_00004115 +BDMAP_00003367 +BDMAP_00002899 +BDMAP_00003771 +BDMAP_00003502 +BDMAP_00001628 +BDMAP_00003884 +BDMAP_00005074 +BDMAP_00003114 +BDMAP_00004741 +BDMAP_00001746 +BDMAP_00002603 +BDMAP_00004128 diff --git a/Generation_Pipeline_filter_all/Atlas_X_1k.txt b/Generation_Pipeline_filter_all/Atlas_X_1k.txt new file mode 100644 index 0000000000000000000000000000000000000000..9184dd0cbff7a9e917f585256ad7abb4afc6ef31 --- /dev/null +++ b/Generation_Pipeline_filter_all/Atlas_X_1k.txt @@ -0,0 +1,956 @@ +BDMAP_00002654 +BDMAP_00002173 +BDMAP_00003294 +BDMAP_00001597 +BDMAP_00001557 +BDMAP_00003327 +BDMAP_00002075 +BDMAP_00004887 +BDMAP_00001434 +BDMAP_00001705 +BDMAP_00000710 +BDMAP_00002271 +BDMAP_00003406 +BDMAP_00003556 +BDMAP_00002103 +BDMAP_00002230 +BDMAP_00000427 +BDMAP_00002746 +BDMAP_00003483 +BDMAP_00003543 +BDMAP_00001396 +BDMAP_00000836 +BDMAP_00003808 +BDMAP_00002619 +BDMAP_00004183 +BDMAP_00001562 +BDMAP_00001414 +BDMAP_00004087 +BDMAP_00002704 +BDMAP_00004198 +BDMAP_00000285 +BDMAP_00005077 +BDMAP_00001343 +BDMAP_00002909 +BDMAP_00002849 +BDMAP_00002655 +BDMAP_00001015 +BDMAP_00003592 +BDMAP_00001676 +BDMAP_00001863 +BDMAP_00002404 +BDMAP_00001035 +BDMAP_00003457 +BDMAP_00001782 +BDMAP_00004586 +BDMAP_00004514 +BDMAP_00004165 +BDMAP_00001171 +BDMAP_00005140 +BDMAP_00005037 +BDMAP_00001769 +BDMAP_00004482 +BDMAP_00003551 +BDMAP_00000887 +BDMAP_00004103 +BDMAP_00002689 +BDMAP_00003727 +BDMAP_00002653 +BDMAP_00000034 +BDMAP_00001504 +BDMAP_00000889 +BDMAP_00004992 +BDMAP_00002065 +BDMAP_00003815 +BDMAP_00004494 +BDMAP_00001545 +BDMAP_00004954 +BDMAP_00002332 +BDMAP_00004288 +BDMAP_00005006 +BDMAP_00001865 +BDMAP_00000604 +BDMAP_00004616 +BDMAP_00001359 +BDMAP_00003956 +BDMAP_00004148 +BDMAP_00001426 +BDMAP_00003301 +BDMAP_00003300 +BDMAP_00000104 +BDMAP_00001185 +BDMAP_00004459 +BDMAP_00000805 +BDMAP_00001238 +BDMAP_00004066 +BDMAP_00001020 +BDMAP_00002626 +BDMAP_00002730 +BDMAP_00000241 +BDMAP_00002017 +BDMAP_00001055 +BDMAP_00005073 +BDMAP_00004296 +BDMAP_00003425 +BDMAP_00003749 +BDMAP_00004775 +BDMAP_00004843 +BDMAP_00003752 +BDMAP_00005105 +BDMAP_00003832 +BDMAP_00004262 +BDMAP_00002085 +BDMAP_00003824 +BDMAP_00001057 +BDMAP_00003812 +BDMAP_00000993 +BDMAP_00000176 +BDMAP_00000618 +BDMAP_00003133 +BDMAP_00004652 +BDMAP_00002437 +BDMAP_00001461 +BDMAP_00003847 +BDMAP_00003381 +BDMAP_00004229 +BDMAP_00001109 +BDMAP_00002930 +BDMAP_00003664 +BDMAP_00001853 +BDMAP_00000851 +BDMAP_00002152 +BDMAP_00004510 +BDMAP_00000362 +BDMAP_00003178 +BDMAP_00003168 +BDMAP_00000465 +BDMAP_00003603 +BDMAP_00002776 +BDMAP_00000480 +BDMAP_00003822 +BDMAP_00004113 +BDMAP_00002695 +BDMAP_00003513 +BDMAP_00001590 +BDMAP_00000826 +BDMAP_00002403 +BDMAP_00001169 +BDMAP_00002661 +BDMAP_00003920 +BDMAP_00000122 +BDMAP_00004130 +BDMAP_00002133 +BDMAP_00002612 +BDMAP_00003923 +BDMAP_00004278 +BDMAP_00004888 +BDMAP_00002422 +BDMAP_00004639 +BDMAP_00002856 +BDMAP_00001907 +BDMAP_00004175 +BDMAP_00002896 +BDMAP_00004257 +BDMAP_00003017 +BDMAP_00004509 +BDMAP_00003377 +BDMAP_00001704 +BDMAP_00002283 +BDMAP_00004664 +BDMAP_00001305 +BDMAP_00004481 +BDMAP_00000696 +BDMAP_00000716 +BDMAP_00002807 +BDMAP_00003608 +BDMAP_00000881 +BDMAP_00004561 +BDMAP_00001027 +BDMAP_00003002 +BDMAP_00002361 +BDMAP_00002289 +BDMAP_00000159 +BDMAP_00000809 +BDMAP_00003918 +BDMAP_00001636 +BDMAP_00003153 +BDMAP_00000413 +BDMAP_00000137 +BDMAP_00002472 +BDMAP_00001281 +BDMAP_00000965 +BDMAP_00002226 +BDMAP_00001605 +BDMAP_00003347 +BDMAP_00002471 +BDMAP_00002582 +BDMAP_00002114 +BDMAP_00005083 +BDMAP_00000438 +BDMAP_00002354 +BDMAP_00003580 +BDMAP_00003315 +BDMAP_00003612 +BDMAP_00004829 +BDMAP_00004395 +BDMAP_00000709 +BDMAP_00000273 +BDMAP_00004636 +BDMAP_00001732 +BDMAP_00004331 +BDMAP_00001868 +BDMAP_00001214 +BDMAP_00001275 +BDMAP_00001809 +BDMAP_00004374 +BDMAP_00005009 +BDMAP_00001807 +BDMAP_00004294 +BDMAP_00004499 +BDMAP_00001251 +BDMAP_00004457 +BDMAP_00002495 +BDMAP_00001331 +BDMAP_00000481 +BDMAP_00000236 +BDMAP_00001862 +BDMAP_00002288 +BDMAP_00004620 +BDMAP_00001122 +BDMAP_00000882 +BDMAP_00002164 +BDMAP_00004196 +BDMAP_00003384 +BDMAP_00001710 +BDMAP_00003701 +BDMAP_00000607 +BDMAP_00000161 +BDMAP_00004065 +BDMAP_00003031 +BDMAP_00002216 +BDMAP_00001995 +BDMAP_00001584 +BDMAP_00000066 +BDMAP_00004475 +BDMAP_00001620 +BDMAP_00003658 +BDMAP_00003615 +BDMAP_00005113 +BDMAP_00004903 +BDMAP_00001125 +BDMAP_00003484 +BDMAP_00001325 +BDMAP_00000036 +BDMAP_00001370 +BDMAP_00002387 +BDMAP_00002396 +BDMAP_00003514 +BDMAP_00002918 +BDMAP_00004990 +BDMAP_00004106 +BDMAP_00000321 +BDMAP_00000713 +BDMAP_00002363 +BDMAP_00001445 +BDMAP_00000980 +BDMAP_00002485 +BDMAP_00002260 +BDMAP_00000388 +BDMAP_00001476 +BDMAP_00002592 +BDMAP_00003058 +BDMAP_00003364 +BDMAP_00000810 +BDMAP_00003329 +BDMAP_00001891 +BDMAP_00000117 +BDMAP_00001283 +BDMAP_00001128 +BDMAP_00005114 +BDMAP_00000692 +BDMAP_00000190 +BDMAP_00004579 +BDMAP_00005174 +BDMAP_00002690 +BDMAP_00004231 +BDMAP_00000219 +BDMAP_00002846 +BDMAP_00002057 +BDMAP_00001518 +BDMAP_00000589 +BDMAP_00003482 +BDMAP_00004817 +BDMAP_00003633 +BDMAP_00003890 +BDMAP_00002401 +BDMAP_00001223 +BDMAP_00004017 +BDMAP_00003400 +BDMAP_00000091 +BDMAP_00003363 +BDMAP_00004839 +BDMAP_00002383 +BDMAP_00004927 +BDMAP_00002451 +BDMAP_00004815 +BDMAP_00004783 +BDMAP_00005157 +BDMAP_00002373 +BDMAP_00001736 +BDMAP_00004943 +BDMAP_00004015 +BDMAP_00004773 +BDMAP_00001522 +BDMAP_00002171 +BDMAP_00002945 +BDMAP_00002990 +BDMAP_00001802 +BDMAP_00002326 +BDMAP_00000069 +BDMAP_00002185 +BDMAP_00001093 +BDMAP_00001487 +BDMAP_00001456 +BDMAP_00001045 +BDMAP_00001024 +BDMAP_00004615 +BDMAP_00000232 +BDMAP_00003722 +BDMAP_00001383 +BDMAP_00003267 +BDMAP_00002844 +BDMAP_00000030 +BDMAP_00001288 +BDMAP_00001483 +BDMAP_00000437 +BDMAP_00002855 +BDMAP_00003427 +BDMAP_00000771 +BDMAP_00004185 +BDMAP_00003740 +BDMAP_00004841 +BDMAP_00000062 +BDMAP_00004546 +BDMAP_00000662 +BDMAP_00002663 +BDMAP_00000936 +BDMAP_00002758 +BDMAP_00001892 +BDMAP_00002609 +BDMAP_00001982 +BDMAP_00005167 +BDMAP_00001945 +BDMAP_00001102 +BDMAP_00005170 +BDMAP_00000982 +BDMAP_00004129 +BDMAP_00001875 +BDMAP_00004735 +BDMAP_00000366 +BDMAP_00001175 +BDMAP_00002902 +BDMAP_00003558 +BDMAP_00002476 +BDMAP_00003694 +BDMAP_00000304 +BDMAP_00000225 +BDMAP_00002411 +BDMAP_00002304 +BDMAP_00000452 +BDMAP_00003598 +BDMAP_00001212 +BDMAP_00000683 +BDMAP_00005075 +BDMAP_00000162 +BDMAP_00002748 +BDMAP_00005099 +BDMAP_00002854 +BDMAP_00001289 +BDMAP_00000714 +BDMAP_00003849 +BDMAP_00003268 +BDMAP_00002529 +BDMAP_00001258 +BDMAP_00003438 +BDMAP_00000571 +BDMAP_00003853 +BDMAP_00003744 +BDMAP_00002829 +BDMAP_00000364 +BDMAP_00004039 +BDMAP_00000774 +BDMAP_00001834 +BDMAP_00001183 +BDMAP_00002458 +BDMAP_00004511 +BDMAP_00003255 +BDMAP_00003976 +BDMAP_00001924 +BDMAP_00004804 +BDMAP_00004163 +BDMAP_00001646 +BDMAP_00000435 +BDMAP_00002347 +BDMAP_00004297 +BDMAP_00002184 +BDMAP_00004712 +BDMAP_00003683 +BDMAP_00003657 +BDMAP_00004885 +BDMAP_00002947 +BDMAP_00002545 +BDMAP_00001119 +BDMAP_00001754 +BDMAP_00002267 +BDMAP_00003202 +BDMAP_00005108 +BDMAP_00001265 +BDMAP_00001092 +BDMAP_00004253 +BDMAP_00001563 +BDMAP_00001966 +BDMAP_00004304 +BDMAP_00000197 +BDMAP_00001273 +BDMAP_00003867 +BDMAP_00000859 +BDMAP_00001649 +BDMAP_00001664 +BDMAP_00003833 +BDMAP_00002710 +BDMAP_00001791 +BDMAP_00003932 +BDMAP_00002523 +BDMAP_00001632 +BDMAP_00002863 +BDMAP_00003762 +BDMAP_00001040 +BDMAP_00003971 +BDMAP_00005097 +BDMAP_00001845 +BDMAP_00000989 +BDMAP_00003672 +BDMAP_00001114 +BDMAP_00002742 +BDMAP_00004373 +BDMAP_00004850 +BDMAP_00002278 +BDMAP_00001701 +BDMAP_00001804 +BDMAP_00002349 +BDMAP_00002167 +BDMAP_00002265 +BDMAP_00004417 +BDMAP_00000245 +BDMAP_00005022 +BDMAP_00000871 +BDMAP_00002803 +BDMAP_00000656 +BDMAP_00001095 +BDMAP_00003506 +BDMAP_00003359 +BDMAP_00005141 +BDMAP_00001617 +BDMAP_00002479 +BDMAP_00000778 +BDMAP_00000113 +BDMAP_00000439 +BDMAP_00003409 +BDMAP_00003769 +BDMAP_00001025 +BDMAP_00000469 +BDMAP_00002841 +BDMAP_00001906 +BDMAP_00002426 +BDMAP_00004228 +BDMAP_00000616 +BDMAP_00000547 +BDMAP_00002440 +BDMAP_00002188 +BDMAP_00002484 +BDMAP_00003385 +BDMAP_00001261 +BDMAP_00001441 +BDMAP_00001324 +BDMAP_00003549 +BDMAP_00002465 +BDMAP_00004014 +BDMAP_00000432 +BDMAP_00001067 +BDMAP_00001001 +BDMAP_00000940 +BDMAP_00004597 +BDMAP_00001104 +BDMAP_00001296 +BDMAP_00002562 +BDMAP_00001692 +BDMAP_00005151 +BDMAP_00000883 +BDMAP_00001533 +BDMAP_00001921 +BDMAP_00002410 +BDMAP_00002237 +BDMAP_00002328 +BDMAP_00003614 +BDMAP_00000562 +BDMAP_00001237 +BDMAP_00003333 +BDMAP_00004847 +BDMAP_00005119 +BDMAP_00003277 +BDMAP_00005120 +BDMAP_00005081 +BDMAP_00001607 +BDMAP_00001523 +BDMAP_00005017 +BDMAP_00001010 +BDMAP_00001126 +BDMAP_00001957 +BDMAP_00003776 +BDMAP_00000368 +BDMAP_00002199 +BDMAP_00000956 +BDMAP_00001752 +BDMAP_00005168 +BDMAP_00000205 +BDMAP_00002309 +BDMAP_00002419 +BDMAP_00000093 +BDMAP_00000698 +BDMAP_00004917 +BDMAP_00000434 +BDMAP_00004867 +BDMAP_00000429 +BDMAP_00003947 +BDMAP_00004030 +BDMAP_00001270 +BDMAP_00002402 +BDMAP_00000972 +BDMAP_00003330 +BDMAP_00003244 +BDMAP_00001200 +BDMAP_00000149 +BDMAP_00003252 +BDMAP_00002029 +BDMAP_00000154 +BDMAP_00002940 +BDMAP_00000152 +BDMAP_00001471 +BDMAP_00002737 +BDMAP_00000023 +BDMAP_00002251 +BDMAP_00000701 +BDMAP_00002166 +BDMAP_00001236 +BDMAP_00000329 +BDMAP_00000642 +BDMAP_00001397 +BDMAP_00003435 +BDMAP_00000913 +BDMAP_00005092 +BDMAP_00004925 +BDMAP_00003412 +BDMAP_00003957 +BDMAP_00003897 +BDMAP_00004398 +BDMAP_00001539 +BDMAP_00001911 +BDMAP_00002421 +BDMAP_00004745 +BDMAP_00002318 +BDMAP_00000470 +BDMAP_00002889 +BDMAP_00001912 +BDMAP_00003326 +BDMAP_00002275 +BDMAP_00002227 +BDMAP_00000926 +BDMAP_00004187 +BDMAP_00001148 +BDMAP_00003376 +BDMAP_00003774 +BDMAP_00003857 +BDMAP_00003650 +BDMAP_00005078 +BDMAP_00003151 +BDMAP_00001242 +BDMAP_00003215 +BDMAP_00000676 +BDMAP_00003396 +BDMAP_00003479 +BDMAP_00003781 +BDMAP_00005070 +BDMAP_00003631 +BDMAP_00003840 +BDMAP_00003640 +BDMAP_00000347 +BDMAP_00004645 +BDMAP_00000715 +BDMAP_00002871 +BDMAP_00004834 +BDMAP_00004493 +BDMAP_00001828 +BDMAP_00001565 +BDMAP_00000902 +BDMAP_00001908 +BDMAP_00002688 +BDMAP_00003130 +BDMAP_00000971 +BDMAP_00000192 +BDMAP_00002924 +BDMAP_00002845 +BDMAP_00000660 +BDMAP_00000324 +BDMAP_00004895 +BDMAP_00002751 +BDMAP_00001474 +BDMAP_00001218 +BDMAP_00001130 +BDMAP_00001697 +BDMAP_00002498 +BDMAP_00001768 +BDMAP_00000233 +BDMAP_00004416 +BDMAP_00003138 +BDMAP_00000138 +BDMAP_00004508 +BDMAP_00001514 +BDMAP_00000243 +BDMAP_00001747 +BDMAP_00002487 +BDMAP_00003943 +BDMAP_00000043 +BDMAP_00001835 +BDMAP_00002233 +BDMAP_00004897 +BDMAP_00001230 +BDMAP_00004956 +BDMAP_00005191 +BDMAP_00001444 +BDMAP_00002117 +BDMAP_00001598 +BDMAP_00000087 +BDMAP_00000725 +BDMAP_00004552 +BDMAP_00005064 +BDMAP_00003111 +BDMAP_00004420 +BDMAP_00004293 +BDMAP_00000449 +BDMAP_00001905 +BDMAP_00003569 +BDMAP_00005005 +BDMAP_00004600 +BDMAP_00001766 +BDMAP_00001656 +BDMAP_00000345 +BDMAP_00001753 +BDMAP_00004028 +BDMAP_00000084 +BDMAP_00002253 +BDMAP_00004808 +BDMAP_00003052 +BDMAP_00002362 +BDMAP_00004435 +BDMAP_00004964 +BDMAP_00000516 +BDMAP_00004876 +BDMAP_00004651 +BDMAP_00000431 +BDMAP_00002022 +BDMAP_00001316 +BDMAP_00002359 +BDMAP_00004147 +BDMAP_00004264 +BDMAP_00004980 +BDMAP_00003685 +BDMAP_00004384 +BDMAP_00004199 +BDMAP_00002791 +BDMAP_00002120 +BDMAP_00002244 +BDMAP_00004462 +BDMAP_00000279 +BDMAP_00004676 +BDMAP_00000569 +BDMAP_00001517 +BDMAP_00004450 +BDMAP_00000414 +BDMAP_00000582 +BDMAP_00004558 +BDMAP_00001712 +BDMAP_00004796 +BDMAP_00004295 +BDMAP_00001842 +BDMAP_00001422 +BDMAP_00003036 +BDMAP_00001419 +BDMAP_00003576 +BDMAP_00000331 +BDMAP_00001225 +BDMAP_00004673 +BDMAP_00000977 +BDMAP_00000044 +BDMAP_00001826 +BDMAP_00001440 +BDMAP_00000574 +BDMAP_00004672 +BDMAP_00004830 +BDMAP_00004077 +BDMAP_00004793 +BDMAP_00004074 +BDMAP_00000139 +BDMAP_00003356 +BDMAP_00003713 +BDMAP_00003254 +BDMAP_00001333 +BDMAP_00004023 +BDMAP_00004880 +BDMAP_00002981 +BDMAP_00005160 +BDMAP_00001096 +BDMAP_00003109 +BDMAP_00003063 +BDMAP_00003973 +BDMAP_00004719 +BDMAP_00000542 +BDMAP_00004491 +BDMAP_00002172 +BDMAP_00000907 +BDMAP_00005154 +BDMAP_00003827 +BDMAP_00004541 +BDMAP_00003493 +BDMAP_00003461 +BDMAP_00000338 +BDMAP_00004016 +BDMAP_00002815 +BDMAP_00002805 +BDMAP_00000918 +BDMAP_00003141 +BDMAP_00001564 +BDMAP_00003392 +BDMAP_00000939 +BDMAP_00001368 +BDMAP_00004549 +BDMAP_00001707 +BDMAP_00001475 +BDMAP_00002232 +BDMAP_00000923 +BDMAP_00004104 +BDMAP_00004608 +BDMAP_00004825 +BDMAP_00001209 +BDMAP_00005185 +BDMAP_00002696 +BDMAP_00000828 +BDMAP_00001059 +BDMAP_00001647 +BDMAP_00000039 +BDMAP_00000935 +BDMAP_00002712 +BDMAP_00003451 +BDMAP_00000059 +BDMAP_00003516 +BDMAP_00002295 +BDMAP_00001516 +BDMAP_00002319 +BDMAP_00001077 +BDMAP_00003581 +BDMAP_00002884 +BDMAP_00003324 +BDMAP_00000128 +BDMAP_00002959 +BDMAP_00000411 +BDMAP_00003717 +BDMAP_00004995 +BDMAP_00000653 +BDMAP_00004031 +BDMAP_00003590 +BDMAP_00001215 +BDMAP_00001256 +BDMAP_00002273 +BDMAP_00000667 +BDMAP_00000373 +BDMAP_00003680 +BDMAP_00001784 +BDMAP_00001286 +BDMAP_00001246 +BDMAP_00003440 +BDMAP_00002656 +BDMAP_00003955 +BDMAP_00003930 +BDMAP_00001985 +BDMAP_00004328 +BDMAP_00004744 +BDMAP_00004529 +BDMAP_00004447 +BDMAP_00002252 +BDMAP_00003994 +BDMAP_00001711 +BDMAP_00000355 +BDMAP_00001836 +BDMAP_00003448 +BDMAP_00000855 +BDMAP_00002039 +BDMAP_00005063 +BDMAP_00004286 +BDMAP_00001823 +BDMAP_00002407 +BDMAP_00002933 +BDMAP_00003928 +BDMAP_00000447 +BDMAP_00003411 +BDMAP_00004641 +BDMAP_00003886 +BDMAP_00000240 +BDMAP_00001917 +BDMAP_00003952 +BDMAP_00001464 +BDMAP_00000614 +BDMAP_00003491 +BDMAP_00004427 +BDMAP_00004131 +BDMAP_00004011 +BDMAP_00000297 +BDMAP_00001511 +BDMAP_00000812 +BDMAP_00005020 +BDMAP_00004060 +BDMAP_00002496 +BDMAP_00003455 +BDMAP_00005169 +BDMAP_00000462 +BDMAP_00001502 +BDMAP_00000558 +BDMAP_00004216 +BDMAP_00000244 +BDMAP_00001602 +BDMAP_00003073 +BDMAP_00001618 +BDMAP_00000839 +BDMAP_00002333 +BDMAP_00002298 +BDMAP_00000873 +BDMAP_00001521 +BDMAP_00003946 +BDMAP_00000690 +BDMAP_00004969 +BDMAP_00000320 +BDMAP_00003074 +BDMAP_00004154 +BDMAP_00001420 +BDMAP_00002826 +BDMAP_00002076 +BDMAP_00002021 +BDMAP_00000837 +BDMAP_00000968 +BDMAP_00001138 +BDMAP_00002524 +BDMAP_00000532 +BDMAP_00002250 +BDMAP_00002282 +BDMAP_00003281 +BDMAP_00004738 +BDMAP_00004389 +BDMAP_00004922 +BDMAP_00002305 +BDMAP_00003070 +BDMAP_00002793 +BDMAP_00002986 +BDMAP_00000623 +BDMAP_00001794 +BDMAP_00002475 +BDMAP_00004415 +BDMAP_00001898 +BDMAP_00002936 +BDMAP_00003443 +BDMAP_00004550 +BDMAP_00004479 +BDMAP_00002041 +BDMAP_00001806 +BDMAP_00002509 +BDMAP_00002616 +BDMAP_00005065 +BDMAP_00005085 +BDMAP_00001379 +BDMAP_00003911 +BDMAP_00002707 +BDMAP_00004097 +BDMAP_00003128 +BDMAP_00003996 +BDMAP_00000626 +BDMAP_00000263 +BDMAP_00001549 +BDMAP_00000229 +BDMAP_00001688 +BDMAP_00002313 +BDMAP_00003319 +BDMAP_00003343 +BDMAP_00004624 +BDMAP_00001737 +BDMAP_00001624 +BDMAP_00003358 +BDMAP_00000998 +BDMAP_00004195 +BDMAP_00001941 +BDMAP_00004870 +BDMAP_00000948 +BDMAP_00001496 +BDMAP_00000687 +BDMAP_00004033 +BDMAP_00001068 +BDMAP_00003520 +BDMAP_00000941 +BDMAP_00000867 +BDMAP_00000264 +BDMAP_00005067 +BDMAP_00000132 +BDMAP_00004650 +BDMAP_00003736 +BDMAP_00003564 +BDMAP_00001635 +BDMAP_00003898 +BDMAP_00004901 +BDMAP_00000400 +BDMAP_00004671 +BDMAP_00000353 +BDMAP_00001089 +BDMAP_00000572 +BDMAP_00002953 +BDMAP_00003600 +BDMAP_00003798 +BDMAP_00000987 +BDMAP_00000541 +BDMAP_00004717 +BDMAP_00002068 +BDMAP_00001977 +BDMAP_00002942 +BDMAP_00000416 +BDMAP_00002580 +BDMAP_00001410 +BDMAP_00000052 +BDMAP_00003361 +BDMAP_00001247 +BDMAP_00004894 +BDMAP_00002060 +BDMAP_00000319 +BDMAP_00004407 +BDMAP_00002099 +BDMAP_00004431 +BDMAP_00003225 +BDMAP_00003236 +BDMAP_00004981 +BDMAP_00000671 +BDMAP_00003444 +BDMAP_00003525 +BDMAP_00000259 +BDMAP_00003497 +BDMAP_00003767 +BDMAP_00004184 +BDMAP_00003524 +BDMAP_00000942 +BDMAP_00002719 +BDMAP_00004232 +BDMAP_00005186 +BDMAP_00003900 diff --git a/Generation_Pipeline_filter_all/best_metric_model_classification3d_dict.pth b/Generation_Pipeline_filter_all/best_metric_model_classification3d_dict.pth new file mode 100644 index 0000000000000000000000000000000000000000..d60083fe12376e385f304bcac9f1d17ae0742aba --- /dev/null +++ b/Generation_Pipeline_filter_all/best_metric_model_classification3d_dict.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c435f0e8efa89f02ac5fb0bd1092608332f2de6f8f7aabbe17fcc6d1a83d35c6 +size 45594067 diff --git a/Generation_Pipeline_filter_all/get_syn_list.py b/Generation_Pipeline_filter_all/get_syn_list.py new file mode 100644 index 0000000000000000000000000000000000000000..a03571f9f096b48bd73a6cd8236435aac77dca75 --- /dev/null +++ b/Generation_Pipeline_filter_all/get_syn_list.py @@ -0,0 +1,26 @@ +import os + +organ = 'colon' +real_organ = [] +with open(f'real_set/{organ}.txt', 'r') as f: + real_organ=f.readlines() +real_organ = [i.split('\n')[0] for i in real_organ] + + +total_case = [] +with open(f'real_total.txt', 'r') as f: + total_case=f.readlines() +total_case = [i.split('\n')[0] for i in total_case] + + +absence2= list(set(total_case) - set(real_organ)) +absence2 = [i for i in absence2] +# breakpoint() + +filename = open(f'syn_{organ}/healthy_{organ}_1k.txt','a+')#dict转txt +for i in absence2: + filename.write(i) + filename.write('\n') +filename.close() + + diff --git a/Generation_Pipeline_filter_all/get_training_list.py b/Generation_Pipeline_filter_all/get_training_list.py new file mode 100644 index 0000000000000000000000000000000000000000..dba70b6bd3d7081cf58d741719b87d4ce4170511 --- /dev/null +++ b/Generation_Pipeline_filter_all/get_training_list.py @@ -0,0 +1,45 @@ +import os + +total_case = [] +with open(f'real_total.txt', 'r') as f: + total_case=f.readlines() +total_case = [i.split('\n')[0] for i in total_case] + + +real_organ = [] +with open(f'val_set/bodymap_liver.txt', 'r') as f: + real_organ=f.readlines() +real_organ = [i.split('\n')[0] for i in real_organ] +total_case= list(set(total_case) - set(real_organ)) +total_case = [i for i in total_case] + +real_organ = [] +with open(f'val_set/bodymap_pancreas.txt', 'r') as f: + real_organ=f.readlines() +real_organ = [i.split('\n')[0] for i in real_organ] +total_case= list(set(total_case) - set(real_organ)) +total_case = [i for i in total_case] + +real_organ = [] +with open(f'val_set/bodymap_kidney.txt', 'r') as f: + real_organ=f.readlines() +real_organ = [i.split('\n')[0] for i in real_organ] +total_case= list(set(total_case) - set(real_organ)) +total_case = [i for i in total_case] + +real_organ = [] +with open(f'val_set/bodymap_colon.txt', 'r') as f: + real_organ=f.readlines() +real_organ = [i.split('\n')[0] for i in real_organ] +total_case= list(set(total_case) - set(real_organ)) +total_case = [i for i in total_case] + + + +filename = open(f'Atlas_X_1k.txt','a+')#dict转txt +for i in total_case: + filename.write(i) + filename.write('\n') +filename.close() + + diff --git a/Generation_Pipeline_filter_all/real_set/colon.txt b/Generation_Pipeline_filter_all/real_set/colon.txt new file mode 100644 index 0000000000000000000000000000000000000000..c19f373284fab2de8fa90f1ccc891b5b514d9b72 --- /dev/null +++ b/Generation_Pipeline_filter_all/real_set/colon.txt @@ -0,0 +1,126 @@ +BDMAP_00001078 +BDMAP_00003031 +BDMAP_00002253 +BDMAP_00001732 +BDMAP_00000874 +BDMAP_00003847 +BDMAP_00003268 +BDMAP_00002846 +BDMAP_00001438 +BDMAP_00004650 +BDMAP_00003109 +BDMAP_00004121 +BDMAP_00004165 +BDMAP_00004676 +BDMAP_00003890 +BDMAP_00003327 +BDMAP_00000132 +BDMAP_00001215 +BDMAP_00001769 +BDMAP_00003412 +BDMAP_00002318 +BDMAP_00004624 +BDMAP_00000345 +BDMAP_00002230 +BDMAP_00003111 +BDMAP_00001015 +BDMAP_00001514 +BDMAP_00001924 +BDMAP_00002845 +BDMAP_00002598 +BDMAP_00001209 +BDMAP_00000373 +BDMAP_00001737 +BDMAP_00003113 +BDMAP_00004876 +BDMAP_00003640 +BDMAP_00001985 +BDMAP_00000138 +BDMAP_00000881 +BDMAP_00002739 +BDMAP_00003560 +BDMAP_00002612 +BDMAP_00001445 +BDMAP_00003827 +BDMAP_00001024 +BDMAP_00000568 +BDMAP_00001095 +BDMAP_00002458 +BDMAP_00002986 +BDMAP_00000913 +BDMAP_00000264 +BDMAP_00000690 +BDMAP_00002039 +BDMAP_00001426 +BDMAP_00002730 +BDMAP_00001917 +BDMAP_00005067 +BDMAP_00002924 +BDMAP_00005160 +BDMAP_00005073 +BDMAP_00000547 +BDMAP_00000942 +BDMAP_00002103 +BDMAP_00002654 +BDMAP_00004374 +BDMAP_00003510 +BDMAP_00004910 +BDMAP_00004558 +BDMAP_00004450 +BDMAP_00000152 +BDMAP_00004491 +BDMAP_00001237 +BDMAP_00001785 +BDMAP_00001865 +BDMAP_00000851 +BDMAP_00003357 +BDMAP_00004415 +BDMAP_00004615 +BDMAP_00003680 +BDMAP_00001875 +BDMAP_00004894 +BDMAP_00001835 +BDMAP_00000069 +BDMAP_00001809 +BDMAP_00004431 +BDMAP_00002704 +BDMAP_00002185 +BDMAP_00004384 +BDMAP_00003299 +BDMAP_00003333 +BDMAP_00002305 +BDMAP_00001598 +BDMAP_00002465 +BDMAP_00002199 +BDMAP_00002875 +BDMAP_00000828 +BDMAP_00003564 +BDMAP_00005001 +BDMAP_00004493 +BDMAP_00000190 +BDMAP_00000873 +BDMAP_00005170 +BDMAP_00002152 +BDMAP_00004163 +BDMAP_00000939 +BDMAP_00001212 +BDMAP_00001982 +BDMAP_00000552 +BDMAP_00004764 +BDMAP_00002401 +BDMAP_00002451 +BDMAP_00003634 +BDMAP_00005016 +BDMAP_00000716 +BDMAP_00003373 +BDMAP_00000030 +BDMAP_00003946 +BDMAP_00002828 +BDMAP_00004196 +BDMAP_00005005 +BDMAP_00003972 +BDMAP_00003172 +BDMAP_00004783 +BDMAP_00001102 +BDMAP_00004147 +BDMAP_00004604 diff --git a/Generation_Pipeline_filter_all/real_set/kidney.txt b/Generation_Pipeline_filter_all/real_set/kidney.txt new file mode 100644 index 0000000000000000000000000000000000000000..dc40eda19f629c971abb354633257a95ca7cea14 --- /dev/null +++ b/Generation_Pipeline_filter_all/real_set/kidney.txt @@ -0,0 +1,489 @@ +BDMAP_00000245 +BDMAP_00000036 +BDMAP_00003833 +BDMAP_00001517 +BDMAP_00004087 +BDMAP_00002807 +BDMAP_00002099 +BDMAP_00001602 +BDMAP_00001035 +BDMAP_00002422 +BDMAP_00000626 +BDMAP_00002173 +BDMAP_00000240 +BDMAP_00001246 +BDMAP_00000582 +BDMAP_00003996 +BDMAP_00001707 +BDMAP_00000923 +BDMAP_00003411 +BDMAP_00004113 +BDMAP_00002582 +BDMAP_00001261 +BDMAP_00005167 +BDMAP_00004897 +BDMAP_00001169 +BDMAP_00001148 +BDMAP_00002164 +BDMAP_00002041 +BDMAP_00000889 +BDMAP_00001109 +BDMAP_00005009 +BDMAP_00001286 +BDMAP_00000297 +BDMAP_00005099 +BDMAP_00004257 +BDMAP_00005017 +BDMAP_00000604 +BDMAP_00002472 +BDMAP_00001225 +BDMAP_00005081 +BDMAP_00003491 +BDMAP_00001635 +BDMAP_00002075 +BDMAP_00000660 +BDMAP_00001238 +BDMAP_00002656 +BDMAP_00003558 +BDMAP_00001104 +BDMAP_00004066 +BDMAP_00003294 +BDMAP_00001607 +BDMAP_00001077 +BDMAP_00000653 +BDMAP_00001273 +BDMAP_00000616 +BDMAP_00002057 +BDMAP_00004586 +BDMAP_00004407 +BDMAP_00004922 +BDMAP_00002592 +BDMAP_00000149 +BDMAP_00000320 +BDMAP_00001511 +BDMAP_00000435 +BDMAP_00002746 +BDMAP_00004457 +BDMAP_00000805 +BDMAP_00002661 +BDMAP_00004552 +BDMAP_00004154 +BDMAP_00002902 +BDMAP_00000839 +BDMAP_00000233 +BDMAP_00000122 +BDMAP_00005151 +BDMAP_00004427 +BDMAP_00002936 +BDMAP_00003955 +BDMAP_00001863 +BDMAP_00002326 +BDMAP_00001420 +BDMAP_00000329 +BDMAP_00004561 +BDMAP_00003971 +BDMAP_00000935 +BDMAP_00000569 +BDMAP_00004956 +BDMAP_00000285 +BDMAP_00004597 +BDMAP_00001747 +BDMAP_00001059 +BDMAP_00002354 +BDMAP_00001656 +BDMAP_00004395 +BDMAP_00002942 +BDMAP_00004981 +BDMAP_00001768 +BDMAP_00002319 +BDMAP_00003947 +BDMAP_00001868 +BDMAP_00002065 +BDMAP_00002333 +BDMAP_00003358 +BDMAP_00001265 +BDMAP_00003952 +BDMAP_00001891 +BDMAP_00003576 +BDMAP_00000980 +BDMAP_00003300 +BDMAP_00001782 +BDMAP_00003717 +BDMAP_00001251 +BDMAP_00000044 +BDMAP_00004510 +BDMAP_00003315 +BDMAP_00002653 +BDMAP_00001045 +BDMAP_00003694 +BDMAP_00004216 +BDMAP_00001794 +BDMAP_00000532 +BDMAP_00002288 +BDMAP_00001256 +BDMAP_00000219 +BDMAP_00000710 +BDMAP_00003930 +BDMAP_00001636 +BDMAP_00003749 +BDMAP_00000998 +BDMAP_00000176 +BDMAP_00000429 +BDMAP_00001001 +BDMAP_00001908 +BDMAP_00003363 +BDMAP_00004903 +BDMAP_00004482 +BDMAP_00003178 +BDMAP_00003202 +BDMAP_00001230 +BDMAP_00003461 +BDMAP_00003281 +BDMAP_00000434 +BDMAP_00001218 +BDMAP_00003976 +BDMAP_00003455 +BDMAP_00001183 +BDMAP_00002609 +BDMAP_00001305 +BDMAP_00000364 +BDMAP_00003516 +BDMAP_00003956 +BDMAP_00000977 +BDMAP_00001784 +BDMAP_00004389 +BDMAP_00001711 +BDMAP_00000698 +BDMAP_00003153 +BDMAP_00001995 +BDMAP_00001549 +BDMAP_00001324 +BDMAP_00004195 +BDMAP_00001562 +BDMAP_00004074 +BDMAP_00001483 +BDMAP_00002085 +BDMAP_00001396 +BDMAP_00000241 +BDMAP_00004031 +BDMAP_00004775 +BDMAP_00001807 +BDMAP_00005120 +BDMAP_00004065 +BDMAP_00003943 +BDMAP_00002953 +BDMAP_00004232 +BDMAP_00002184 +BDMAP_00002407 +BDMAP_00003252 +BDMAP_00004296 +BDMAP_00000161 +BDMAP_00002981 +BDMAP_00003608 +BDMAP_00003128 +BDMAP_00000571 +BDMAP_00000259 +BDMAP_00003444 +BDMAP_00001647 +BDMAP_00000662 +BDMAP_00003774 +BDMAP_00001383 +BDMAP_00004616 +BDMAP_00001906 +BDMAP_00003740 +BDMAP_00001422 +BDMAP_00002631 +BDMAP_00004294 +BDMAP_00003994 +BDMAP_00004475 +BDMAP_00002744 +BDMAP_00001068 +BDMAP_00000667 +BDMAP_00001945 +BDMAP_00002710 +BDMAP_00002440 +BDMAP_00000833 +BDMAP_00003143 +BDMAP_00000062 +BDMAP_00003392 +BDMAP_00004373 +BDMAP_00001020 +BDMAP_00003603 +BDMAP_00001027 +BDMAP_00005114 +BDMAP_00003384 +BDMAP_00000794 +BDMAP_00001911 +BDMAP_00002437 +BDMAP_00004579 +BDMAP_00004250 +BDMAP_00002068 +BDMAP_00000608 +BDMAP_00004551 +BDMAP_00002884 +BDMAP_00004033 +BDMAP_00005105 +BDMAP_00002776 +BDMAP_00000414 +BDMAP_00003580 +BDMAP_00004712 +BDMAP_00002114 +BDMAP_00002226 +BDMAP_00003923 +BDMAP_00002854 +BDMAP_00004039 +BDMAP_00004014 +BDMAP_00001289 +BDMAP_00003435 +BDMAP_00004578 +BDMAP_00002940 +BDMAP_00003164 +BDMAP_00002751 +BDMAP_00001516 +BDMAP_00003486 +BDMAP_00000279 +BDMAP_00001664 +BDMAP_00004738 +BDMAP_00001735 +BDMAP_00000562 +BDMAP_00000812 +BDMAP_00000511 +BDMAP_00004746 +BDMAP_00000452 +BDMAP_00004328 +BDMAP_00002017 +BDMAP_00002840 +BDMAP_00000039 +BDMAP_00002242 +BDMAP_00002775 +BDMAP_00003762 +BDMAP_00000229 +BDMAP_00003520 +BDMAP_00000725 +BDMAP_00000516 +BDMAP_00001941 +BDMAP_00003928 +BDMAP_00001255 +BDMAP_00001456 +BDMAP_00002410 +BDMAP_00002742 +BDMAP_00001688 +BDMAP_00000487 +BDMAP_00000469 +BDMAP_00002022 +BDMAP_00003058 +BDMAP_00004148 +BDMAP_00001977 +BDMAP_00000887 +BDMAP_00003448 +BDMAP_00001410 +BDMAP_00002383 +BDMAP_00003736 +BDMAP_00002626 +BDMAP_00001710 +BDMAP_00001130 +BDMAP_00001138 +BDMAP_00001413 +BDMAP_00003815 +BDMAP_00004130 +BDMAP_00004652 +BDMAP_00002864 +BDMAP_00000574 +BDMAP_00003493 +BDMAP_00003364 +BDMAP_00002648 +BDMAP_00001281 +BDMAP_00002655 +BDMAP_00001126 +BDMAP_00002804 +BDMAP_00000321 +BDMAP_00005191 +BDMAP_00004420 +BDMAP_00000304 +BDMAP_00003150 +BDMAP_00004620 +BDMAP_00000368 +BDMAP_00000066 +BDMAP_00003701 +BDMAP_00005174 +BDMAP_00002545 +BDMAP_00003957 +BDMAP_00004331 +BDMAP_00000687 +BDMAP_00001791 +BDMAP_00002959 +BDMAP_00004104 +BDMAP_00003073 +BDMAP_00003713 +BDMAP_00002363 +BDMAP_00000137 +BDMAP_00000104 +BDMAP_00002689 +BDMAP_00004990 +BDMAP_00003301 +BDMAP_00001434 +BDMAP_00000449 +BDMAP_00005113 +BDMAP_00003225 +BDMAP_00001359 +BDMAP_00001223 +BDMAP_00002803 +BDMAP_00000355 +BDMAP_00001826 +BDMAP_00004673 +BDMAP_00002251 +BDMAP_00000439 +BDMAP_00005085 +BDMAP_00003381 +BDMAP_00004645 +BDMAP_00000432 +BDMAP_00001444 +BDMAP_00001705 +BDMAP_00001892 +BDMAP_00002826 +BDMAP_00004671 +BDMAP_00000926 +BDMAP_00004817 +BDMAP_00004175 +BDMAP_00003484 +BDMAP_00003672 +BDMAP_00003267 +BDMAP_00001089 +BDMAP_00001496 +BDMAP_00003615 +BDMAP_00003832 +BDMAP_00002695 +BDMAP_00002696 +BDMAP_00004499 +BDMAP_00004867 +BDMAP_00004479 +BDMAP_00003600 +BDMAP_00000989 +BDMAP_00002421 +BDMAP_00003406 +BDMAP_00000263 +BDMAP_00002396 +BDMAP_00002265 +BDMAP_00000713 +BDMAP_00000883 +BDMAP_00001258 +BDMAP_00004253 +BDMAP_00004870 +BDMAP_00000331 +BDMAP_00004608 +BDMAP_00001518 +BDMAP_00002562 +BDMAP_00002889 +BDMAP_00001676 +BDMAP_00000117 +BDMAP_00003973 +BDMAP_00002509 +BDMAP_00002487 +BDMAP_00003457 +BDMAP_00000982 +BDMAP_00002260 +BDMAP_00001283 +BDMAP_00003506 +BDMAP_00000366 +BDMAP_00002133 +BDMAP_00000465 +BDMAP_00003767 +BDMAP_00001853 +BDMAP_00002361 +BDMAP_00004815 +BDMAP_00002933 +BDMAP_00000162 +BDMAP_00004925 +BDMAP_00005077 +BDMAP_00001533 +BDMAP_00001242 +BDMAP_00000871 +BDMAP_00000948 +BDMAP_00001119 +BDMAP_00004887 +BDMAP_00002404 +BDMAP_00003722 +BDMAP_00002426 +BDMAP_00002060 +BDMAP_00004850 +BDMAP_00003343 +BDMAP_00001624 +BDMAP_00000481 +BDMAP_00002166 +BDMAP_00003849 +BDMAP_00004808 +BDMAP_00002471 +BDMAP_00000656 +BDMAP_00003581 +BDMAP_00000023 +BDMAP_00003727 +BDMAP_00000319 +BDMAP_00003255 +BDMAP_00003752 +BDMAP_00000139 +BDMAP_00003614 +BDMAP_00003549 +BDMAP_00003808 +BDMAP_00002930 +BDMAP_00001128 +BDMAP_00004717 +BDMAP_00000826 +BDMAP_00002663 +BDMAP_00000837 +BDMAP_00000159 +BDMAP_00005154 +BDMAP_00002524 +BDMAP_00000968 +BDMAP_00004278 +BDMAP_00001325 +BDMAP_00000987 +BDMAP_00004901 +BDMAP_00003425 +BDMAP_00005006 +BDMAP_00004131 +BDMAP_00002403 +BDMAP_00001620 +BDMAP_00002347 +BDMAP_00001522 +BDMAP_00004011 +BDMAP_00001474 +BDMAP_00004744 +BDMAP_00002484 +BDMAP_00001370 +BDMAP_00003324 +BDMAP_00001557 +BDMAP_00000867 +BDMAP_00001487 +BDMAP_00004980 +BDMAP_00000034 +BDMAP_00000936 +BDMAP_00000128 +BDMAP_00001275 +BDMAP_00004030 +BDMAP_00003359 +BDMAP_00003070 +BDMAP_00002476 +BDMAP_00002990 +BDMAP_00000810 +BDMAP_00003514 +BDMAP_00004834 +BDMAP_00003409 +BDMAP_00002498 +BDMAP_00004481 +BDMAP_00002273 +BDMAP_00002496 +BDMAP_00002871 +BDMAP_00000059 +BDMAP_00001475 +BDMAP_00000902 +BDMAP_00004417 +BDMAP_00005157 +BDMAP_00001752 +BDMAP_00001563 +BDMAP_00003063 +BDMAP_00001296 +BDMAP_00002707 +BDMAP_00000836 +BDMAP_00000353 +BDMAP_00000043 +BDMAP_00000244 diff --git a/Generation_Pipeline_filter_all/real_set/liver.txt b/Generation_Pipeline_filter_all/real_set/liver.txt new file mode 100644 index 0000000000000000000000000000000000000000..4e722fe9f916144960815ec7f01a5d3ba64a5d1d --- /dev/null +++ b/Generation_Pipeline_filter_all/real_set/liver.txt @@ -0,0 +1,159 @@ +BDMAP_00000400 +BDMAP_00003497 +BDMAP_00001270 +BDMAP_00001766 +BDMAP_00001309 +BDMAP_00004745 +BDMAP_00003002 +BDMAP_00004825 +BDMAP_00004416 +BDMAP_00002712 +BDMAP_00004830 +BDMAP_00000907 +BDMAP_00001957 +BDMAP_00000941 +BDMAP_00002841 +BDMAP_00001962 +BDMAP_00004462 +BDMAP_00004281 +BDMAP_00004890 +BDMAP_00003272 +BDMAP_00003377 +BDMAP_00005186 +BDMAP_00002172 +BDMAP_00000091 +BDMAP_00004639 +BDMAP_00000918 +BDMAP_00000671 +BDMAP_00004028 +BDMAP_00004529 +BDMAP_00001907 +BDMAP_00001122 +BDMAP_00003151 +BDMAP_00002252 +BDMAP_00003524 +BDMAP_00004704 +BDMAP_00000362 +BDMAP_00003932 +BDMAP_00004995 +BDMAP_00002748 +BDMAP_00004117 +BDMAP_00000480 +BDMAP_00001010 +BDMAP_00000100 +BDMAP_00001200 +BDMAP_00004103 +BDMAP_00004878 +BDMAP_00002282 +BDMAP_00001471 +BDMAP_00000232 +BDMAP_00003439 +BDMAP_00003857 +BDMAP_00004943 +BDMAP_00005130 +BDMAP_00002479 +BDMAP_00002909 +BDMAP_00004185 +BDMAP_00003569 +BDMAP_00001185 +BDMAP_00002849 +BDMAP_00003556 +BDMAP_00003052 +BDMAP_00000971 +BDMAP_00003330 +BDMAP_00000113 +BDMAP_00004600 +BDMAP_00002529 +BDMAP_00000437 +BDMAP_00003074 +BDMAP_00005139 +BDMAP_00001966 +BDMAP_00002791 +BDMAP_00001692 +BDMAP_00001786 +BDMAP_00001697 +BDMAP_00003798 +BDMAP_00000273 +BDMAP_00001114 +BDMAP_00003898 +BDMAP_00001397 +BDMAP_00003867 +BDMAP_00005065 +BDMAP_00001802 +BDMAP_00001539 +BDMAP_00000084 +BDMAP_00002955 +BDMAP_00002271 +BDMAP_00004459 +BDMAP_00004378 +BDMAP_00004435 +BDMAP_00001093 +BDMAP_00003897 +BDMAP_00003236 +BDMAP_00001502 +BDMAP_00001834 +BDMAP_00000347 +BDMAP_00000831 +BDMAP_00002717 +BDMAP_00002856 +BDMAP_00004199 +BDMAP_00000709 +BDMAP_00003481 +BDMAP_00002719 +BDMAP_00005083 +BDMAP_00002359 +BDMAP_00000642 +BDMAP_00000778 +BDMAP_00000745 +BDMAP_00000607 +BDMAP_00001236 +BDMAP_00001333 +BDMAP_00003920 +BDMAP_00003664 +BDMAP_00003911 +BDMAP_00002463 +BDMAP_00002419 +BDMAP_00000965 +BDMAP_00003513 +BDMAP_00004508 +BDMAP_00002283 +BDMAP_00004509 +BDMAP_00000615 +BDMAP_00001171 +BDMAP_00001343 +BDMAP_00002167 +BDMAP_00000205 +BDMAP_00002805 +BDMAP_00002275 +BDMAP_00002485 +BDMAP_00004228 +BDMAP_00004304 +BDMAP_00004187 +BDMAP_00001379 +BDMAP_00001753 +BDMAP_00000413 +BDMAP_00002289 +BDMAP_00000572 +BDMAP_00005119 +BDMAP_00004017 +BDMAP_00004016 +BDMAP_00002349 +BDMAP_00000101 +BDMAP_00003482 +BDMAP_00004839 +BDMAP_00001025 +BDMAP_00003361 +BDMAP_00002495 +BDMAP_00001055 +BDMAP_00002214 +BDMAP_00005097 +BDMAP_00005168 +BDMAP_00002267 +BDMAP_00001198 +BDMAP_00002918 +BDMAP_00004664 +BDMAP_00004888 +BDMAP_00000921 +BDMAP_00002373 +BDMAP_00001316 +BDMAP_00002117 diff --git a/Generation_Pipeline_filter_all/real_set/pancreas.txt b/Generation_Pipeline_filter_all/real_set/pancreas.txt new file mode 100644 index 0000000000000000000000000000000000000000..2f0d669103ff6594f0e57d4a55cadffb2c4a8d9c --- /dev/null +++ b/Generation_Pipeline_filter_all/real_set/pancreas.txt @@ -0,0 +1,281 @@ +BDMAP_00003244 +BDMAP_00005074 +BDMAP_00004804 +BDMAP_00004672 +BDMAP_00003133 +BDMAP_00004969 +BDMAP_00002278 +BDMAP_00001862 +BDMAP_00005185 +BDMAP_00004880 +BDMAP_00004770 +BDMAP_00002690 +BDMAP_00002944 +BDMAP_00003744 +BDMAP_00002021 +BDMAP_00003141 +BDMAP_00004927 +BDMAP_00001476 +BDMAP_00003551 +BDMAP_00004964 +BDMAP_00001605 +BDMAP_00002298 +BDMAP_00001746 +BDMAP_00000332 +BDMAP_00003590 +BDMAP_00000956 +BDMAP_00001649 +BDMAP_00003781 +BDMAP_00001523 +BDMAP_00003347 +BDMAP_00005022 +BDMAP_00004128 +BDMAP_00003612 +BDMAP_00003658 +BDMAP_00003812 +BDMAP_00003427 +BDMAP_00003502 +BDMAP_00001823 +BDMAP_00004847 +BDMAP_00003776 +BDMAP_00001205 +BDMAP_00000192 +BDMAP_00004511 +BDMAP_00001564 +BDMAP_00000416 +BDMAP_00005070 +BDMAP_00001040 +BDMAP_00004231 +BDMAP_00002945 +BDMAP_00001704 +BDMAP_00002402 +BDMAP_00000940 +BDMAP_00000243 +BDMAP_00001464 +BDMAP_00002793 +BDMAP_00001646 +BDMAP_00005020 +BDMAP_00004992 +BDMAP_00003017 +BDMAP_00001096 +BDMAP_00003451 +BDMAP_00001067 +BDMAP_00001331 +BDMAP_00000696 +BDMAP_00001461 +BDMAP_00003326 +BDMAP_00000715 +BDMAP_00000855 +BDMAP_00000087 +BDMAP_00000093 +BDMAP_00000324 +BDMAP_00003440 +BDMAP_00002387 +BDMAP_00004060 +BDMAP_00000714 +BDMAP_00001617 +BDMAP_00004494 +BDMAP_00002616 +BDMAP_00000225 +BDMAP_00001754 +BDMAP_00005075 +BDMAP_00002328 +BDMAP_00004229 +BDMAP_00000541 +BDMAP_00004447 +BDMAP_00004106 +BDMAP_00003592 +BDMAP_00003036 +BDMAP_00001125 +BDMAP_00001361 +BDMAP_00002863 +BDMAP_00002309 +BDMAP_00001905 +BDMAP_00004115 +BDMAP_00002216 +BDMAP_00004829 +BDMAP_00003443 +BDMAP_00001504 +BDMAP_00004885 +BDMAP_00003451 +BDMAP_00000679 +BDMAP_00002362 +BDMAP_00000388 +BDMAP_00003769 +BDMAP_00004198 +BDMAP_00004719 +BDMAP_00000809 +BDMAP_00003525 +BDMAP_00003138 +BDMAP_00005063 +BDMAP_00000676 +BDMAP_00000411 +BDMAP_00002523 +BDMAP_00003367 +BDMAP_00003961 +BDMAP_00003822 +BDMAP_00000462 +BDMAP_00001632 +BDMAP_00003840 +BDMAP_00003483 +BDMAP_00002313 +BDMAP_00000154 +BDMAP_00001828 +BDMAP_00003771 +BDMAP_00004550 +BDMAP_00001628 +BDMAP_00003479 +BDMAP_00003396 +BDMAP_00000431 +BDMAP_00004077 +BDMAP_00002899 +BDMAP_00000542 +BDMAP_00000438 +BDMAP_00003277 +BDMAP_00002295 +BDMAP_00005140 +BDMAP_00004183 +BDMAP_00002029 +BDMAP_00003385 +BDMAP_00000447 +BDMAP_00004262 +BDMAP_00000430 +BDMAP_00001247 +BDMAP_00003809 +BDMAP_00000771 +BDMAP_00004773 +BDMAP_00001175 +BDMAP_00000774 +BDMAP_00001419 +BDMAP_00003319 +BDMAP_00001712 +BDMAP_00004129 +BDMAP_00002688 +BDMAP_00004858 +BDMAP_00003886 +BDMAP_00004184 +BDMAP_00000589 +BDMAP_00001414 +BDMAP_00001590 +BDMAP_00002896 +BDMAP_00005064 +BDMAP_00004514 +BDMAP_00003884 +BDMAP_00001565 +BDMAP_00000236 +BDMAP_00001736 +BDMAP_00004895 +BDMAP_00001597 +BDMAP_00003631 +BDMAP_00000692 +BDMAP_00004843 +BDMAP_00004288 +BDMAP_00000623 +BDMAP_00004398 +BDMAP_00001368 +BDMAP_00000701 +BDMAP_00002855 +BDMAP_00004293 +BDMAP_00001806 +BDMAP_00000882 +BDMAP_00004796 +BDMAP_00002603 +BDMAP_00005155 +BDMAP_00001836 +BDMAP_00001440 +BDMAP_00004295 +BDMAP_00000859 +BDMAP_00002120 +BDMAP_00001092 +BDMAP_00002171 +BDMAP_00002947 +BDMAP_00005169 +BDMAP_00004015 +BDMAP_00001804 +BDMAP_00003329 +BDMAP_00003657 +BDMAP_00000427 +BDMAP_00001921 +BDMAP_00003215 +BDMAP_00001521 +BDMAP_00001288 +BDMAP_00003918 +BDMAP_00004097 +BDMAP_00003598 +BDMAP_00000614 +BDMAP_00004541 +BDMAP_00004264 +BDMAP_00001618 +BDMAP_00001842 +BDMAP_00002076 +BDMAP_00002332 +BDMAP_00003683 +BDMAP_00001214 +BDMAP_00003685 +BDMAP_00002244 +BDMAP_00003114 +BDMAP_00001057 +BDMAP_00004917 +BDMAP_00003543 +BDMAP_00003633 +BDMAP_00001898 +BDMAP_00000683 +BDMAP_00005141 +BDMAP_00003853 +BDMAP_00003650 +BDMAP_00002619 +BDMAP_00002250 +BDMAP_00002304 +BDMAP_00002815 +BDMAP_00002188 +BDMAP_00001701 +BDMAP_00004023 +BDMAP_00002233 +BDMAP_00003130 +BDMAP_00004286 +BDMAP_00002227 +BDMAP_00003254 +BDMAP_00003376 +BDMAP_00001441 +BDMAP_00004954 +BDMAP_00000052 +BDMAP_00000558 +BDMAP_00005092 +BDMAP_00000993 +BDMAP_00001912 +BDMAP_00003168 +BDMAP_00001545 +BDMAP_00005078 +BDMAP_00000618 +BDMAP_00004546 +BDMAP_00002580 +BDMAP_00000197 +BDMAP_00000972 +BDMAP_00002237 +BDMAP_00004549 +BDMAP_00004841 +BDMAP_00004741 +BDMAP_00003824 +BDMAP_00005108 +BDMAP_00004651 +BDMAP_00005037 +BDMAP_00000470 +BDMAP_00002829 +BDMAP_00003438 +BDMAP_00002411 +BDMAP_00004793 +BDMAP_00004636 +BDMAP_00004641 +BDMAP_00002737 +BDMAP_00003356 +BDMAP_00001845 +BDMAP_00004735 +BDMAP_00000338 +BDMAP_00002844 +BDMAP_00001584 +BDMAP_00003900 +BDMAP_00002232 +BDMAP_00004297 +BDMAP_00003400 +BDMAP_00002758 +BDMAP_00002475 diff --git a/Generation_Pipeline_filter_all/real_total.txt b/Generation_Pipeline_filter_all/real_total.txt new file mode 100644 index 0000000000000000000000000000000000000000..6ad82b735dc6c762eb481d2f0d29160601f75127 --- /dev/null +++ b/Generation_Pipeline_filter_all/real_total.txt @@ -0,0 +1,1054 @@ +BDMAP_00002856 +BDMAP_00004199 +BDMAP_00000709 +BDMAP_00003481 +BDMAP_00002719 +BDMAP_00005083 +BDMAP_00002359 +BDMAP_00000642 +BDMAP_00000778 +BDMAP_00000745 +BDMAP_00000607 +BDMAP_00001236 +BDMAP_00001333 +BDMAP_00003920 +BDMAP_00003664 +BDMAP_00003911 +BDMAP_00002463 +BDMAP_00002419 +BDMAP_00000965 +BDMAP_00003513 +BDMAP_00004508 +BDMAP_00002283 +BDMAP_00004509 +BDMAP_00000615 +BDMAP_00001171 +BDMAP_00001343 +BDMAP_00002167 +BDMAP_00000205 +BDMAP_00002805 +BDMAP_00002275 +BDMAP_00002485 +BDMAP_00004228 +BDMAP_00004304 +BDMAP_00004187 +BDMAP_00001379 +BDMAP_00001753 +BDMAP_00000413 +BDMAP_00002289 +BDMAP_00000572 +BDMAP_00005119 +BDMAP_00004017 +BDMAP_00004016 +BDMAP_00002349 +BDMAP_00000101 +BDMAP_00003482 +BDMAP_00004839 +BDMAP_00001025 +BDMAP_00003361 +BDMAP_00002495 +BDMAP_00001055 +BDMAP_00002214 +BDMAP_00005097 +BDMAP_00005168 +BDMAP_00002267 +BDMAP_00001198 +BDMAP_00002918 +BDMAP_00004664 +BDMAP_00004888 +BDMAP_00000921 +BDMAP_00002373 +BDMAP_00001316 +BDMAP_00002117 +BDMAP_00001361 +BDMAP_00002863 +BDMAP_00002309 +BDMAP_00001905 +BDMAP_00004115 +BDMAP_00002216 +BDMAP_00004829 +BDMAP_00003443 +BDMAP_00001504 +BDMAP_00004885 +BDMAP_00003451 +BDMAP_00000679 +BDMAP_00002362 +BDMAP_00000388 +BDMAP_00003769 +BDMAP_00004198 +BDMAP_00004719 +BDMAP_00000809 +BDMAP_00003525 +BDMAP_00003138 +BDMAP_00005063 +BDMAP_00000676 +BDMAP_00000411 +BDMAP_00002523 +BDMAP_00003367 +BDMAP_00003961 +BDMAP_00003822 +BDMAP_00000462 +BDMAP_00001632 +BDMAP_00003840 +BDMAP_00003483 +BDMAP_00002313 +BDMAP_00000154 +BDMAP_00001828 +BDMAP_00003771 +BDMAP_00004550 +BDMAP_00001628 +BDMAP_00003479 +BDMAP_00003396 +BDMAP_00000431 +BDMAP_00004077 +BDMAP_00002899 +BDMAP_00000542 +BDMAP_00000438 +BDMAP_00003277 +BDMAP_00002295 +BDMAP_00005140 +BDMAP_00004183 +BDMAP_00002029 +BDMAP_00003385 +BDMAP_00000447 +BDMAP_00004262 +BDMAP_00000430 +BDMAP_00001247 +BDMAP_00003809 +BDMAP_00000771 +BDMAP_00004773 +BDMAP_00001175 +BDMAP_00000774 +BDMAP_00001419 +BDMAP_00003319 +BDMAP_00001712 +BDMAP_00004129 +BDMAP_00002688 +BDMAP_00004858 +BDMAP_00003886 +BDMAP_00004184 +BDMAP_00000589 +BDMAP_00001414 +BDMAP_00001590 +BDMAP_00002896 +BDMAP_00005064 +BDMAP_00004514 +BDMAP_00003884 +BDMAP_00001565 +BDMAP_00000236 +BDMAP_00001736 +BDMAP_00004895 +BDMAP_00001597 +BDMAP_00003631 +BDMAP_00000692 +BDMAP_00004843 +BDMAP_00004288 +BDMAP_00000623 +BDMAP_00004398 +BDMAP_00001368 +BDMAP_00000701 +BDMAP_00002855 +BDMAP_00004293 +BDMAP_00001806 +BDMAP_00000882 +BDMAP_00004796 +BDMAP_00002603 +BDMAP_00005155 +BDMAP_00001836 +BDMAP_00001440 +BDMAP_00004295 +BDMAP_00000859 +BDMAP_00002120 +BDMAP_00001092 +BDMAP_00002171 +BDMAP_00002947 +BDMAP_00005169 +BDMAP_00004015 +BDMAP_00001804 +BDMAP_00003329 +BDMAP_00003657 +BDMAP_00000427 +BDMAP_00001921 +BDMAP_00003215 +BDMAP_00001521 +BDMAP_00001288 +BDMAP_00003918 +BDMAP_00004097 +BDMAP_00003598 +BDMAP_00000614 +BDMAP_00004541 +BDMAP_00004264 +BDMAP_00001618 +BDMAP_00001842 +BDMAP_00002076 +BDMAP_00002332 +BDMAP_00003683 +BDMAP_00001214 +BDMAP_00003685 +BDMAP_00002244 +BDMAP_00003114 +BDMAP_00001057 +BDMAP_00004917 +BDMAP_00003543 +BDMAP_00003633 +BDMAP_00001898 +BDMAP_00000683 +BDMAP_00005141 +BDMAP_00003853 +BDMAP_00003650 +BDMAP_00002619 +BDMAP_00002250 +BDMAP_00002304 +BDMAP_00002815 +BDMAP_00002188 +BDMAP_00001701 +BDMAP_00004023 +BDMAP_00002233 +BDMAP_00003130 +BDMAP_00004286 +BDMAP_00002227 +BDMAP_00003254 +BDMAP_00003376 +BDMAP_00001441 +BDMAP_00004954 +BDMAP_00000052 +BDMAP_00000558 +BDMAP_00005092 +BDMAP_00000993 +BDMAP_00001912 +BDMAP_00003168 +BDMAP_00001545 +BDMAP_00005078 +BDMAP_00000618 +BDMAP_00004546 +BDMAP_00002580 +BDMAP_00000197 +BDMAP_00000972 +BDMAP_00002237 +BDMAP_00004549 +BDMAP_00004841 +BDMAP_00004741 +BDMAP_00003824 +BDMAP_00005108 +BDMAP_00004651 +BDMAP_00005037 +BDMAP_00000470 +BDMAP_00002829 +BDMAP_00003438 +BDMAP_00002411 +BDMAP_00004793 +BDMAP_00004636 +BDMAP_00004641 +BDMAP_00002737 +BDMAP_00003356 +BDMAP_00001845 +BDMAP_00004735 +BDMAP_00000338 +BDMAP_00002844 +BDMAP_00001584 +BDMAP_00003900 +BDMAP_00002232 +BDMAP_00004297 +BDMAP_00003400 +BDMAP_00002758 +BDMAP_00002475 +BDMAP_00000245 +BDMAP_00000036 +BDMAP_00003833 +BDMAP_00001517 +BDMAP_00004087 +BDMAP_00002807 +BDMAP_00002099 +BDMAP_00001602 +BDMAP_00001035 +BDMAP_00002422 +BDMAP_00000626 +BDMAP_00002173 +BDMAP_00000240 +BDMAP_00001246 +BDMAP_00000582 +BDMAP_00003996 +BDMAP_00001707 +BDMAP_00000923 +BDMAP_00003411 +BDMAP_00004113 +BDMAP_00002582 +BDMAP_00001261 +BDMAP_00005167 +BDMAP_00004897 +BDMAP_00001169 +BDMAP_00001148 +BDMAP_00002164 +BDMAP_00002041 +BDMAP_00000889 +BDMAP_00001109 +BDMAP_00005009 +BDMAP_00001286 +BDMAP_00000297 +BDMAP_00005099 +BDMAP_00004257 +BDMAP_00005017 +BDMAP_00000604 +BDMAP_00002472 +BDMAP_00001225 +BDMAP_00005081 +BDMAP_00003491 +BDMAP_00001635 +BDMAP_00002075 +BDMAP_00000660 +BDMAP_00001238 +BDMAP_00002656 +BDMAP_00003558 +BDMAP_00001104 +BDMAP_00004066 +BDMAP_00003294 +BDMAP_00001607 +BDMAP_00001077 +BDMAP_00000653 +BDMAP_00001273 +BDMAP_00000616 +BDMAP_00002057 +BDMAP_00004586 +BDMAP_00004407 +BDMAP_00004922 +BDMAP_00002592 +BDMAP_00000149 +BDMAP_00000320 +BDMAP_00001511 +BDMAP_00000435 +BDMAP_00002746 +BDMAP_00004457 +BDMAP_00000805 +BDMAP_00002661 +BDMAP_00004552 +BDMAP_00004154 +BDMAP_00002902 +BDMAP_00000839 +BDMAP_00000233 +BDMAP_00000122 +BDMAP_00005151 +BDMAP_00004427 +BDMAP_00002936 +BDMAP_00003955 +BDMAP_00001863 +BDMAP_00002326 +BDMAP_00001420 +BDMAP_00000329 +BDMAP_00004561 +BDMAP_00003971 +BDMAP_00000935 +BDMAP_00000569 +BDMAP_00004956 +BDMAP_00000285 +BDMAP_00004597 +BDMAP_00001747 +BDMAP_00001059 +BDMAP_00002354 +BDMAP_00001656 +BDMAP_00004395 +BDMAP_00002942 +BDMAP_00004981 +BDMAP_00001768 +BDMAP_00002319 +BDMAP_00003947 +BDMAP_00001868 +BDMAP_00002065 +BDMAP_00002333 +BDMAP_00003358 +BDMAP_00001265 +BDMAP_00003952 +BDMAP_00001891 +BDMAP_00003576 +BDMAP_00000980 +BDMAP_00003300 +BDMAP_00001782 +BDMAP_00003717 +BDMAP_00001251 +BDMAP_00000044 +BDMAP_00004510 +BDMAP_00003315 +BDMAP_00002653 +BDMAP_00001045 +BDMAP_00003694 +BDMAP_00004216 +BDMAP_00001794 +BDMAP_00000532 +BDMAP_00002288 +BDMAP_00001256 +BDMAP_00000219 +BDMAP_00000710 +BDMAP_00003930 +BDMAP_00001636 +BDMAP_00003749 +BDMAP_00000998 +BDMAP_00000176 +BDMAP_00000429 +BDMAP_00001001 +BDMAP_00001908 +BDMAP_00003363 +BDMAP_00004903 +BDMAP_00004482 +BDMAP_00003178 +BDMAP_00003202 +BDMAP_00001230 +BDMAP_00003461 +BDMAP_00003281 +BDMAP_00000434 +BDMAP_00001218 +BDMAP_00003976 +BDMAP_00003455 +BDMAP_00001183 +BDMAP_00002609 +BDMAP_00001305 +BDMAP_00000364 +BDMAP_00003516 +BDMAP_00003956 +BDMAP_00000977 +BDMAP_00001784 +BDMAP_00004389 +BDMAP_00001711 +BDMAP_00000698 +BDMAP_00003153 +BDMAP_00001995 +BDMAP_00001549 +BDMAP_00001324 +BDMAP_00004195 +BDMAP_00001562 +BDMAP_00004074 +BDMAP_00001483 +BDMAP_00002085 +BDMAP_00001396 +BDMAP_00000241 +BDMAP_00004031 +BDMAP_00004775 +BDMAP_00001807 +BDMAP_00005120 +BDMAP_00004065 +BDMAP_00003943 +BDMAP_00002953 +BDMAP_00004232 +BDMAP_00002184 +BDMAP_00002407 +BDMAP_00003252 +BDMAP_00004296 +BDMAP_00000161 +BDMAP_00002981 +BDMAP_00003608 +BDMAP_00003128 +BDMAP_00000571 +BDMAP_00000259 +BDMAP_00003444 +BDMAP_00001647 +BDMAP_00000662 +BDMAP_00003774 +BDMAP_00001383 +BDMAP_00004616 +BDMAP_00001906 +BDMAP_00003740 +BDMAP_00001422 +BDMAP_00002631 +BDMAP_00004294 +BDMAP_00003994 +BDMAP_00004475 +BDMAP_00002744 +BDMAP_00001068 +BDMAP_00000667 +BDMAP_00001945 +BDMAP_00002710 +BDMAP_00002440 +BDMAP_00000833 +BDMAP_00003143 +BDMAP_00000062 +BDMAP_00003392 +BDMAP_00004373 +BDMAP_00001020 +BDMAP_00003603 +BDMAP_00001027 +BDMAP_00005114 +BDMAP_00003384 +BDMAP_00000794 +BDMAP_00001911 +BDMAP_00002437 +BDMAP_00004579 +BDMAP_00004250 +BDMAP_00002068 +BDMAP_00000608 +BDMAP_00004551 +BDMAP_00002884 +BDMAP_00004033 +BDMAP_00005105 +BDMAP_00002776 +BDMAP_00000414 +BDMAP_00003580 +BDMAP_00004712 +BDMAP_00002114 +BDMAP_00002226 +BDMAP_00003923 +BDMAP_00002854 +BDMAP_00004039 +BDMAP_00004014 +BDMAP_00001289 +BDMAP_00003435 +BDMAP_00004578 +BDMAP_00002940 +BDMAP_00003164 +BDMAP_00002751 +BDMAP_00001516 +BDMAP_00003486 +BDMAP_00000279 +BDMAP_00001664 +BDMAP_00004738 +BDMAP_00001735 +BDMAP_00000562 +BDMAP_00000812 +BDMAP_00000511 +BDMAP_00004746 +BDMAP_00000452 +BDMAP_00004328 +BDMAP_00002017 +BDMAP_00002840 +BDMAP_00000039 +BDMAP_00002242 +BDMAP_00002775 +BDMAP_00003762 +BDMAP_00000229 +BDMAP_00003520 +BDMAP_00000725 +BDMAP_00000516 +BDMAP_00001941 +BDMAP_00003928 +BDMAP_00001255 +BDMAP_00001456 +BDMAP_00002410 +BDMAP_00002742 +BDMAP_00001688 +BDMAP_00000487 +BDMAP_00000469 +BDMAP_00002022 +BDMAP_00003058 +BDMAP_00004148 +BDMAP_00001977 +BDMAP_00000887 +BDMAP_00003448 +BDMAP_00001410 +BDMAP_00002383 +BDMAP_00003736 +BDMAP_00002626 +BDMAP_00001710 +BDMAP_00001130 +BDMAP_00001138 +BDMAP_00001413 +BDMAP_00003815 +BDMAP_00004130 +BDMAP_00004652 +BDMAP_00002864 +BDMAP_00000574 +BDMAP_00003493 +BDMAP_00003364 +BDMAP_00002648 +BDMAP_00001281 +BDMAP_00002655 +BDMAP_00001126 +BDMAP_00002804 +BDMAP_00000321 +BDMAP_00005191 +BDMAP_00004420 +BDMAP_00000304 +BDMAP_00003150 +BDMAP_00004620 +BDMAP_00000368 +BDMAP_00000066 +BDMAP_00003701 +BDMAP_00005174 +BDMAP_00002545 +BDMAP_00003957 +BDMAP_00004331 +BDMAP_00000687 +BDMAP_00001791 +BDMAP_00002959 +BDMAP_00004104 +BDMAP_00003073 +BDMAP_00003713 +BDMAP_00002363 +BDMAP_00000137 +BDMAP_00000104 +BDMAP_00002689 +BDMAP_00004990 +BDMAP_00003301 +BDMAP_00001434 +BDMAP_00000449 +BDMAP_00005113 +BDMAP_00003225 +BDMAP_00001359 +BDMAP_00001223 +BDMAP_00002803 +BDMAP_00000355 +BDMAP_00001826 +BDMAP_00004673 +BDMAP_00002251 +BDMAP_00000439 +BDMAP_00005085 +BDMAP_00003381 +BDMAP_00004645 +BDMAP_00000432 +BDMAP_00001444 +BDMAP_00001705 +BDMAP_00001892 +BDMAP_00002826 +BDMAP_00004671 +BDMAP_00000926 +BDMAP_00004817 +BDMAP_00004175 +BDMAP_00003484 +BDMAP_00003672 +BDMAP_00003267 +BDMAP_00001089 +BDMAP_00001496 +BDMAP_00003615 +BDMAP_00003832 +BDMAP_00002695 +BDMAP_00002696 +BDMAP_00004499 +BDMAP_00004867 +BDMAP_00004479 +BDMAP_00003600 +BDMAP_00000989 +BDMAP_00002421 +BDMAP_00003406 +BDMAP_00000263 +BDMAP_00002396 +BDMAP_00002265 +BDMAP_00000713 +BDMAP_00000883 +BDMAP_00001258 +BDMAP_00004253 +BDMAP_00004870 +BDMAP_00000331 +BDMAP_00004608 +BDMAP_00001518 +BDMAP_00002562 +BDMAP_00002889 +BDMAP_00001676 +BDMAP_00000117 +BDMAP_00003973 +BDMAP_00002509 +BDMAP_00002487 +BDMAP_00003457 +BDMAP_00000982 +BDMAP_00002260 +BDMAP_00001283 +BDMAP_00003506 +BDMAP_00000366 +BDMAP_00002133 +BDMAP_00000465 +BDMAP_00003767 +BDMAP_00001853 +BDMAP_00002361 +BDMAP_00004815 +BDMAP_00002933 +BDMAP_00000162 +BDMAP_00004925 +BDMAP_00005077 +BDMAP_00001533 +BDMAP_00001242 +BDMAP_00000871 +BDMAP_00000948 +BDMAP_00001119 +BDMAP_00004887 +BDMAP_00002404 +BDMAP_00003722 +BDMAP_00002426 +BDMAP_00002060 +BDMAP_00004850 +BDMAP_00003343 +BDMAP_00001624 +BDMAP_00000481 +BDMAP_00002166 +BDMAP_00003849 +BDMAP_00004808 +BDMAP_00002471 +BDMAP_00000656 +BDMAP_00003581 +BDMAP_00000023 +BDMAP_00003727 +BDMAP_00000319 +BDMAP_00003255 +BDMAP_00003752 +BDMAP_00000139 +BDMAP_00003614 +BDMAP_00003549 +BDMAP_00003808 +BDMAP_00002930 +BDMAP_00001128 +BDMAP_00004717 +BDMAP_00000826 +BDMAP_00002663 +BDMAP_00000837 +BDMAP_00000159 +BDMAP_00005154 +BDMAP_00002524 +BDMAP_00000968 +BDMAP_00004278 +BDMAP_00001325 +BDMAP_00000987 +BDMAP_00004901 +BDMAP_00003425 +BDMAP_00005006 +BDMAP_00004131 +BDMAP_00002403 +BDMAP_00001620 +BDMAP_00002347 +BDMAP_00001522 +BDMAP_00004011 +BDMAP_00001474 +BDMAP_00004744 +BDMAP_00002484 +BDMAP_00001370 +BDMAP_00003324 +BDMAP_00001557 +BDMAP_00000867 +BDMAP_00001487 +BDMAP_00004980 +BDMAP_00000034 +BDMAP_00000936 +BDMAP_00000128 +BDMAP_00001275 +BDMAP_00004030 +BDMAP_00003359 +BDMAP_00003070 +BDMAP_00002476 +BDMAP_00002990 +BDMAP_00000810 +BDMAP_00003514 +BDMAP_00004834 +BDMAP_00003409 +BDMAP_00002498 +BDMAP_00004481 +BDMAP_00002273 +BDMAP_00002496 +BDMAP_00002871 +BDMAP_00000059 +BDMAP_00001475 +BDMAP_00000902 +BDMAP_00004417 +BDMAP_00005157 +BDMAP_00001752 +BDMAP_00001563 +BDMAP_00003063 +BDMAP_00001296 +BDMAP_00002707 +BDMAP_00000836 +BDMAP_00000353 +BDMAP_00000043 +BDMAP_00000244 +BDMAP_00000264 +BDMAP_00000690 +BDMAP_00002039 +BDMAP_00001426 +BDMAP_00002730 +BDMAP_00001917 +BDMAP_00005067 +BDMAP_00002924 +BDMAP_00005160 +BDMAP_00005073 +BDMAP_00000547 +BDMAP_00000942 +BDMAP_00002103 +BDMAP_00002654 +BDMAP_00004374 +BDMAP_00003510 +BDMAP_00004910 +BDMAP_00004558 +BDMAP_00004450 +BDMAP_00000152 +BDMAP_00004491 +BDMAP_00001237 +BDMAP_00001785 +BDMAP_00001865 +BDMAP_00000851 +BDMAP_00003357 +BDMAP_00004415 +BDMAP_00004615 +BDMAP_00003680 +BDMAP_00001875 +BDMAP_00004894 +BDMAP_00001835 +BDMAP_00000069 +BDMAP_00001809 +BDMAP_00004431 +BDMAP_00002704 +BDMAP_00002185 +BDMAP_00004384 +BDMAP_00003299 +BDMAP_00003333 +BDMAP_00002305 +BDMAP_00001598 +BDMAP_00002465 +BDMAP_00002199 +BDMAP_00002875 +BDMAP_00000828 +BDMAP_00003564 +BDMAP_00005001 +BDMAP_00004493 +BDMAP_00000190 +BDMAP_00000873 +BDMAP_00005170 +BDMAP_00002152 +BDMAP_00004163 +BDMAP_00000939 +BDMAP_00001212 +BDMAP_00001982 +BDMAP_00000552 +BDMAP_00004764 +BDMAP_00002401 +BDMAP_00002451 +BDMAP_00003634 +BDMAP_00005016 +BDMAP_00000716 +BDMAP_00003373 +BDMAP_00000030 +BDMAP_00003946 +BDMAP_00002828 +BDMAP_00004196 +BDMAP_00005005 +BDMAP_00003972 +BDMAP_00003172 +BDMAP_00004783 +BDMAP_00001102 +BDMAP_00004147 +BDMAP_00004604 +BDMAP_00000400 +BDMAP_00003497 +BDMAP_00001270 +BDMAP_00001766 +BDMAP_00001309 +BDMAP_00004745 +BDMAP_00003002 +BDMAP_00004825 +BDMAP_00004416 +BDMAP_00002712 +BDMAP_00004830 +BDMAP_00000907 +BDMAP_00001957 +BDMAP_00000941 +BDMAP_00002841 +BDMAP_00001962 +BDMAP_00004462 +BDMAP_00004281 +BDMAP_00004890 +BDMAP_00003272 +BDMAP_00003377 +BDMAP_00005186 +BDMAP_00002172 +BDMAP_00000091 +BDMAP_00004639 +BDMAP_00000918 +BDMAP_00000671 +BDMAP_00004028 +BDMAP_00004529 +BDMAP_00001907 +BDMAP_00001122 +BDMAP_00003151 +BDMAP_00002252 +BDMAP_00003524 +BDMAP_00004704 +BDMAP_00000362 +BDMAP_00003932 +BDMAP_00004995 +BDMAP_00002748 +BDMAP_00004117 +BDMAP_00000480 +BDMAP_00001010 +BDMAP_00000100 +BDMAP_00001200 +BDMAP_00004103 +BDMAP_00004878 +BDMAP_00002282 +BDMAP_00001471 +BDMAP_00000232 +BDMAP_00003439 +BDMAP_00003857 +BDMAP_00004943 +BDMAP_00005130 +BDMAP_00002479 +BDMAP_00002909 +BDMAP_00004185 +BDMAP_00003569 +BDMAP_00001185 +BDMAP_00001078 +BDMAP_00003031 +BDMAP_00002253 +BDMAP_00001732 +BDMAP_00000874 +BDMAP_00003847 +BDMAP_00003268 +BDMAP_00002846 +BDMAP_00001438 +BDMAP_00004650 +BDMAP_00003109 +BDMAP_00004121 +BDMAP_00004165 +BDMAP_00004676 +BDMAP_00003890 +BDMAP_00003327 +BDMAP_00000132 +BDMAP_00001215 +BDMAP_00001769 +BDMAP_00003412 +BDMAP_00002318 +BDMAP_00004624 +BDMAP_00000345 +BDMAP_00002230 +BDMAP_00003111 +BDMAP_00001015 +BDMAP_00001514 +BDMAP_00001924 +BDMAP_00002845 +BDMAP_00002598 +BDMAP_00001209 +BDMAP_00000373 +BDMAP_00001737 +BDMAP_00003113 +BDMAP_00004876 +BDMAP_00003640 +BDMAP_00001985 +BDMAP_00000138 +BDMAP_00000881 +BDMAP_00002739 +BDMAP_00003560 +BDMAP_00002612 +BDMAP_00001445 +BDMAP_00003827 +BDMAP_00001024 +BDMAP_00000568 +BDMAP_00001095 +BDMAP_00002458 +BDMAP_00002986 +BDMAP_00000913 +BDMAP_00002849 +BDMAP_00003556 +BDMAP_00003052 +BDMAP_00000971 +BDMAP_00003330 +BDMAP_00000113 +BDMAP_00004600 +BDMAP_00002529 +BDMAP_00000437 +BDMAP_00003074 +BDMAP_00005139 +BDMAP_00001966 +BDMAP_00002791 +BDMAP_00001692 +BDMAP_00001786 +BDMAP_00001697 +BDMAP_00003798 +BDMAP_00000273 +BDMAP_00001114 +BDMAP_00003898 +BDMAP_00001397 +BDMAP_00003867 +BDMAP_00005065 +BDMAP_00001802 +BDMAP_00001539 +BDMAP_00000084 +BDMAP_00002955 +BDMAP_00002271 +BDMAP_00004459 +BDMAP_00004378 +BDMAP_00004435 +BDMAP_00001093 +BDMAP_00003897 +BDMAP_00003236 +BDMAP_00001502 +BDMAP_00001834 +BDMAP_00000347 +BDMAP_00000831 +BDMAP_00002717 +BDMAP_00003244 +BDMAP_00005074 +BDMAP_00004804 +BDMAP_00004672 +BDMAP_00003133 +BDMAP_00004969 +BDMAP_00002278 +BDMAP_00001862 +BDMAP_00005185 +BDMAP_00004880 +BDMAP_00004770 +BDMAP_00002690 +BDMAP_00002944 +BDMAP_00003744 +BDMAP_00002021 +BDMAP_00003141 +BDMAP_00004927 +BDMAP_00001476 +BDMAP_00003551 +BDMAP_00004964 +BDMAP_00001605 +BDMAP_00002298 +BDMAP_00001746 +BDMAP_00000332 +BDMAP_00003590 +BDMAP_00000956 +BDMAP_00001649 +BDMAP_00003781 +BDMAP_00001523 +BDMAP_00003347 +BDMAP_00005022 +BDMAP_00004128 +BDMAP_00003612 +BDMAP_00003658 +BDMAP_00003812 +BDMAP_00003427 +BDMAP_00003502 +BDMAP_00001823 +BDMAP_00004847 +BDMAP_00003776 +BDMAP_00001205 +BDMAP_00000192 +BDMAP_00004511 +BDMAP_00001564 +BDMAP_00000416 +BDMAP_00005070 +BDMAP_00001040 +BDMAP_00004231 +BDMAP_00002945 +BDMAP_00001704 +BDMAP_00002402 +BDMAP_00000940 +BDMAP_00000243 +BDMAP_00001464 +BDMAP_00002793 +BDMAP_00001646 +BDMAP_00005020 +BDMAP_00004992 +BDMAP_00003017 +BDMAP_00001096 +BDMAP_00001067 +BDMAP_00001331 +BDMAP_00000696 +BDMAP_00001461 +BDMAP_00003326 +BDMAP_00000715 +BDMAP_00000855 +BDMAP_00000087 +BDMAP_00000093 +BDMAP_00000324 +BDMAP_00003440 +BDMAP_00002387 +BDMAP_00004060 +BDMAP_00000714 +BDMAP_00001617 +BDMAP_00004494 +BDMAP_00002616 +BDMAP_00000225 +BDMAP_00001754 +BDMAP_00005075 +BDMAP_00002328 +BDMAP_00004229 +BDMAP_00000541 +BDMAP_00004447 +BDMAP_00004106 +BDMAP_00003592 +BDMAP_00003036 +BDMAP_00001125 diff --git a/Generation_Pipeline_filter_all/resample.py b/Generation_Pipeline_filter_all/resample.py new file mode 100644 index 0000000000000000000000000000000000000000..1179efedb86328d1832d50547f1b29877708d608 --- /dev/null +++ b/Generation_Pipeline_filter_all/resample.py @@ -0,0 +1,120 @@ +import os, time, csv +import numpy as np +import torch +from sklearn.metrics import confusion_matrix +from scipy import ndimage +from scipy.ndimage import label +from functools import partial +import monai +from monai.transforms import AsDiscrete,AsDiscreted,Compose,Invertd,SaveImaged +from monai import transforms, data +import nibabel as nib + +import warnings +warnings.filterwarnings("ignore") + +import argparse +parser = argparse.ArgumentParser(description='colon tumor validation') + +# file dir +parser.add_argument('--data_root', default=None, type=str) +parser.add_argument('--organ_type', default='colon', type=str) +parser.add_argument('--save_dir', default='out', type=str) +parser.add_argument('--data_file', default='out', type=str) +parser.add_argument('--ddim_ts', default=50, type=int) +parser.add_argument('--fg_thresh', default=30, type=int) +parser.add_argument('--start', default=0, type=int) +parser.add_argument('--end', default=1000, type=int) + +def voxel2R(A): + return (np.array(A)/4*3/np.pi)**(1/3) + +def _get_loader(args): + # val_data_dir = args.val_dir + # datalist_json = args.json_dir + val_org_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image"], dtype=np.int16), + transforms.AddChanneld(keys=["image"]), + transforms.Orientationd(keys=["image"], axcodes="RAS"), + transforms.Spacingd(keys=["image"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear")), + # transforms.ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), + transforms.ToTensord(keys=["image"]), + ] + ) + + val_img=[] + val_name=[] + + for line in open(args.data_file): + # name = line.strip().split()[1].split('.')[0] + # val_img.append(args.data_root + line.strip().split()[0]) + # val_lbl.append(args.data_root + line.strip().split()[1]) + # breakpoint() + name = line.strip() + val_img.append(os.path.join(args.data_root, name, 'ct.nii.gz')) + val_name.append(name) + data_dicts_val = [{'image': image, 'name': name} + for image, name in zip(val_img, val_name)] + + if args.end < len(data_dicts_val): + data_dicts_val = data_dicts_val[args.start:args.end] + else: + data_dicts_val = data_dicts_val[args.start:] + print('val len {}'.format(len(data_dicts_val))) + val_org_ds = data.Dataset(data_dicts_val, transform=val_org_transform) + val_org_loader = data.DataLoader(val_org_ds, batch_size=1, shuffle=False, num_workers=4, sampler=None, pin_memory=True) + + post_transforms = Compose([ + Invertd( + keys=['image'], + transform=val_org_transform, + orig_keys="image", + nearest_interp=False, + # nearest_interp=True, + to_tensor=True, + ) + ]) + return val_org_loader, post_transforms + +def main(): + args = parser.parse_args() + output_dir = args.save_dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + print("MAIN Argument values:") + for k, v in vars(args).items(): + print(k, '=>', v) + print('-----------------') + + ## loader and post_transform + val_loader, post_transforms = _get_loader(args) + + start_time=0 + with torch.no_grad(): + for idx, val_data in enumerate(val_loader): + print('idx',idx) + if idx == 0: + start_time = time.time() + + data_names = val_data['name'] + case_name = data_names[0].split('/')[-1] + print('case_name', case_name) + original_affine = val_data["image_meta_dict"]["original_affine"][0].numpy() + + val_data = [post_transforms(i) for i in data.decollate_batch(val_data)] + synt_data = val_data[0]['image'][0] + final_data = synt_data.cpu().numpy() + + # synt_data = val_data['image'] + # final_data = synt_data.cpu().numpy()[0,0] + # breakpoint() + os.makedirs(os.path.join(output_dir, f'{case_name}'), exist_ok=True) + nib.save(nib.Nifti1Image(final_data, original_affine), os.path.join(output_dir, f'{case_name}/ct.nii.gz')) + # breakpoint() + print('time = ', time.time()-start_time) + start_time = time.time() + + +if __name__ == "__main__": + main() diff --git a/Generation_Pipeline_filter_all/syn_colon/CT_syn_colon_data_new.py b/Generation_Pipeline_filter_all/syn_colon/CT_syn_colon_data_new.py new file mode 100644 index 0000000000000000000000000000000000000000..6ee937d183358ae83aaf57de0e1b84db4e357193 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_colon/CT_syn_colon_data_new.py @@ -0,0 +1,230 @@ +import os, time, csv +import numpy as np +import torch +from sklearn.metrics import confusion_matrix +from scipy import ndimage +from scipy.ndimage import label +from functools import partial +import monai +from monai.transforms import AsDiscrete,AsDiscreted,Compose,Invertd,SaveImaged +from monai import transforms, data +from TumorGeneration.utils import synthesize_colon_tumor, synt_model_prepare +import nibabel as nib + +import warnings +warnings.filterwarnings("ignore") + +import argparse +parser = argparse.ArgumentParser(description='colon tumor validation') + +# file dir +parser.add_argument('--data_root', default=None, type=str) +parser.add_argument('--organ_type', default='colon', type=str) +parser.add_argument('--save_dir', default='out', type=str) +parser.add_argument('--data_file', default='out', type=str) +parser.add_argument('--ddim_ts', default=50, type=int) +parser.add_argument('--fg_thresh', default=30, type=int) +parser.add_argument('--start', default=0, type=int) +parser.add_argument('--end', default=1000, type=int) + +def voxel2R(A): + return (np.array(A)/4*3/np.pi)**(1/3) + +class RandCropByPosNegLabeld_select(transforms.RandCropByPosNegLabeld): + def __init__(self, keys, label_key, spatial_size, + pos=1.0, neg=1.0, num_samples=1, + image_key=None, image_threshold=0.0, allow_missing_keys=True, + fg_thresh=0): + super().__init__(keys=keys, label_key=label_key, spatial_size=spatial_size, + pos=pos, neg=neg, num_samples=num_samples, + image_key=image_key, image_threshold=image_threshold, allow_missing_keys=allow_missing_keys) + self.fg_thresh = fg_thresh + + def R2voxel(self,R): + return (4/3*np.pi)*(R)**(3) + + def __call__(self, data): + d = dict(data) + data_name = d['name'] + d.pop('name') + + if '10_Decathlon' in data_name or '05_KiTS' in data_name: + d_crop = super().__call__(d) + + else: + flag=0 + while 1: + flag+=1 + + d_crop = super().__call__(d) + pixel_num = (d_crop[0]['label']>0).sum() + + if pixel_num > self.R2voxel(self.fg_thresh): + break + if flag>5 and pixel_num > self.R2voxel(max(self.fg_thresh-5, 5)): + break + if flag>10 and pixel_num > self.R2voxel(max(self.fg_thresh-10, 5)): + break + if flag>15 and pixel_num > self.R2voxel(max(self.fg_thresh-15, 5)): + break + if flag>20 and pixel_num > self.R2voxel(max(self.fg_thresh-20, 5)): + break + if flag>25 and pixel_num > self.R2voxel(max(self.fg_thresh-25, 5)): + break + if flag>50: + break + + d_crop[0]['name'] = data_name + + return d_crop + +def _get_loader(args): + # val_data_dir = args.val_dir + # datalist_json = args.json_dir + val_org_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label", "raw_image"]), + transforms.AddChanneld(keys=["image", "label", "raw_image"]), + transforms.Orientationd(keys=["image", "label"], axcodes="RAS"), + transforms.Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear")), + transforms.ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), + transforms.SpatialPadd(keys=["image", "label"], mode=["minimum", "constant"], spatial_size=[96, 96, 96]), + RandCropByPosNegLabeld_select( + keys=["image", "label", "name"], + label_key="label", + spatial_size=(96, 96, 96), + pos=1, + neg=0, + num_samples=1, + image_key="image", + image_threshold=0, + fg_thresh = args.fg_thresh, + ), + transforms.ToTensord(keys=["image", "label", "raw_image"]), + ] + ) + + val_img=[] + val_lbl=[] + val_name=[] + + for line in open(args.data_file): + # name = line.strip().split()[1].split('.')[0] + # val_img.append(args.data_root + line.strip().split()[0]) + # val_lbl.append(args.data_root + line.strip().split()[1]) + # breakpoint() + name = line.strip() + val_img.append(os.path.join(args.data_root, name, 'ct.nii.gz')) + val_lbl.append(os.path.join(args.data_root, name, 'segmentations/colon.nii.gz')) + val_name.append(name) + data_dicts_val = [{'image': image, 'raw_image':image, 'label': label, 'name': name} + for image, label, name in zip(val_img, val_lbl, val_name)] + + if args.end < len(data_dicts_val): + data_dicts_val = data_dicts_val[args.start:args.end] + else: + data_dicts_val = data_dicts_val[args.start:] + print('val len {}'.format(len(data_dicts_val))) + val_org_ds = data.Dataset(data_dicts_val, transform=val_org_transform) + val_org_loader = data.DataLoader(val_org_ds, batch_size=1, shuffle=False, num_workers=4, sampler=None, pin_memory=True) + + post_transforms = Compose([ + Invertd( + keys=['image'], + transform=val_org_transform, + orig_keys="image", + nearest_interp=False, + # nearest_interp=True, + to_tensor=True, + ), + Invertd( + keys=['label'], + transform=val_org_transform, + orig_keys="label", + nearest_interp=False, + # nearest_interp=True, + to_tensor=True, + ) + ]) + return val_org_loader, post_transforms + +def main(): + args = parser.parse_args() + output_dir = args.save_dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + print("MAIN Argument values:") + for k, v in vars(args).items(): + print(k, '=>', v) + print('-----------------') + + ## loader and post_transform + val_loader, post_transforms = _get_loader(args) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device) + model.load_state_dict(torch.load("../best_metric_model_classification3d_dict.pth")) + model.eval() + + start_time=0 + with torch.no_grad(): + for idx, val_data in enumerate(val_loader): + print('idx',idx) + if idx == 0: + start_time = time.time() + # val_inputs = val_data["image"] + # name = val_data['label_meta_dict']['filename_or_obj'][0].split('/')[-1].split('.')[0] + + vqgan, noearly_sampler= synt_model_prepare(device = torch.device("cuda"), fold=0, organ=args.organ_type) + + healthy_data, healthy_target, data_names, raw_data = val_data['image'], val_data['label'], val_data['name'], val_data['raw_image'] + case_name = data_names[0].split('/')[-1] + print('case_name', case_name) + original_affine = val_data["label_meta_dict"]["original_affine"][0].numpy() + + if healthy_target.sum() == 0: + val_data = [post_transforms(i) for i in data.decollate_batch(val_data)] + tumor_mask = val_data[0]['label'][0].cpu().numpy().astype(np.uint8) + tumor_mask_ = np.zeros_like(tumor_mask) + nib.save(nib.Nifti1Image(tumor_mask_, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/colon_tumor.nii.gz')) + continue + + healthy_data, healthy_target = healthy_data.cuda(), healthy_target.cuda() + healthy_target = (healthy_target>0).to(healthy_target) + + while 1: + flag+=1 + synt_data, synt_target = synthesize_colon_tumor(healthy_data, healthy_target, vqgan, noearly_sampler, ddim_ts=args.ddim_ts) + + syn_confidence = model(synt_data).sigmoid()[:,1] + if syn_confidence>0.01: + break + elif flag > 20 and syn_confidence>0.005: + break + elif flag > 40 and syn_confidence>0.001: + break + + val_data['image'] = synt_data.detach() + val_data['label'] = synt_target.detach() + + val_data = [post_transforms(i) for i in data.decollate_batch(val_data)] + synt_data = val_data[0]['image'][0] + synt_target = val_data[0]['label'][0] + final_data = raw_data[0,0] + + synt_data = (synt_data*(250+175)-175) + final_data[synt_target>1] = synt_data[synt_target>1] + final_data = final_data.cpu().numpy() + final_label = (synt_target>=1.5).cpu().numpy().astype(np.uint8) + + os.makedirs(os.path.join(output_dir, f'{case_name}'), exist_ok=True) + os.makedirs(os.path.join(output_dir, f'{case_name}/segmentations'), exist_ok=True) + nib.save(nib.Nifti1Image(final_data, original_affine), os.path.join(output_dir, f'{case_name}/ct.nii.gz')) + nib.save(nib.Nifti1Image(final_label, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/colon_tumor.nii.gz')) + + print('time = ', time.time()-start_time) + start_time = time.time() + + +if __name__ == "__main__": + main() diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/.DS_Store b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..fb0d207e124cfecc777faae5b4a56c5ca1b9cd2e Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/.DS_Store differ diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/README.md b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/README.md new file mode 100644 index 0000000000000000000000000000000000000000..201911031ca988c410781a60dc3c192a65ee56b3 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/README.md @@ -0,0 +1,5 @@ +```bash +wget https://www.dropbox.com/scl/fi/k856fhk60kck8uqxtxazw/model_weight.tar.gz?rlkey=hrcn4cbt690dzern1bkbfejll +mv model_weight.tar.gz?rlkey=hrcn4cbt690dzern1bkbfejll model_weight.tar.gz +tar -xzvf model_weight.tar.gz +``` diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/TumorGenerated.py b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/TumorGenerated.py new file mode 100644 index 0000000000000000000000000000000000000000..9983cf047b1532f5edd20c0d4c78102d6d611eb9 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/TumorGenerated.py @@ -0,0 +1,39 @@ +import random +from typing import Hashable, Mapping, Dict + +from monai.config import KeysCollection +from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms.transform import MapTransform, RandomizableTransform + +from .utils_ import SynthesisTumor +import numpy as np + +class TumorGenerated(RandomizableTransform, MapTransform): + def __init__(self, + keys: KeysCollection, + prob: float = 0.1, + tumor_prob = [0.2, 0.2, 0.2, 0.2, 0.2], + allow_missing_keys: bool = False + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob) + random.seed(0) + np.random.seed(0) + + self.tumor_types = ['tiny', 'small', 'medium', 'large', 'mix'] + + assert len(tumor_prob) == 5 + self.tumor_prob = np.array(tumor_prob) + + + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + self.randomize(None) + + if self._do_transform and (np.max(d['label']) <= 1): + tumor_type = np.random.choice(self.tumor_types, p=self.tumor_prob.ravel()) + + d['image'][0], d['label'][0] = SynthesisTumor(d['image'][0], d['label'][0], tumor_type) + # print(tumor_type, d['image'].shape, np.max(d['label'])) + return d diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/__init__.py b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d9bbd5e8cede113145b2742ebdd63d7226fe6396 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/__init__.py @@ -0,0 +1,5 @@ +### Online Version TumorGeneration ### + +from .TumorGenerated import TumorGenerated + +# from .utils import synthesize_early_tumor, synthesize_medium_tumor, synthesize_large_tumor \ No newline at end of file diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0ebfed0cd33d072a0561f1b2c881ab987c39b98 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a0c6fe95cfdcd5c8b93417763b81cf141f23bef Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/__pycache__/utils.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0756ebf64f7e3068fc02220df45239da35516ff Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/__pycache__/utils.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/__pycache__/utils_.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/__pycache__/utils_.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2ccff6985d92f26b80eca3e6ab2d9a009aabe71 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/__pycache__/utils_.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/diffusion_config/ddpm.yaml b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/diffusion_config/ddpm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..474a973b1f76c2a026e2df916c4a153b5bbea05f --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/diffusion_config/ddpm.yaml @@ -0,0 +1,29 @@ +vqgan_ckpt: None + +# Have to be derived from VQ-GAN Latent space dimensions +diffusion_img_size: 24 +diffusion_depth_size: 24 +diffusion_num_channels: 17 # 17 +out_dim: 8 +dim_mults: [1,2,4,8] +results_folder: checkpoints/ddpm/ +results_folder_postfix: 'own_dataset_t2' +load_milestone: False # False + +batch_size: 2 # 40 +num_workers: 20 +logger: wandb +objective: pred_x0 +save_and_sample_every: 1000 +denoising_fn: Unet3D +train_lr: 1e-4 +timesteps: 2 # number of steps +sampling_timesteps: 250 # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) +loss_type: l1 # L1 or L2 +train_num_steps: 700000 # total training steps +gradient_accumulate_every: 2 # gradient accumulation steps +ema_decay: 0.995 # exponential moving average decay +amp: False # turn on mixed precision +num_sample_rows: 1 +gpus: 0 + diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/diffusion_config/vq_gan_3d.yaml b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/diffusion_config/vq_gan_3d.yaml new file mode 100644 index 0000000000000000000000000000000000000000..29940377c5cadb0c06322de8ac60d0713f799024 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/diffusion_config/vq_gan_3d.yaml @@ -0,0 +1,37 @@ +seed: 1234 +batch_size: 2 # 30 +num_workers: 32 # 30 + +gpus: 1 +accumulate_grad_batches: 1 +default_root_dir: checkpoints/vq_gan/ +default_root_dir_postfix: 'flair' +resume_from_checkpoint: +max_steps: -1 +max_epochs: -1 +precision: 16 +gradient_clip_val: 1.0 + + +embedding_dim: 8 # 256 +n_codes: 16384 # 2048 +n_hiddens: 16 +lr: 3e-4 +downsample: [2, 2, 2] # [4, 4, 4] +disc_channels: 64 +disc_layers: 3 +discriminator_iter_start: 10000 # 50000 +disc_loss_type: hinge +image_gan_weight: 1.0 +video_gan_weight: 1.0 +l1_weight: 4.0 +gan_feat_weight: 4.0 # 0.0 +perceptual_weight: 4.0 # 0.0 +i3d_feat: False +restart_thres: 1.0 +no_random_restart: False +norm_type: group +padding_type: replicate +num_groups: 32 + + diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/__init__.py b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a01336e46e37471097d8e1420d0ffd5d803a1edd --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/__init__.py @@ -0,0 +1 @@ +from .diffusion import Unet3D, GaussianDiffusion, Tester diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..800688be0198f7f33c5329c5a467a39ab6f58611 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e795973913e71f310074432d53ecbe72a127b9a2 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37a2c3b0735cf9b01c20cfb80ae9ced9228687c4 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac6f099866a2f7b7329ec1923619ceb7a8114e14 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a4a010ee9f6d95cbccb3c1cb4897eb530d556419 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/ddim.py b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..1bc51f94355732aedb0ffd254b8166144c468370 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/ddim.py @@ -0,0 +1,206 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from .util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): # "uniform" 'quad' + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + # breakpoint() + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, T, H, W = shape + # breakpoint() + size = (batch_size, C, T, H, W) + # print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + # iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(time_range): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + # breakpoint() + e_t = self.model.denoise_fn(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.denoise_fn(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/diffusion.py b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..87111dc3b2ab671e798efc3a320ae81d024f20b9 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/diffusion.py @@ -0,0 +1,1016 @@ +"Largely taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch" + +import math +import copy +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial + +from torch.utils import data +from pathlib import Path +from torch.optim import Adam +from torchvision import transforms as T, utils +from torch.cuda.amp import autocast, GradScaler +from PIL import Image + +from tqdm import tqdm +from einops import rearrange +from einops_exts import check_shape, rearrange_many + +from rotary_embedding_torch import RotaryEmbedding + +from .text import tokenize, bert_embed, BERT_MODEL_DIM +from torch.utils.data import Dataset, DataLoader +from ..vq_gan_3d.model.vqgan import VQGAN + +import matplotlib.pyplot as plt + +# helpers functions + + +def exists(x): + return x is not None + + +def noop(*args, **kwargs): + pass + + +def is_odd(n): + return (n % 2) == 1 + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +def cycle(dl): + while True: + for data in dl: + yield data + + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + + +def prob_mask_like(shape, prob, device): + if prob == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif prob == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob + + +def is_list_str(x): + if not isinstance(x, (list, tuple)): + return False + return all([type(el) == str for el in x]) + +# relative positional bias + + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads=8, + num_buckets=32, + max_distance=128 + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / + max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min( + val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype=torch.long, device=device) + k_pos = torch.arange(n, dtype=torch.long, device=device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket( + rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') + +# small helper modules + + +class EMA(): + def __init__(self, beta): + super().__init__() + self.beta = beta + + def update_model_average(self, ma_model, current_model): + for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + old_weight, up_weight = ma_params.data, current_params.data + ma_params.data = self.update_average(old_weight, up_weight) + + def update_average(self, old, new): + if old is None: + return new + return old * self.beta + (1 - self.beta) * new + + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +def Upsample(dim): + return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) + + +def Downsample(dim): + return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) + + +class LayerNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1)) + + def forward(self, x): + var = torch.var(x, dim=1, unbiased=False, keepdim=True) + mean = torch.mean(x, dim=1, keepdim=True) + return (x - mean) / (var + self.eps).sqrt() * self.gamma + + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = LayerNorm(dim) + + def forward(self, x, **kwargs): + x = self.norm(x) + return self.fn(x, **kwargs) + +# building block modules + + +class Block(nn.Module): + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding=(0, 1, 1)) + self.norm = nn.GroupNorm(groups, dim_out) + self.act = nn.SiLU() + + def forward(self, x, scale_shift=None): + x = self.proj(x) + x = self.norm(x) + + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + + return self.act(x) + + +class ResnetBlock(nn.Module): + def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): + super().__init__() + self.mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(time_emb_dim, dim_out * 2) + ) if exists(time_emb_dim) else None + + self.block1 = Block(dim, dim_out, groups=groups) + self.block2 = Block(dim_out, dim_out, groups=groups) + self.res_conv = nn.Conv3d( + dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb=None): + + scale_shift = None + if exists(self.mlp): + assert exists(time_emb), 'time emb must be passed in' + time_emb = self.mlp(time_emb) + time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') + scale_shift = time_emb.chunk(2, dim=1) + + h = self.block1(x, scale_shift=scale_shift) + + h = self.block2(h) + return h + self.res_conv(x) + + +class SpatialLinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, f, h, w = x.shape + x = rearrange(x, 'b c f h w -> (b f) c h w') + + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = rearrange_many( + qkv, 'b (h c) x y -> b h c (x y)', h=self.heads) + + q = q.softmax(dim=-2) + k = k.softmax(dim=-1) + + q = q * self.scale + context = torch.einsum('b h d n, b h e n -> b h d e', k, v) + + out = torch.einsum('b h d e, b h d n -> b h e n', context, q) + out = rearrange(out, 'b h c (x y) -> b (h c) x y', + h=self.heads, x=h, y=w) + out = self.to_out(out) + return rearrange(out, '(b f) c h w -> b c f h w', b=b) + +# attention along space and time + + +class EinopsToAndFrom(nn.Module): + def __init__(self, from_einops, to_einops, fn): + super().__init__() + self.from_einops = from_einops + self.to_einops = to_einops + self.fn = fn + + def forward(self, x, **kwargs): + shape = x.shape + reconstitute_kwargs = dict( + tuple(zip(self.from_einops.split(' '), shape))) + x = rearrange(x, f'{self.from_einops} -> {self.to_einops}') + x = self.fn(x, **kwargs) + x = rearrange( + x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + heads=4, + dim_head=32, + rotary_emb=None + ): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + + self.rotary_emb = rotary_emb + self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False) + self.to_out = nn.Linear(hidden_dim, dim, bias=False) + + def forward( + self, + x, + pos_bias=None, + focus_present_mask=None + ): + n, device = x.shape[-2], x.device + + qkv = self.to_qkv(x).chunk(3, dim=-1) + + if exists(focus_present_mask) and focus_present_mask.all(): + # if all batch samples are focusing on present + # it would be equivalent to passing that token's values through to the output + values = qkv[-1] + return self.to_out(values) + + # split out heads + + q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads) + + # scale + + q = q * self.scale + + # rotate positions into queries and keys for time attention + + if exists(self.rotary_emb): + q = self.rotary_emb.rotate_queries_or_keys(q) + k = self.rotary_emb.rotate_queries_or_keys(k) + + # similarity + + sim = einsum('... h i d, ... h j d -> ... h i j', q, k) + + # relative positional bias + + if exists(pos_bias): + sim = sim + pos_bias + + if exists(focus_present_mask) and not (~focus_present_mask).all(): + attend_all_mask = torch.ones( + (n, n), device=device, dtype=torch.bool) + attend_self_mask = torch.eye(n, device=device, dtype=torch.bool) + + mask = torch.where( + rearrange(focus_present_mask, 'b -> b 1 1 1 1'), + rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), + rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), + ) + + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + # numerical stability + + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + # aggregate values + + out = einsum('... h i j, ... h j d -> ... h i d', attn, v) + out = rearrange(out, '... h n d -> ... n (h d)') + return self.to_out(out) + +# model + + +class Unet3D(nn.Module): + def __init__( + self, + dim, + cond_dim=None, + out_dim=None, + dim_mults=(1, 2, 4, 8), + channels=3, + attn_heads=8, + attn_dim_head=32, + use_bert_text_cond=False, + init_dim=None, + init_kernel_size=7, + use_sparse_linear_attn=True, + block_type='resnet', + resnet_groups=8 + ): + super().__init__() + self.channels = channels + + # temporal attention and its relative positional encoding + + rotary_emb = RotaryEmbedding(min(32, attn_dim_head)) + + def temporal_attn(dim): return EinopsToAndFrom('b c f h w', 'b (h w) f c', Attention( + dim, heads=attn_heads, dim_head=attn_dim_head, rotary_emb=rotary_emb)) + + # realistically will not be able to generate that many frames of video... yet + self.time_rel_pos_bias = RelativePositionBias( + heads=attn_heads, max_distance=32) + + # initial conv + + init_dim = default(init_dim, dim) + assert is_odd(init_kernel_size) + + init_padding = init_kernel_size // 2 + self.init_conv = nn.Conv3d(channels, init_dim, (1, init_kernel_size, + init_kernel_size), padding=(0, init_padding, init_padding)) + + self.init_temporal_attn = Residual( + PreNorm(init_dim, temporal_attn(init_dim))) + + # dimensions + + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + # time conditioning + + time_dim = dim * 4 + self.time_mlp = nn.Sequential( + SinusoidalPosEmb(dim), + nn.Linear(dim, time_dim), + nn.GELU(), + nn.Linear(time_dim, time_dim) + ) + + # text conditioning + + self.has_cond = exists(cond_dim) or use_bert_text_cond + cond_dim = BERT_MODEL_DIM if use_bert_text_cond else cond_dim + + self.null_cond_emb = nn.Parameter( + torch.randn(1, cond_dim)) if self.has_cond else None + + cond_dim = time_dim + int(cond_dim or 0) + + # layers + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + + num_resolutions = len(in_out) + # block type + + block_klass = partial(ResnetBlock, groups=resnet_groups) + block_klass_cond = partial(block_klass, time_emb_dim=cond_dim) + + # modules for all layers + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + + self.downs.append(nn.ModuleList([ + block_klass_cond(dim_in, dim_out), + block_klass_cond(dim_out, dim_out), + Residual(PreNorm(dim_out, SpatialLinearAttention( + dim_out, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(), + Residual(PreNorm(dim_out, temporal_attn(dim_out))), + Downsample(dim_out) if not is_last else nn.Identity() + ])) + + mid_dim = dims[-1] + self.mid_block1 = block_klass_cond(mid_dim, mid_dim) + + spatial_attn = EinopsToAndFrom( + 'b c f h w', 'b f (h w) c', Attention(mid_dim, heads=attn_heads)) + + self.mid_spatial_attn = Residual(PreNorm(mid_dim, spatial_attn)) + self.mid_temporal_attn = Residual( + PreNorm(mid_dim, temporal_attn(mid_dim))) + + self.mid_block2 = block_klass_cond(mid_dim, mid_dim) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): + is_last = ind >= (num_resolutions - 1) + + self.ups.append(nn.ModuleList([ + block_klass_cond(dim_out * 2, dim_in), + block_klass_cond(dim_in, dim_in), + Residual(PreNorm(dim_in, SpatialLinearAttention( + dim_in, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(), + Residual(PreNorm(dim_in, temporal_attn(dim_in))), + Upsample(dim_in) if not is_last else nn.Identity() + ])) + + out_dim = default(out_dim, channels) + self.final_conv = nn.Sequential( + block_klass(dim * 2, dim), + nn.Conv3d(dim, out_dim, 1) + ) + + def forward_with_cond_scale( + self, + *args, + cond_scale=2., + **kwargs + ): + logits = self.forward(*args, null_cond_prob=0., **kwargs) + if cond_scale == 1 or not self.has_cond: + return logits + + null_logits = self.forward(*args, null_cond_prob=1., **kwargs) + return null_logits + (logits - null_logits) * cond_scale + + def forward( + self, + x, + time, + cond=None, + null_cond_prob=0., + focus_present_mask=None, + # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time) + prob_focus_present=0. + ): + assert not (self.has_cond and not exists(cond) + ), 'cond must be passed in if cond_dim specified' + x = torch.cat([x, cond], dim=1) + + batch, device = x.shape[0], x.device + + focus_present_mask = default(focus_present_mask, lambda: prob_mask_like( + (batch,), prob_focus_present, device=device)) + + time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device=x.device) + + x = self.init_conv(x) + r = x.clone() + + x = self.init_temporal_attn(x, pos_bias=time_rel_pos_bias) + + t = self.time_mlp(time) if exists(self.time_mlp) else None # [2, 128] + + # classifier free guidance + + if self.has_cond: + batch, device = x.shape[0], x.device + mask = prob_mask_like((batch,), null_cond_prob, device=device) + cond = torch.where(rearrange(mask, 'b -> b 1'), + self.null_cond_emb, cond) + t = torch.cat((t, cond), dim=-1) + + h = [] + + for block1, block2, spatial_attn, temporal_attn, downsample in self.downs: + x = block1(x, t) + x = block2(x, t) + x = spatial_attn(x) + x = temporal_attn(x, pos_bias=time_rel_pos_bias, + focus_present_mask=focus_present_mask) + h.append(x) + x = downsample(x) + + # [2, 256, 32, 4, 4] + x = self.mid_block1(x, t) + x = self.mid_spatial_attn(x) + x = self.mid_temporal_attn( + x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask) + x = self.mid_block2(x, t) + + for block1, block2, spatial_attn, temporal_attn, upsample in self.ups: + x = torch.cat((x, h.pop()), dim=1) + x = block1(x, t) + x = block2(x, t) + x = spatial_attn(x) + x = temporal_attn(x, pos_bias=time_rel_pos_bias, + focus_present_mask=focus_present_mask) + x = upsample(x) + + x = torch.cat((x, r), dim=1) + return self.final_conv(x) + +# gaussian diffusion trainer class + + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = torch.linspace(0, timesteps, steps, dtype=torch.float64) + alphas_cumprod = torch.cos( + ((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return torch.clip(betas, 0, 0.9999) + + +class GaussianDiffusion(nn.Module): + def __init__( + self, + denoise_fn, + *, + image_size, + num_frames, + text_use_bert_cls=False, + channels=3, + timesteps=1000, + loss_type='l1', + use_dynamic_thres=False, # from the Imagen paper + dynamic_thres_percentile=0.9, + vqgan_ckpt=None, + device=None + ): + super().__init__() + self.channels = channels + self.image_size = image_size + self.num_frames = num_frames + self.denoise_fn = denoise_fn + self.device = device + + if vqgan_ckpt: + self.vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt).cuda() + self.vqgan.eval() + else: + self.vqgan = None + + betas = cosine_beta_schedule(timesteps) + + alphas = 1. - betas + alphas_cumprod = torch.cumprod(alphas, axis=0) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.loss_type = loss_type + + # register buffer helper function that casts float64 to float32 + + def register_buffer(name, val): return self.register_buffer( + name, val.to(torch.float32)) + + register_buffer('betas', betas) + register_buffer('alphas_cumprod', alphas_cumprod) + register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) + + # calculations for diffusion q(x_t | x_{t-1}) and others + + register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) + register_buffer('sqrt_one_minus_alphas_cumprod', + torch.sqrt(1. - alphas_cumprod)) + register_buffer('log_one_minus_alphas_cumprod', + torch.log(1. - alphas_cumprod)) + register_buffer('sqrt_recip_alphas_cumprod', + torch.sqrt(1. / alphas_cumprod)) + register_buffer('sqrt_recipm1_alphas_cumprod', + torch.sqrt(1. / alphas_cumprod - 1)) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + + posterior_variance = betas * \ + (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + + register_buffer('posterior_variance', posterior_variance) + + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + + register_buffer('posterior_log_variance_clipped', + torch.log(posterior_variance.clamp(min=1e-20))) + register_buffer('posterior_mean_coef1', betas * + torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) + register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) + * torch.sqrt(alphas) / (1. - alphas_cumprod)) + + # text conditioning parameters + + self.text_use_bert_cls = text_use_bert_cls + + # dynamic thresholding when sampling + + self.use_dynamic_thres = use_dynamic_thres + self.dynamic_thres_percentile = dynamic_thres_percentile + + def q_mean_variance(self, x_start, t): + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract(1. - self.alphas_cumprod, t, x_start.shape) + log_variance = extract( + self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract( + self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool, cond=None, cond_scale=1.): + x_recon = self.predict_start_from_noise( + x, t=t, noise=self.denoise_fn.forward_with_cond_scale(x, t, cond=cond, cond_scale=cond_scale)) + + if clip_denoised: + s = 1. + if self.use_dynamic_thres: + s = torch.quantile( + rearrange(x_recon, 'b ... -> b (...)').abs(), + self.dynamic_thres_percentile, + dim=-1 + ) + + s.clamp_(min=1.) + s = s.view(-1, *((1,) * (x_recon.ndim - 1))) + + # clip by threshold, depending on whether static or dynamic + x_recon = x_recon.clamp(-s, s) / s + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.inference_mode() + def p_sample(self, x, t, cond=None, cond_scale=1., clip_denoised=True): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance( + x=x, t=t, clip_denoised=clip_denoised, cond=cond, cond_scale=cond_scale) + noise = torch.randn_like(x) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, + *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.inference_mode() + def p_sample_loop(self, shape, cond=None, cond_scale=1.): + device = self.betas.device + + b = shape[0] + img = torch.randn(shape, device=device) + # print('cond', cond.shape) + for i in reversed(range(0, self.num_timesteps)): + img = self.p_sample(img, torch.full( + (b,), i, device=device, dtype=torch.long), cond=cond, cond_scale=cond_scale) + + return img + + @torch.inference_mode() + def sample(self, cond=None, cond_scale=1., batch_size=16): + device = next(self.denoise_fn.parameters()).device + + if is_list_str(cond): + cond = bert_embed(tokenize(cond)).to(device) + + # batch_size = cond.shape[0] if exists(cond) else batch_size + batch_size = batch_size + image_size = self.image_size + channels = 8 # self.channels + num_frames = self.num_frames + # print((batch_size, channels, num_frames, image_size, image_size)) + # print('cond_',cond.shape) + _sample = self.p_sample_loop( + (batch_size, channels, num_frames, image_size, image_size), cond=cond, cond_scale=cond_scale) + + if isinstance(self.vqgan, VQGAN): + # denormalize TODO: Remove eventually + _sample = (((_sample + 1.0) / 2.0) * (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) + self.vqgan.codebook.embeddings.min() + + _sample = self.vqgan.decode(_sample, quantize=True) + else: + unnormalize_img(_sample) + + return _sample + + @torch.inference_mode() + def interpolate(self, x1, x2, t=None, lam=0.5): + b, *_, device = *x1.shape, x1.device + t = default(t, self.num_timesteps - 1) + + assert x1.shape == x2.shape + + t_batched = torch.stack([torch.tensor(t, device=device)] * b) + xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) + + img = (1 - lam) * xt1 + lam * xt2 + for i in reversed(range(0, t)): + img = self.p_sample(img, torch.full( + (b,), i, device=device, dtype=torch.long)) + + return img + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, + t, x_start.shape) * noise + ) + + def p_losses(self, x_start, t, cond=None, noise=None, **kwargs): + b, c, f, h, w, device = *x_start.shape, x_start.device + noise = default(noise, lambda: torch.randn_like(x_start)) + # breakpoint() + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # [2, 8, 32, 32, 32] + + if is_list_str(cond): + cond = bert_embed( + tokenize(cond), return_cls_repr=self.text_use_bert_cls) + cond = cond.to(device) + + x_recon = self.denoise_fn(x_noisy, t, cond=cond, **kwargs) + + if self.loss_type == 'l1': + loss = F.l1_loss(noise, x_recon) + elif self.loss_type == 'l2': + loss = F.mse_loss(noise, x_recon) + else: + raise NotImplementedError() + + return loss + + def forward(self, x, *args, **kwargs): + bs = int(x.shape[0]/2) + img=x[:bs,...] + mask=x[bs:,...] + mask_=(1-mask).detach() + masked_img = (img*mask_).detach() + masked_img=masked_img.permute(0,1,-1,-3,-2) + img=img.permute(0,1,-1,-3,-2) + mask=mask.permute(0,1,-1,-3,-2) + # breakpoint() + if isinstance(self.vqgan, VQGAN): + with torch.no_grad(): + img = self.vqgan.encode( + img, quantize=False, include_embeddings=True) + # normalize to -1 and 1 + img = ((img - self.vqgan.codebook.embeddings.min()) / + (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + masked_img = self.vqgan.encode( + masked_img, quantize=False, include_embeddings=True) + # normalize to -1 and 1 + masked_img = ((masked_img - self.vqgan.codebook.embeddings.min()) / + (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + else: + print("Hi") + img = normalize_img(img) + masked_img = normalize_img(masked_img) + mask = mask*2.0 - 1.0 + cc = torch.nn.functional.interpolate(mask, size=masked_img.shape[-3:]) + cond = torch.cat((masked_img, cc), dim=1) + + b, device, img_size, = img.shape[0], img.device, self.image_size + t = torch.randint(0, self.num_timesteps, (b,), device=device).long() + # breakpoint() + return self.p_losses(img, t, cond=cond, *args, **kwargs) + +# trainer class + + +CHANNELS_TO_MODE = { + 1: 'L', + 3: 'RGB', + 4: 'RGBA' +} + + +def seek_all_images(img, channels=3): + assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid' + mode = CHANNELS_TO_MODE[channels] + + i = 0 + while True: + try: + img.seek(i) + yield img.convert(mode) + except EOFError: + break + i += 1 + +# tensor of shape (channels, frames, height, width) -> gif + + +def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True): + tensor = ((tensor - tensor.min()) / (tensor.max() - tensor.min())) * 1.0 + images = map(T.ToPILImage(), tensor.unbind(dim=1)) + first_img, *rest_imgs = images + first_img.save(path, save_all=True, append_images=rest_imgs, + duration=duration, loop=loop, optimize=optimize) + return images + +# gif -> (channels, frame, height, width) tensor + + +def gif_to_tensor(path, channels=3, transform=T.ToTensor()): + img = Image.open(path) + tensors = tuple(map(transform, seek_all_images(img, channels=channels))) + return torch.stack(tensors, dim=1) + + +def identity(t, *args, **kwargs): + return t + + +def normalize_img(t): + return t * 2 - 1 + + +def unnormalize_img(t): + return (t + 1) * 0.5 + + +def cast_num_frames(t, *, frames): + f = t.shape[1] + + if f == frames: + return t + + if f > frames: + return t[:, :frames] + + return F.pad(t, (0, 0, 0, 0, 0, frames - f)) + + +class Dataset(data.Dataset): + def __init__( + self, + folder, + image_size, + channels=3, + num_frames=16, + horizontal_flip=False, + force_num_frames=True, + exts=['gif'] + ): + super().__init__() + self.folder = folder + self.image_size = image_size + self.channels = channels + self.paths = [p for ext in exts for p in Path( + f'{folder}').glob(f'**/*.{ext}')] + + self.cast_num_frames_fn = partial( + cast_num_frames, frames=num_frames) if force_num_frames else identity + + self.transform = T.Compose([ + T.Resize(image_size), + T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity), + T.CenterCrop(image_size), + T.ToTensor() + ]) + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = self.paths[index] + tensor = gif_to_tensor(path, self.channels, transform=self.transform) + return self.cast_num_frames_fn(tensor) + +# trainer class + + +class Tester(object): + def __init__( + self, + diffusion_model, + ): + super().__init__() + self.model = diffusion_model + self.ema_model = copy.deepcopy(self.model) + self.step=0 + self.image_size = diffusion_model.image_size + + self.reset_parameters() + + def reset_parameters(self): + self.ema_model.load_state_dict(self.model.state_dict()) + + + def load(self, milestone, map_location=None, **kwargs): + if milestone == -1: + all_milestones = [int(p.stem.split('-')[-1]) + for p in Path(self.results_folder).glob('**/*.pt')] + assert len( + all_milestones) > 0, 'need to have at least one milestone to load from latest checkpoint (milestone == -1)' + milestone = max(all_milestones) + + if map_location: + data = torch.load(milestone, map_location=map_location) + else: + data = torch.load(milestone) + + self.step = data['step'] + self.model.load_state_dict(data['model'], **kwargs) + self.ema_model.load_state_dict(data['ema'], **kwargs) + + diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/text.py b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/text.py new file mode 100644 index 0000000000000000000000000000000000000000..42dca78bec5075fb4f59c522aa3d00cc395d7536 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/text.py @@ -0,0 +1,94 @@ +"Taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch" + +import torch +from einops import rearrange + + +def exists(val): + return val is not None + +# singleton globals + + +MODEL = None +TOKENIZER = None +BERT_MODEL_DIM = 768 + + +def get_tokenizer(): + global TOKENIZER + if not exists(TOKENIZER): + TOKENIZER = torch.hub.load( + 'huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased') + return TOKENIZER + + +def get_bert(): + global MODEL + if not exists(MODEL): + MODEL = torch.hub.load( + 'huggingface/pytorch-transformers', 'model', 'bert-base-cased') + if torch.cuda.is_available(): + MODEL = MODEL.cuda() + + return MODEL + +# tokenize + + +def tokenize(texts, add_special_tokens=True): + if not isinstance(texts, (list, tuple)): + texts = [texts] + + tokenizer = get_tokenizer() + + encoding = tokenizer.batch_encode_plus( + texts, + add_special_tokens=add_special_tokens, + padding=True, + return_tensors='pt' + ) + + token_ids = encoding.input_ids + return token_ids + +# embedding function + + +@torch.no_grad() +def bert_embed( + token_ids, + return_cls_repr=False, + eps=1e-8, + pad_id=0. +): + model = get_bert() + mask = token_ids != pad_id + + if torch.cuda.is_available(): + token_ids = token_ids.cuda() + mask = mask.cuda() + + outputs = model( + input_ids=token_ids, + attention_mask=mask, + output_hidden_states=True + ) + + hidden_state = outputs.hidden_states[-1] + + if return_cls_repr: + # return [cls] as representation + return hidden_state[:, 0] + + if not exists(mask): + return hidden_state.mean(dim=1) + + # mean all tokens excluding [cls], accounting for length + mask = mask[:, 1:] + mask = rearrange(mask, 'b n -> b n 1') + + numer = (hidden_state[:, 1:] * mask).sum(dim=1) + denom = mask.sum(dim=1) + masked_mean = numer / (denom + eps) + return diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/time_embedding.py b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/time_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..7c911fa4f485295005cda9c9c2099db3ddbda15c --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/time_embedding.py @@ -0,0 +1,75 @@ +import math +import torch +import torch.nn as nn + +from monai.networks.layers.utils import get_act_layer + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, emb_dim=16, downscale_freq_shift=1, max_period=10000, flip_sin_to_cos=False): + super().__init__() + self.emb_dim = emb_dim + self.downscale_freq_shift = downscale_freq_shift + self.max_period = max_period + self.flip_sin_to_cos = flip_sin_to_cos + + def forward(self, x): + device = x.device + half_dim = self.emb_dim // 2 + emb = math.log(self.max_period) / \ + (half_dim - self.downscale_freq_shift) + emb = torch.exp(-emb*torch.arange(half_dim, device=device)) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + + if self.flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + if self.emb_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class LearnedSinusoidalPosEmb(nn.Module): + """ following @crowsonkb 's lead with learned sinusoidal pos emb """ + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, emb_dim): + super().__init__() + self.emb_dim = emb_dim + half_dim = emb_dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x): + x = x[:, None] + freqs = x * self.weights[None, :] * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + if self.emb_dim % 2 == 1: + fouriered = torch.nn.functional.pad(fouriered, (0, 1, 0, 0)) + return fouriered + + +class TimeEmbbeding(nn.Module): + def __init__( + self, + emb_dim=64, + pos_embedder=SinusoidalPosEmb, + pos_embedder_kwargs={}, + act_name=("SWISH", {}) # Swish = SiLU + ): + super().__init__() + self.emb_dim = emb_dim + self.pos_emb_dim = pos_embedder_kwargs.get('emb_dim', emb_dim//4) + pos_embedder_kwargs['emb_dim'] = self.pos_emb_dim + self.pos_embedder = pos_embedder(**pos_embedder_kwargs) + + self.time_emb = nn.Sequential( + self.pos_embedder, + nn.Linear(self.pos_emb_dim, self.emb_dim), + get_act_layer(act_name), + nn.Linear(self.emb_dim, self.emb_dim) + ) + + def forward(self, time): + return self.time_emb(time) diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/unet.py b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..e9cae8021e874a92b6baaaf4decd802516c2dc87 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/unet.py @@ -0,0 +1,226 @@ +from ddpm.time_embedding import TimeEmbbeding + +import monai.networks.nets as nets +import torch +import torch.nn as nn +from einops import rearrange + +from monai.networks.blocks import UnetBasicBlock, UnetResBlock, UnetUpBlock, Convolution, UnetOutBlock +from monai.networks.layers.utils import get_act_layer + + +class DownBlock(nn.Module): + def __init__( + self, + spatial_dims, + in_ch, + out_ch, + time_emb_dim, + cond_emb_dim, + act_name=("swish", {}), + **kwargs): + super(DownBlock, self).__init__() + self.loca_time_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(time_emb_dim, in_ch) # in_ch * 2 + ) + if cond_emb_dim is not None: + self.loca_cond_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(cond_emb_dim, in_ch), + ) + self.down_op = UnetBasicBlock( + spatial_dims, in_ch, out_ch, act_name=act_name, **kwargs) + + def forward(self, x, time_emb, cond_emb): + b, c, *_ = x.shape + sp_dim = x.ndim-2 + + # ------------ Time ---------- + time_emb = self.loca_time_embedder(time_emb) + time_emb = time_emb.reshape(b, c, *((1,)*sp_dim)) + # scale, shift = time_emb.chunk(2, dim = 1) + + # ------------ Combine ------------ + # x = x * (scale + 1) + shift + x = x + time_emb + + # ----------- Condition ------------ + if cond_emb is not None: + cond_emb = self.loca_cond_embedder(cond_emb) + cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim)) + x = x + cond_emb + + # ----------- Image --------- + y = self.down_op(x) + return y + + +class UpBlock(nn.Module): + def __init__( + self, + spatial_dims, + skip_ch, + enc_ch, + time_emb_dim, + cond_emb_dim, + act_name=("swish", {}), + **kwargs): + super(UpBlock, self).__init__() + self.up_op = UnetUpBlock(spatial_dims, enc_ch, + skip_ch, act_name=act_name, **kwargs) + self.loca_time_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(time_emb_dim, skip_ch * 2), + ) + if cond_emb_dim is not None: + self.loca_cond_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(cond_emb_dim, skip_ch * 2), + ) + + def forward(self, x_skip, x_enc, time_emb, cond_emb): + b, c, *_ = x_enc.shape + sp_dim = x_enc.ndim-2 + + # ----------- Time -------------- + time_emb = self.loca_time_embedder(time_emb) + time_emb = time_emb.reshape(b, c, *((1,)*sp_dim)) + # scale, shift = time_emb.chunk(2, dim = 1) + + # -------- Combine ------------- + # y = x * (scale + 1) + shift + x_enc = x_enc + time_emb + + # ----------- Condition ------------ + if cond_emb is not None: + cond_emb = self.loca_cond_embedder(cond_emb) + cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim)) + x_enc = x_enc + cond_emb + + # ----------- Image ------------- + y = self.up_op(x_enc, x_skip) + + # -------- Combine ------------- + # y = y * (scale + 1) + shift + + return y + + +class UNet(nn.Module): + + def __init__(self, + in_ch=1, + out_ch=1, + spatial_dims=3, + hid_chs=[32, 64, 128, 256, 512], + kernel_sizes=[(1, 3, 3), (1, 3, 3), (1, 3, 3), 3, 3], + strides=[1, (1, 2, 2), (1, 2, 2), 2, 2], + upsample_kernel_sizes=None, + act_name=("SWISH", {}), + norm_name=("INSTANCE", {"affine": True}), + time_embedder=TimeEmbbeding, + time_embedder_kwargs={}, + cond_embedder=None, + cond_embedder_kwargs={}, + # True = all but last layer, 0/False=disable, 1=only first layer, ... + deep_ver_supervision=True, + estimate_variance=False, + use_self_conditioning=False, + **kwargs + ): + super().__init__() + if upsample_kernel_sizes is None: + upsample_kernel_sizes = strides[1:] + + # ------------- Time-Embedder----------- + self.time_embedder = time_embedder(**time_embedder_kwargs) + + # ------------- Condition-Embedder----------- + if cond_embedder is not None: + self.cond_embedder = cond_embedder(**cond_embedder_kwargs) + cond_emb_dim = self.cond_embedder.emb_dim + else: + self.cond_embedder = None + cond_emb_dim = None + + # ----------- In-Convolution ------------ + in_ch = in_ch*2 if use_self_conditioning else in_ch + self.inc = UnetBasicBlock(spatial_dims, in_ch, hid_chs[0], kernel_size=kernel_sizes[0], stride=strides[0], + act_name=act_name, norm_name=norm_name, **kwargs) + + # ----------- Encoder ---------------- + self.encoders = nn.ModuleList([ + DownBlock(spatial_dims, hid_chs[i-1], hid_chs[i], time_emb_dim=self.time_embedder.emb_dim, + cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[ + i], stride=strides[i], act_name=act_name, + norm_name=norm_name, **kwargs) + for i in range(1, len(strides)) + ]) + + # ------------ Decoder ---------- + self.decoders = nn.ModuleList([ + UpBlock(spatial_dims, hid_chs[i], hid_chs[i+1], time_emb_dim=self.time_embedder.emb_dim, + cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[i + + 1], stride=strides[i+1], act_name=act_name, + norm_name=norm_name, upsample_kernel_size=upsample_kernel_sizes[i], **kwargs) + for i in range(len(strides)-1) + ]) + + # --------------- Out-Convolution ---------------- + out_ch_hor = out_ch*2 if estimate_variance else out_ch + self.outc = UnetOutBlock( + spatial_dims, hid_chs[0], out_ch_hor, dropout=None) + if isinstance(deep_ver_supervision, bool): + deep_ver_supervision = len( + strides)-2 if deep_ver_supervision else 0 + self.outc_ver = nn.ModuleList([ + UnetOutBlock(spatial_dims, hid_chs[i], out_ch, dropout=None) + for i in range(1, deep_ver_supervision+1) + ]) + + def forward(self, x_t, t, cond=None, self_cond=None, **kwargs): + condition = cond + # x_t [B, C, (D), H, W] + # t [B,] + + # -------- In-Convolution -------------- + x = [None for _ in range(len(self.encoders)+1)] + x_t = torch.cat([x_t, self_cond], + dim=1) if self_cond is not None else x_t + x[0] = self.inc(x_t) + + # -------- Time Embedding (Gloabl) ----------- + time_emb = self.time_embedder(t) # [B, C] + + # -------- Condition Embedding (Gloabl) ----------- + if (condition is None) or (self.cond_embedder is None): + cond_emb = None + else: + cond_emb = self.cond_embedder(condition) # [B, C] + + # --------- Encoder -------------- + for i in range(len(self.encoders)): + x[i+1] = self.encoders[i](x[i], time_emb, cond_emb) + + # -------- Decoder ----------- + for i in range(len(self.decoders), 0, -1): + x[i-1] = self.decoders[i-1](x[i-1], x[i], time_emb, cond_emb) + + # ---------Out-Convolution ------------ + y_hor = self.outc(x[0]) + y_ver = [outc_ver_i(x[i+1]) + for i, outc_ver_i in enumerate(self.outc_ver)] + + return y_hor # , y_ver + + def forward_with_cond_scale(self, *args, cond_scale=0., **kwargs): + return self.forward(*args, **kwargs) + + +if __name__ == '__main__': + model = UNet(in_ch=3) + input = torch.randn((1, 3, 16, 128, 128)) + time = torch.randn((1,)) + out_hor, out_ver = model(input, time) + print(out_hor[0].shape) diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/util.py b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..2047cf56802046bbf37cfd33f819ed75a239a7df --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/ddpm/util.py @@ -0,0 +1,271 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +# from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + if c != 1: + steps_out = ddim_timesteps + 1 + else: + steps_out = ddim_timesteps + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac7bdfd0edc6757534df4352a2952faf3c5c588b Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__init__.py b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38c08845c1ee2f7b4b52bf95fa282b0808931a3a --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__init__.py @@ -0,0 +1,3 @@ +from .vqgan import VQGAN +from .codebook import Codebook +from .lpips import LPIPS diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..62702f0203b5aca1df69361cf1271d4659fc63e8 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..809398d341751492f5da4cb0646ed6a5e2a58fd1 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..832b32b8e9b977c5dfdf1fff0f125c414e662370 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ef6892b233146a743404465dc24fd5974e4e72c5 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth new file mode 100644 index 0000000000000000000000000000000000000000..f57dcf5cc764d61c8a460365847fb2137ff0a62d --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868 +size 7289 diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/codebook.py b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/codebook.py new file mode 100644 index 0000000000000000000000000000000000000000..c17ed701a2824cc0dcd75670d47a6ab3842e9d35 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/codebook.py @@ -0,0 +1,109 @@ +""" Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from ..utils import shift_dim + + +class Codebook(nn.Module): + def __init__(self, n_codes, embedding_dim, no_random_restart=False, restart_thres=1.0): + super().__init__() + self.register_buffer('embeddings', torch.randn(n_codes, embedding_dim)) + self.register_buffer('N', torch.zeros(n_codes)) + self.register_buffer('z_avg', self.embeddings.data.clone()) + + self.n_codes = n_codes + self.embedding_dim = embedding_dim + self._need_init = True + self.no_random_restart = no_random_restart + self.restart_thres = restart_thres + + def _tile(self, x): + d, ew = x.shape + if d < self.n_codes: + n_repeats = (self.n_codes + d - 1) // d + std = 0.01 / np.sqrt(ew) + x = x.repeat(n_repeats, 1) + x = x + torch.randn_like(x) * std + return x + + def _init_embeddings(self, z): + # z: [b, c, t, h, w] + self._need_init = False + breakpoint() + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [65536, 8] [2, 8, 32, 32, 32] + y = self._tile(flat_inputs) # [65536, 8] + + d = y.shape[0] + _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + self.embeddings.data.copy_(_k_rand) + self.z_avg.data.copy_(_k_rand) + self.N.data.copy_(torch.ones(self.n_codes)) + + def forward(self, z): + # z: [b, c, t, h, w] + if self._need_init and self.training: + self._init_embeddings(z) + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [bthw, c] [65536, 8] + distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \ + - 2 * flat_inputs @ self.embeddings.t() \ + + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) # [bthw, c] [65536, 8] + + encoding_indices = torch.argmin(distances, dim=1) # [65536] + encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as( + flat_inputs) # [bthw, ncode] [65536, 16384] + encoding_indices = encoding_indices.view( + z.shape[0], *z.shape[2:]) # [b, t, h, w, ncode] [2, 32, 32, 32] + + embeddings = F.embedding( + encoding_indices, self.embeddings) # [b, t, h, w, c] self.embeddings [16384, 8] + embeddings = shift_dim(embeddings, -1, 1) # [b, c, t, h, w] [2, 8, 32, 32, 32] + + commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) + + # EMA codebook update + if self.training: + n_total = encode_onehot.sum(dim=0) # [16384] + encode_sum = flat_inputs.t() @ encode_onehot # [8, 16384] + if dist.is_initialized(): + dist.all_reduce(n_total) + dist.all_reduce(encode_sum) + + self.N.data.mul_(0.99).add_(n_total, alpha=0.01) + self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) + + n = self.N.sum() + weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n + encode_normalized = self.z_avg / weights.unsqueeze(1) + self.embeddings.data.copy_(encode_normalized) + + y = self._tile(flat_inputs) + _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + + if not self.no_random_restart: + usage = (self.N.view(self.n_codes, 1) + >= self.restart_thres).float() + self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) + + embeddings_st = (embeddings - z).detach() + z + + avg_probs = torch.mean(encode_onehot, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * + torch.log(avg_probs + 1e-10))) + + return dict(embeddings=embeddings_st, encodings=encoding_indices, + commitment_loss=commitment_loss, perplexity=perplexity) + + def dictionary_lookup(self, encodings): + embeddings = F.embedding(encodings, self.embeddings) + return embeddings diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/lpips.py b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..019434b9b29945d67e6e0a95dec324df7ff908f3 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/lpips.py @@ -0,0 +1,181 @@ +"""Adapted from https://github.com/SongweiGe/TATS""" + +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + + +from collections import namedtuple +from torchvision import models +import torch.nn as nn +import torch +from tqdm import tqdm +import requests +import os +import hashlib +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format( + name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + # self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + self.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name is not "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + model.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer( + input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor( + outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) + for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor( + [-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor( + [.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, + padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple( + "VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, + h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x/(norm_factor+eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py new file mode 100644 index 0000000000000000000000000000000000000000..a64ec8878c0ecbf418775e2958b2cbf578ce2918 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py @@ -0,0 +1,561 @@ +"""Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import math +import argparse +import numpy as np +import pickle as pkl + +import pytorch_lightning as pl +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from ..utils import shift_dim, adopt_weight, comp_getattr +from .lpips import LPIPS +from .codebook import Codebook + + +def silu(x): + return x*torch.sigmoid(x) + + +class SiLU(nn.Module): + def __init__(self): + super(SiLU, self).__init__() + + def forward(self, x): + return silu(x) + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1. - logits_real)) + loss_fake = torch.mean(F.relu(1. + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake))) + return d_loss + + +class VQGAN(pl.LightningModule): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.embedding_dim = cfg.model.embedding_dim # 8 + self.n_codes = cfg.model.n_codes # 16384 + + self.encoder = Encoder(cfg.model.n_hiddens, # 16 + cfg.model.downsample, # [2, 2, 2] + cfg.dataset.image_channels, # 1 + cfg.model.norm_type, # group + cfg.model.padding_type, # replicate + cfg.model.num_groups, # 32 + ) + self.decoder = Decoder( + cfg.model.n_hiddens, cfg.model.downsample, cfg.dataset.image_channels, cfg.model.norm_type, cfg.model.num_groups) + self.enc_out_ch = self.encoder.out_channels + self.pre_vq_conv = SamePadConv3d( + self.enc_out_ch, cfg.model.embedding_dim, 1, padding_type=cfg.model.padding_type) + self.post_vq_conv = SamePadConv3d( + cfg.model.embedding_dim, self.enc_out_ch, 1) + + self.codebook = Codebook(cfg.model.n_codes, cfg.model.embedding_dim, + no_random_restart=cfg.model.no_random_restart, restart_thres=cfg.model.restart_thres) + + self.gan_feat_weight = cfg.model.gan_feat_weight + # TODO: Changed batchnorm from sync to normal + self.image_discriminator = NLayerDiscriminator( + cfg.dataset.image_channels, cfg.model.disc_channels, cfg.model.disc_layers, norm_layer=nn.BatchNorm2d) + self.video_discriminator = NLayerDiscriminator3D( + cfg.dataset.image_channels, cfg.model.disc_channels, cfg.model.disc_layers, norm_layer=nn.BatchNorm3d) + + if cfg.model.disc_loss_type == 'vanilla': + self.disc_loss = vanilla_d_loss + elif cfg.model.disc_loss_type == 'hinge': + self.disc_loss = hinge_d_loss + + self.perceptual_model = LPIPS().eval() + + self.image_gan_weight = cfg.model.image_gan_weight + self.video_gan_weight = cfg.model.video_gan_weight + + self.perceptual_weight = cfg.model.perceptual_weight + + self.l1_weight = cfg.model.l1_weight + self.save_hyperparameters() + + def encode(self, x, include_embeddings=False, quantize=True): + h = self.pre_vq_conv(self.encoder(x)) + if quantize: + vq_output = self.codebook(h) + if include_embeddings: + return vq_output['embeddings'], vq_output['encodings'] + else: + return vq_output['encodings'] + return h + + def decode(self, latent, quantize=False): + if quantize: + vq_output = self.codebook(latent) + latent = vq_output['encodings'] + h = F.embedding(latent, self.codebook.embeddings) + h = self.post_vq_conv(shift_dim(h, -1, 1)) + return self.decoder(h) + + def forward(self, x, optimizer_idx=None, log_image=False): + B, C, T, H, W = x.shape + z = self.pre_vq_conv(self.encoder(x)) # [2, 32, 32, 32, 32] [2, 8, 32, 32, 32] + vq_output = self.codebook(z) # ['embeddings', 'encodings', 'commitment_loss', 'perplexity'] + x_recon = self.decoder(self.post_vq_conv(vq_output['embeddings'])) # [2, 8, 32, 32, 32] [2, 32, 32, 32, 32] + + recon_loss = F.l1_loss(x_recon, x) * self.l1_weight + + # Selects one random 2D image from each 3D Image + frame_idx = torch.randint(0, T, [B]).cuda() + frame_idx_selected = frame_idx.reshape(-1, + 1, 1, 1, 1).repeat(1, C, 1, H, W) # [2, 1, 1, 64, 64] + frames = torch.gather(x, 2, frame_idx_selected).squeeze(2) # [2, 1, 64, 64] + frames_recon = torch.gather(x_recon, 2, frame_idx_selected).squeeze(2) # [2, 1, 64, 64] + + if log_image: + return frames, frames_recon, x, x_recon + + if optimizer_idx == 0: + # Autoencoder - train the "generator" + + # Perceptual loss + perceptual_loss = 0 + if self.perceptual_weight > 0: + perceptual_loss = self.perceptual_model( + frames, frames_recon).mean() * self.perceptual_weight + + # Discriminator loss (turned on after a certain epoch) + logits_image_fake, pred_image_fake = self.image_discriminator( + frames_recon) + logits_video_fake, pred_video_fake = self.video_discriminator( + x_recon) + g_image_loss = -torch.mean(logits_image_fake) + g_video_loss = -torch.mean(logits_video_fake) + g_loss = self.image_gan_weight*g_image_loss + self.video_gan_weight*g_video_loss + disc_factor = adopt_weight( + self.global_step, threshold=self.cfg.model.discriminator_iter_start) + aeloss = disc_factor * g_loss + + # GAN feature matching loss - tune features such that we get the same prediction result on the discriminator + image_gan_feat_loss = 0 + video_gan_feat_loss = 0 + feat_weights = 4.0 / (3 + 1) + if self.image_gan_weight > 0: + logits_image_real, pred_image_real = self.image_discriminator( + frames) + for i in range(len(pred_image_fake)-1): + image_gan_feat_loss += feat_weights * \ + F.l1_loss(pred_image_fake[i], pred_image_real[i].detach( + )) * (self.image_gan_weight > 0) + if self.video_gan_weight > 0: + logits_video_real, pred_video_real = self.video_discriminator( + x) + for i in range(len(pred_video_fake)-1): + video_gan_feat_loss += feat_weights * \ + F.l1_loss(pred_video_fake[i], pred_video_real[i].detach( + )) * (self.video_gan_weight > 0) + gan_feat_loss = disc_factor * self.gan_feat_weight * \ + (image_gan_feat_loss + video_gan_feat_loss) + + self.log("train/g_image_loss", g_image_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/g_video_loss", g_video_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/image_gan_feat_loss", image_gan_feat_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/video_gan_feat_loss", video_gan_feat_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/perceptual_loss", perceptual_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log("train/recon_loss", recon_loss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + self.log("train/aeloss", aeloss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + self.log("train/commitment_loss", vq_output['commitment_loss'], + prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log('train/perplexity', vq_output['perplexity'], + prog_bar=True, logger=True, on_step=True, on_epoch=True) + return recon_loss, x_recon, vq_output, aeloss, perceptual_loss, gan_feat_loss + + if optimizer_idx == 1: + # Train discriminator + logits_image_real, _ = self.image_discriminator(frames.detach()) + logits_video_real, _ = self.video_discriminator(x.detach()) + + logits_image_fake, _ = self.image_discriminator( + frames_recon.detach()) + logits_video_fake, _ = self.video_discriminator(x_recon.detach()) + + d_image_loss = self.disc_loss(logits_image_real, logits_image_fake) + d_video_loss = self.disc_loss(logits_video_real, logits_video_fake) + disc_factor = adopt_weight( + self.global_step, threshold=self.cfg.model.discriminator_iter_start) + discloss = disc_factor * \ + (self.image_gan_weight*d_image_loss + + self.video_gan_weight*d_video_loss) + + self.log("train/logits_image_real", logits_image_real.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_image_fake", logits_image_fake.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_video_real", logits_video_real.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_video_fake", logits_video_fake.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/d_image_loss", d_image_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/d_video_loss", d_video_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/discloss", discloss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + return discloss + + perceptual_loss = self.perceptual_model( + frames, frames_recon) * self.perceptual_weight + return recon_loss, x_recon, vq_output, perceptual_loss + + def training_step(self, batch, batch_idx, optimizer_idx): + x = batch['image'] + if optimizer_idx == 0: + recon_loss, _, vq_output, aeloss, perceptual_loss, gan_feat_loss = self.forward( + x, optimizer_idx) + commitment_loss = vq_output['commitment_loss'] + loss = recon_loss + commitment_loss + aeloss + perceptual_loss + gan_feat_loss + if optimizer_idx == 1: + discloss = self.forward(x, optimizer_idx) + loss = discloss + return loss + + def validation_step(self, batch, batch_idx): + x = batch['image'] # TODO: batch['stft'] + recon_loss, _, vq_output, perceptual_loss = self.forward(x) + self.log('val/recon_loss', recon_loss, prog_bar=True) + self.log('val/perceptual_loss', perceptual_loss, prog_bar=True) + self.log('val/perplexity', vq_output['perplexity'], prog_bar=True) + self.log('val/commitment_loss', + vq_output['commitment_loss'], prog_bar=True) + + def configure_optimizers(self): + lr = self.cfg.model.lr + opt_ae = torch.optim.Adam(list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.pre_vq_conv.parameters()) + + list(self.post_vq_conv.parameters()) + + list(self.codebook.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(list(self.image_discriminator.parameters()) + + list(self.video_discriminator.parameters()), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def log_images(self, batch, **kwargs): + log = dict() + x = batch['image'] + x = x.to(self.device) + frames, frames_rec, _, _ = self(x, log_image=True) + log["inputs"] = frames + log["reconstructions"] = frames_rec + #log['mean_org'] = batch['mean_org'] + #log['std_org'] = batch['std_org'] + return log + + def log_videos(self, batch, **kwargs): + log = dict() + x = batch['image'] + _, _, x, x_rec = self(x, log_image=True) + log["inputs"] = x + log["reconstructions"] = x_rec + #log['mean_org'] = batch['mean_org'] + #log['std_org'] = batch['std_org'] + return log + + +def Normalize(in_channels, norm_type='group', num_groups=32): + assert norm_type in ['group', 'batch'] + if norm_type == 'group': + # TODO Changed num_groups from 32 to 8 + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == 'batch': + return torch.nn.SyncBatchNorm(in_channels) + + +class Encoder(nn.Module): + def __init__(self, n_hiddens, downsample, image_channel=3, norm_type='group', padding_type='replicate', num_groups=32): + super().__init__() + n_times_downsample = np.array([int(math.log2(d)) for d in downsample]) + self.conv_blocks = nn.ModuleList() + max_ds = n_times_downsample.max() + + self.conv_first = SamePadConv3d( + image_channel, n_hiddens, kernel_size=3, padding_type=padding_type) + + for i in range(max_ds): + block = nn.Module() + in_channels = n_hiddens * 2**i + out_channels = n_hiddens * 2**(i+1) + stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) + block.down = SamePadConv3d( + in_channels, out_channels, 4, stride=stride, padding_type=padding_type) + block.res = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + self.conv_blocks.append(block) + n_times_downsample -= 1 + + self.final_block = nn.Sequential( + Normalize(out_channels, norm_type, num_groups=num_groups), + SiLU() + ) + + self.out_channels = out_channels + + def forward(self, x): + h = self.conv_first(x) + for block in self.conv_blocks: + h = block.down(h) + h = block.res(h) + h = self.final_block(h) + return h + + +class Decoder(nn.Module): + def __init__(self, n_hiddens, upsample, image_channel, norm_type='group', num_groups=32): + super().__init__() + + n_times_upsample = np.array([int(math.log2(d)) for d in upsample]) + max_us = n_times_upsample.max() + + in_channels = n_hiddens*2**max_us + self.final_block = nn.Sequential( + Normalize(in_channels, norm_type, num_groups=num_groups), + SiLU() + ) + + self.conv_blocks = nn.ModuleList() + for i in range(max_us): + block = nn.Module() + in_channels = in_channels if i == 0 else n_hiddens*2**(max_us-i+1) + out_channels = n_hiddens*2**(max_us-i) + us = tuple([2 if d > 0 else 1 for d in n_times_upsample]) + block.up = SamePadConvTranspose3d( + in_channels, out_channels, 4, stride=us) + block.res1 = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + block.res2 = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + self.conv_blocks.append(block) + n_times_upsample -= 1 + + self.conv_last = SamePadConv3d( + out_channels, image_channel, kernel_size=3) + + def forward(self, x): + h = self.final_block(x) + for i, block in enumerate(self.conv_blocks): + h = block.up(h) + h = block.res1(h) + h = block.res2(h) + h = self.conv_last(h) + return h + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group', padding_type='replicate', num_groups=32): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, norm_type, num_groups=num_groups) + self.conv1 = SamePadConv3d( + in_channels, out_channels, kernel_size=3, padding_type=padding_type) + self.dropout = torch.nn.Dropout(dropout) + self.norm2 = Normalize(in_channels, norm_type, num_groups=num_groups) + self.conv2 = SamePadConv3d( + out_channels, out_channels, kernel_size=3, padding_type=padding_type) + if self.in_channels != self.out_channels: + self.conv_shortcut = SamePadConv3d( + in_channels, out_channels, kernel_size=3, padding_type=padding_type) + + def forward(self, x): + h = x + h = self.norm1(h) + h = silu(h) + h = self.conv1(h) + h = self.norm2(h) + h = silu(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.conv_shortcut(x) + + return x+h + + +# Does not support dilation +class SamePadConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + if isinstance(stride, int): + stride = (stride,) * 3 + + # assumes that the input shape is divisible by stride + total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) + pad_input = [] + for p in total_pad[::-1]: # reverse since F.pad starts from last dim + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + self.pad_input = pad_input + self.padding_type = padding_type + + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, + stride=stride, padding=0, bias=bias) + + def forward(self, x): + return self.conv(F.pad(x, self.pad_input, mode=self.padding_type)) + + +class SamePadConvTranspose3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + if isinstance(stride, int): + stride = (stride,) * 3 + + total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) + pad_input = [] + for p in total_pad[::-1]: # reverse since F.pad starts from last dim + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + self.pad_input = pad_input + self.padding_type = padding_type + + self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, + stride=stride, bias=bias, + padding=tuple([k - 1 for k in kernel_size])) + + def forward(self, x): + return self.convt(F.pad(x, self.pad_input, mode=self.padding_type)) + + +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True): + # def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=True): + super(NLayerDiscriminator, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, + stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, + stride=1, padding=padw)]] + + if use_sigmoid: + sequence += [[nn.Sigmoid()]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[-1], res[1:] + else: + return self.model(input), _ + + +class NLayerDiscriminator3D(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True): + super(NLayerDiscriminator3D, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv3d(input_nc, ndf, kernel_size=kw, + stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv3d(nf, 1, kernel_size=kw, + stride=1, padding=padw)]] + + if use_sigmoid: + sequence += [[nn.Sigmoid()]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[-1], res[1:] + else: + return self.model(input), _ diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/utils.py b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..acf1a21270d98d0a1ea9935d00f9f16d89d54551 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/ldm/vq_gan_3d/utils.py @@ -0,0 +1,177 @@ +""" Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import warnings +import torch +import imageio + +import math +import numpy as np + +import sys +import pdb as pdb_original +import logging + +import imageio.core.util +logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR) + + +class ForkedPdb(pdb_original.Pdb): + """A Pdb subclass that may be used + from a forked multiprocessing child + + """ + + def interaction(self, *args, **kwargs): + _stdin = sys.stdin + try: + sys.stdin = open('/dev/stdin') + pdb_original.Pdb.interaction(self, *args, **kwargs) + finally: + sys.stdin = _stdin + + +# Shifts src_tf dim to dest dim +# i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c) +def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): + n_dims = len(x.shape) + if src_dim < 0: + src_dim = n_dims + src_dim + if dest_dim < 0: + dest_dim = n_dims + dest_dim + + assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims + + dims = list(range(n_dims)) + del dims[src_dim] + + permutation = [] + ctr = 0 + for i in range(n_dims): + if i == dest_dim: + permutation.append(src_dim) + else: + permutation.append(dims[ctr]) + ctr += 1 + x = x.permute(permutation) + if make_contiguous: + x = x.contiguous() + return x + + +# reshapes tensor start from dim i (inclusive) +# to dim j (exclusive) to the desired shape +# e.g. if x.shape = (b, thw, c) then +# view_range(x, 1, 2, (t, h, w)) returns +# x of shape (b, t, h, w, c) +def view_range(x, i, j, shape): + shape = tuple(shape) + + n_dims = len(x.shape) + if i < 0: + i = n_dims + i + + if j is None: + j = n_dims + elif j < 0: + j = n_dims + j + + assert 0 <= i < j <= n_dims + + x_shape = x.shape + target_shape = x_shape[:i] + shape + x_shape[j:] + return x.view(target_shape) + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def tensor_slice(x, begin, size): + assert all([b >= 0 for b in begin]) + size = [l - b if s == -1 else s + for s, b, l in zip(size, begin, x.shape)] + assert all([s >= 0 for s in size]) + + slices = [slice(b, b + s) for b, s in zip(begin, size)] + return x[slices] + + +def adopt_weight(global_step, threshold=0, value=0.): + weight = 1 + if global_step < threshold: + weight = value + return weight + + +def save_video_grid(video, fname, nrow=None, fps=6): + b, c, t, h, w = video.shape + video = video.permute(0, 2, 3, 4, 1) + video = (video.cpu().numpy() * 255).astype('uint8') + if nrow is None: + nrow = math.ceil(math.sqrt(b)) + ncol = math.ceil(b / nrow) + padding = 1 + video_grid = np.zeros((t, (padding + h) * nrow + padding, + (padding + w) * ncol + padding, c), dtype='uint8') + for i in range(b): + r = i // ncol + c = i % ncol + start_r = (padding + h) * r + start_c = (padding + w) * c + video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] + video = [] + for i in range(t): + video.append(video_grid[i]) + imageio.mimsave(fname, video, fps=fps) + ## skvideo.io.vwrite(fname, video_grid, inputdict={'-r': '5'}) + #print('saved videos to', fname) + + +def comp_getattr(args, attr_name, default=None): + if hasattr(args, attr_name): + return getattr(args, attr_name) + else: + return default + + +def visualize_tensors(t, name=None, nest=0): + if name is not None: + print(name, "current nest: ", nest) + print("type: ", type(t)) + if 'dict' in str(type(t)): + print(t.keys()) + for k in t.keys(): + if t[k] is None: + print(k, "None") + else: + if 'Tensor' in str(type(t[k])): + print(k, t[k].shape) + elif 'dict' in str(type(t[k])): + print(k, 'dict') + visualize_tensors(t[k], name, nest + 1) + elif 'list' in str(type(t[k])): + print(k, len(t[k])) + visualize_tensors(t[k], name, nest + 1) + elif 'list' in str(type(t)): + print("list length: ", len(t)) + for t2 in t: + visualize_tensors(t2, name, nest + 1) + elif 'Tensor' in str(type(t)): + print(t.shape) + else: + print(t) + return "" diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/utils.py b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..d9d7972cd1ad2f2e4e1070d747c50df734755a68 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/utils.py @@ -0,0 +1,233 @@ +### Tumor Generateion +import random +import cv2 +import elasticdeform +import numpy as np +from scipy.ndimage import gaussian_filter +from TumorGeneration.ldm.ddpm.ddim import DDIMSampler +import skimage + +def im2col(A, BSZ, stepsize=1): + # Parameters + M, N = A.shape + # Get Starting block indices + start_idx = np.arange( + 0, M-BSZ[0]+1, stepsize)[:, None]*N + np.arange(0, N-BSZ[1]+1, stepsize) + # Get offsetted indices across the height and width of input array + offset_idx = np.arange(BSZ[0])[:, None]*N + np.arange(BSZ[1]) + # Get all actual indices & index into input array for final output + return np.take(A, start_idx.ravel()[:, None] + offset_idx.ravel()) + +def seg_to_instance_bd(seg: np.ndarray, + tsz_h: int = 1) -> np.ndarray: + """Generate instance contour map from segmentation masks. + """ + + tsz = tsz_h*2+1 + tsz=int(tsz) + kernel = np.ones((tsz, tsz, tsz), np.uint8) + dilated_seg_mask = skimage.morphology.binary_erosion(seg.astype('uint8'), kernel) + + dilated_seg_mask = dilated_seg_mask.astype(np.uint8) + bd = seg-dilated_seg_mask + bd = (bd>0).astype('uint8') + + return bd + +def sector_mask(shape,centre,radius,angle_range): + """ + Return a boolean mask for a circular sector. The start/stop angles in + `angle_range` should be given in clockwise order. + """ + + x,y = np.ogrid[:shape[0],:shape[1]] + cx,cy = centre + tmin,tmax = np.deg2rad(angle_range) + + # ensure stop angle > start angle + if tmax < tmin: + tmax += 2*np.pi + + # convert cartesian --> polar coordinates + r2 = (x-cx)*(x-cx) + (y-cy)*(y-cy) + theta = np.arctan2(x-cx,y-cy) - tmin + + # wrap angles between 0 and 2*pi + theta %= (2*np.pi) + + # circular mask + circmask = r2 <= radius*radius + + # angular mask + anglemask = theta <= (tmax-tmin) + + return circmask*anglemask + +from scipy.ndimage import label +import elasticdeform +def generate_random_mask(organ_mask): + # initialize tumor mask + tumor_mask = np.zeros_like(organ_mask) + + # randowm mask angle + start_angle = random.randint(0, 360) + angle_range = random.randint(90, 360) + + # generate organ boundary + erode_sz = angle_range//45 * 1 + 3 + # select_size = [3.5, 4, 4.5, 5.0, 5.5, 6.0] + # erode_sz = np.random.choice(select_size) + # print('erode_sz', erode_sz) + organ_bd = seg_to_instance_bd(organ_mask, tsz_h=erode_sz) + + # organ mask range + z_valid_list = np.where(np.any(organ_bd, axis=(0, 1)))[0] + valid_num = len(z_valid_list) + z_valid_list = z_valid_list[round(valid_num*0.25):round(valid_num*0.75)] + # print(z_valid_list) + z = random.choice(z_valid_list) + + # sample thickness + z_thickness = random.randint(10, 20) # 10-20 + # print('z, z_thickness', z, z_thickness) + # crop + tumor_mask[:,:,max(0,z-z_thickness):min(95,z+z_thickness)] = organ_bd[:,:,max(0,z-z_thickness):min(95,z+z_thickness)] + + # random select one + tumor_mask, nb = label(tumor_mask) + sample_id = random.randint(1, nb) + sample_tumor_mask = (tumor_mask==sample_id).astype(np.uint8) + + z_valid = np.where(np.any(sample_tumor_mask, axis=(0, 1)))[0] + z = z_valid[round(0.5 * len(z_valid))] + + # randowm mask region + selected_slice = sample_tumor_mask[..., z] + coordinates = np.argwhere(selected_slice == 1) + center_x, center_y = int(coordinates[:,0].mean()), int(coordinates[:,1].mean()) + # start_angle = random.randint(0, 360) + # angle_range = random.randint(90, 360) + mask_region = sector_mask(selected_slice.shape,(center_x,center_y), 48, (start_angle,start_angle+angle_range)) + mask_region = np.repeat(mask_region[:,:,np.newaxis], axis=-1, repeats=96) + + # elasticdeform + # sigma = random.uniform(1,2) + sigma = random.uniform(2,5) + # sigma = random.uniform(5,10) + deform_tumor_mask = elasticdeform.deform_random_grid(sample_tumor_mask, sigma=sigma, points=3, order=0, axis=(0,1)) + # deform_tumor_mask = elasticdeform.deform_random_grid(deform_tumor_mask, sigma=sigma, points=3, order=0, axis=(1,2)) + # deform_tumor_mask = elasticdeform.deform_random_grid(deform_tumor_mask, sigma=sigma, points=3, order=0, axis=(0,2)) + + # final_tumor_mask = deform_tumor_mask*mask_region*organ_mask + final_tumor_mask = deform_tumor_mask*mask_region + + return final_tumor_mask + + +from .ldm.vq_gan_3d.model.vqgan import VQGAN +import matplotlib.pyplot as plt +import SimpleITK as sitk +from .ldm.ddpm import Unet3D, GaussianDiffusion, Tester +from hydra import initialize, compose +import torch +import yaml +def synt_model_prepare(device, vqgan_ckpt='TumorGeneration/model_weight/recon_colon.ckpt', diffusion_ckpt='TumorGeneration/model_weight/', fold=0, organ='colon'): + with initialize(config_path="diffusion_config/"): + cfg=compose(config_name="ddpm.yaml") + vqgan_ckpt = 'TumorGeneration/model_weight/recon_colon.ckpt' + diffusion_ckpt = 'TumorGeneration/model_weight/diffusion_colon.pt' + print('diffusion_ckpt',diffusion_ckpt) + vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt) + vqgan = vqgan.to(device) + vqgan.eval() + + + noearly_Unet3D = Unet3D( + dim=cfg.diffusion_img_size, + dim_mults=cfg.dim_mults, + channels=cfg.diffusion_num_channels, + out_dim=cfg.out_dim + ).to(device) + + noearly_diffusion = GaussianDiffusion( + noearly_Unet3D, + vqgan_ckpt= vqgan_ckpt, # cfg.vqgan_ckpt, + image_size=cfg.diffusion_img_size, + num_frames=cfg.diffusion_depth_size, + channels=cfg.diffusion_num_channels, + timesteps=200, # cfg.timesteps, + loss_type=cfg.loss_type, + device=device + ).to(device) + + noearly_checkpoint = torch.load(diffusion_ckpt, map_location=device) + + noearly_diffusion.load_state_dict(noearly_checkpoint['ema']) + + noearly_sampler = DDIMSampler(noearly_diffusion, schedule="cosine") + # breakpoint() + return vqgan, noearly_sampler + +def synthesize_colon_tumor(ct_volume, organ_mask, vqgan, sampler, ddim_ts=50): + device=ct_volume.device + + # generate tumor mask + total_tumor_mask = [] + organ_mask_np = organ_mask.cpu().numpy() + with torch.no_grad(): + # get model input + for bs in range(organ_mask_np.shape[0]): + tumor_mask = generate_random_mask(organ_mask_np[bs,0]) + # tumor_mask = organ_mask_np[bs,0] + total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:]) + total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device) + # breakpoint() + volume = ct_volume*2.0 - 1.0 + mask = total_tumor_mask*2.0 - 1.0 + mask_ = 1-total_tumor_mask + masked_volume = (volume*mask_).detach() + + volume = volume.permute(0,1,-1,-3,-2) + masked_volume = masked_volume.permute(0,1,-1,-3,-2) + mask = mask.permute(0,1,-1,-3,-2) + + # vqgan encoder inference + masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True) + masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) / + (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:]) + cond = torch.cat((masked_volume_feat, cc), dim=1) + + # diffusion inference and decoder + shape = masked_volume_feat.shape[-4:] + samples_ddim, _ = sampler.sample(S=ddim_ts, + conditioning=cond, + batch_size=1, + shape=shape, + verbose=False) + samples_ddim = (((samples_ddim + 1.0) / 2.0) * (vqgan.codebook.embeddings.max() - + vqgan.codebook.embeddings.min())) + vqgan.codebook.embeddings.min() + + sample = vqgan.decode(samples_ddim, quantize=True) + + # post-process + mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0) + sigma = np.random.uniform(1, 2) # (1, 2) + mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma]) + # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy() + + volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0) + sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0) + + mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device) + final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_ + final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0) + + # final_volume_ = (sample+1.0)/2.0 + final_volume_ = final_volume_.permute(0,1,-2,-1,-3) + organ_tumor_mask = torch.zeros_like(organ_mask) + organ_tumor_mask[organ_mask==1] = 1 + organ_tumor_mask[total_tumor_mask==1] = 2 + + return final_volume_, organ_tumor_mask diff --git a/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/utils_.py b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/utils_.py new file mode 100644 index 0000000000000000000000000000000000000000..312e72d1571b3cd996fd567bcedf939443a4e182 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_colon/TumorGeneration/utils_.py @@ -0,0 +1,298 @@ +### Tumor Generateion +import random +import cv2 +import elasticdeform +import numpy as np +from scipy.ndimage import gaussian_filter + +# Step 1: Random select (numbers) location for tumor. +def random_select(mask_scan): + # we first find z index and then sample point with z slice + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0][[0, -1]] + + # we need to strict number z's position (0.3 - 0.7 in the middle of liver) + z = round(random.uniform(0.3, 0.7) * (z_end - z_start)) + z_start + + liver_mask = mask_scan[..., z] + + # erode the mask (we don't want the edge points) + kernel = np.ones((5,5), dtype=np.uint8) + liver_mask = cv2.erode(liver_mask, kernel, iterations=1) + + coordinates = np.argwhere(liver_mask == 1) + random_index = np.random.randint(0, len(coordinates)) + xyz = coordinates[random_index].tolist() # get x,y + xyz.append(z) + potential_points = xyz + + return potential_points + +# Step 2 : generate the ellipsoid +def get_ellipsoid(x, y, z): + """" + x, y, z is the radius of this ellipsoid in x, y, z direction respectly. + """ + sh = (4*x, 4*y, 4*z) + out = np.zeros(sh, int) + aux = np.zeros(sh) + radii = np.array([x, y, z]) + com = np.array([2*x, 2*y, 2*z]) # center point + + # calculate the ellipsoid + bboxl = np.floor(com-radii).clip(0,None).astype(int) + bboxh = (np.ceil(com+radii)+1).clip(None, sh).astype(int) + roi = out[tuple(map(slice,bboxl,bboxh))] + roiaux = aux[tuple(map(slice,bboxl,bboxh))] + logrid = *map(np.square,np.ogrid[tuple( + map(slice,(bboxl-com)/radii,(bboxh-com-1)/radii,1j*(bboxh-bboxl)))]), + dst = (1-sum(logrid)).clip(0,None) + mask = dst>roiaux + roi[mask] = 1 + np.copyto(roiaux,dst,where=mask) + + return out + +def get_fixed_geo(mask_scan, tumor_type): + + enlarge_x, enlarge_y, enlarge_z = 160, 160, 160 + geo_mask = np.zeros((mask_scan.shape[0] + enlarge_x, mask_scan.shape[1] + enlarge_y, mask_scan.shape[2] + enlarge_z), dtype=np.int8) + tiny_radius, small_radius, medium_radius, large_radius = 4, 8, 16, 32 + + if tumor_type == 'tiny': + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Tiny tumor + x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + sigma = random.uniform(0.5,1) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + + if tumor_type == 'small': + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Small tumor + x = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + y = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + z = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + sigma = random.randint(1, 2) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'medium': + num_tumor = random.randint(2, 5) + for _ in range(num_tumor): + # medium tumor + x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + sigma = random.randint(3, 6) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste medium tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'large': + num_tumor = random.randint(1,3) + for _ in range(num_tumor): + # Large tumor + x = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + y = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + z = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + sigma = random.randint(5, 10) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == "mix": + # tiny + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Tiny tumor + x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + sigma = random.uniform(0.5,1) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + + # small + num_tumor = random.randint(5,10) + for _ in range(num_tumor): + # Small tumor + x = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + y = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + z = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + sigma = random.randint(1, 2) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + # medium + num_tumor = random.randint(2, 5) + for _ in range(num_tumor): + # medium tumor + x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + sigma = random.randint(3, 6) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste medium tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + # large + num_tumor = random.randint(1,3) + for _ in range(num_tumor): + # Large tumor + x = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + y = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + z = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + sigma = random.randint(5, 10) + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + geo_mask = geo_mask[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + # texture_map = texture_map[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + geo_mask = (geo_mask * mask_scan) >=1 + + return geo_mask + + +def get_tumor(volume_scan, mask_scan, tumor_type): + tumor_mask = get_fixed_geo(mask_scan, tumor_type) + + sigma = np.random.uniform(1, 2) + # difference = np.random.uniform(65, 145) + difference = 1 + + # blur the boundary + tumor_mask_blur = gaussian_filter(tumor_mask*1.0, sigma) + + + abnormally_full = volume_scan * (1 - mask_scan) + abnormally + abnormally_mask = mask_scan + geo_mask + + return abnormally_full, abnormally_mask + +def SynthesisTumor(volume_scan, mask_scan, tumor_type): + # for speed_generate_tumor, we only send the liver part into the generate program + x_start, x_end = np.where(np.any(mask_scan, axis=(1, 2)))[0][[0, -1]] + y_start, y_end = np.where(np.any(mask_scan, axis=(0, 2)))[0][[0, -1]] + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0][[0, -1]] + + # shrink the boundary + x_start, x_end = max(0, x_start+1), min(mask_scan.shape[0], x_end-1) + y_start, y_end = max(0, y_start+1), min(mask_scan.shape[1], y_end-1) + z_start, z_end = max(0, z_start+1), min(mask_scan.shape[2], z_end-1) + + ct_volume = volume_scan[x_start:x_end, y_start:y_end, z_start:z_end] + organ_mask = mask_scan[x_start:x_end, y_start:y_end, z_start:z_end] + + # input texture shape: 420, 300, 320 + # we need to cut it into the shape of liver_mask + # for examples, the liver_mask.shape = 286, 173, 46; we should change the texture shape + x_length, y_length, z_length = 64, 64, 64 + crop_x = random.randint(x_start, x_end - x_length - 1) # random select the start point, -1 is to avoid boundary check + crop_y = random.randint(y_start, y_end - y_length - 1) + crop_z = random.randint(z_start, z_end - z_length - 1) + + ct_volume, organ_tumor_mask = get_tumor(ct_volume, organ_mask, tumor_type) + volume_scan[x_start:x_end, y_start:y_end, z_start:z_end] = ct_volume + mask_scan[x_start:x_end, y_start:y_end, z_start:z_end] = organ_tumor_mask + + return volume_scan, mask_scan diff --git a/Generation_Pipeline_filter_all/syn_colon/healthy_colon_1k.txt b/Generation_Pipeline_filter_all/syn_colon/healthy_colon_1k.txt new file mode 100644 index 0000000000000000000000000000000000000000..780aae5af26bdf8a51701d20142eabd922526d0d --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_colon/healthy_colon_1k.txt @@ -0,0 +1,928 @@ +BDMAP_00001823 +BDMAP_00003074 +BDMAP_00001305 +BDMAP_00001635 +BDMAP_00002359 +BDMAP_00001265 +BDMAP_00000701 +BDMAP_00000771 +BDMAP_00003581 +BDMAP_00002523 +BDMAP_00004028 +BDMAP_00005151 +BDMAP_00001183 +BDMAP_00001656 +BDMAP_00003898 +BDMAP_00001845 +BDMAP_00000481 +BDMAP_00003324 +BDMAP_00002688 +BDMAP_00000948 +BDMAP_00004796 +BDMAP_00004198 +BDMAP_00003514 +BDMAP_00000432 +BDMAP_00003832 +BDMAP_00001296 +BDMAP_00003683 +BDMAP_00001607 +BDMAP_00004745 +BDMAP_00005167 +BDMAP_00005154 +BDMAP_00003598 +BDMAP_00003551 +BDMAP_00000176 +BDMAP_00004719 +BDMAP_00003722 +BDMAP_00002690 +BDMAP_00002244 +BDMAP_00000883 +BDMAP_00000926 +BDMAP_00002849 +BDMAP_00004549 +BDMAP_00004017 +BDMAP_00003482 +BDMAP_00003225 +BDMAP_00000416 +BDMAP_00002387 +BDMAP_00002022 +BDMAP_00002909 +BDMAP_00003236 +BDMAP_00000465 +BDMAP_00001784 +BDMAP_00004103 +BDMAP_00000656 +BDMAP_00004850 +BDMAP_00002955 +BDMAP_00003633 +BDMAP_00000137 +BDMAP_00004529 +BDMAP_00004903 +BDMAP_00001309 +BDMAP_00002216 +BDMAP_00001444 +BDMAP_00000263 +BDMAP_00004066 +BDMAP_00003920 +BDMAP_00001434 +BDMAP_00004890 +BDMAP_00000400 +BDMAP_00001238 +BDMAP_00003592 +BDMAP_00000431 +BDMAP_00002304 +BDMAP_00000285 +BDMAP_00004995 +BDMAP_00004264 +BDMAP_00001440 +BDMAP_00001383 +BDMAP_00003614 +BDMAP_00005157 +BDMAP_00003608 +BDMAP_00002619 +BDMAP_00000615 +BDMAP_00000084 +BDMAP_00002804 +BDMAP_00002592 +BDMAP_00001868 +BDMAP_00002021 +BDMAP_00000297 +BDMAP_00003202 +BDMAP_00000411 +BDMAP_00005070 +BDMAP_00003364 +BDMAP_00004395 +BDMAP_00002075 +BDMAP_00002844 +BDMAP_00002712 +BDMAP_00000714 +BDMAP_00002717 +BDMAP_00004895 +BDMAP_00000698 +BDMAP_00003384 +BDMAP_00001286 +BDMAP_00001562 +BDMAP_00004228 +BDMAP_00000831 +BDMAP_00000855 +BDMAP_00004672 +BDMAP_00000882 +BDMAP_00004992 +BDMAP_00002232 +BDMAP_00003849 +BDMAP_00004880 +BDMAP_00004074 +BDMAP_00002626 +BDMAP_00004262 +BDMAP_00000368 +BDMAP_00002826 +BDMAP_00000837 +BDMAP_00001911 +BDMAP_00001557 +BDMAP_00001126 +BDMAP_00002328 +BDMAP_00002959 +BDMAP_00002562 +BDMAP_00003600 +BDMAP_00001057 +BDMAP_00000940 +BDMAP_00002120 +BDMAP_00002227 +BDMAP_00000122 +BDMAP_00002479 +BDMAP_00002805 +BDMAP_00004980 +BDMAP_00001862 +BDMAP_00000778 +BDMAP_00003749 +BDMAP_00000245 +BDMAP_00000989 +BDMAP_00001247 +BDMAP_00000623 +BDMAP_00004113 +BDMAP_00002278 +BDMAP_00004841 +BDMAP_00001602 +BDMAP_00001464 +BDMAP_00001712 +BDMAP_00003815 +BDMAP_00002407 +BDMAP_00003150 +BDMAP_00001711 +BDMAP_00002273 +BDMAP_00002751 +BDMAP_00005074 +BDMAP_00001068 +BDMAP_00004447 +BDMAP_00000977 +BDMAP_00004297 +BDMAP_00000812 +BDMAP_00004641 +BDMAP_00001422 +BDMAP_00003385 +BDMAP_00003164 +BDMAP_00002475 +BDMAP_00002166 +BDMAP_00004232 +BDMAP_00000826 +BDMAP_00003769 +BDMAP_00003569 +BDMAP_00003853 +BDMAP_00004494 +BDMAP_00004011 +BDMAP_00002776 +BDMAP_00001517 +BDMAP_00004304 +BDMAP_00004645 +BDMAP_00000091 +BDMAP_00004738 +BDMAP_00000725 +BDMAP_00003771 +BDMAP_00002524 +BDMAP_00000161 +BDMAP_00000902 +BDMAP_00001786 +BDMAP_00002332 +BDMAP_00004175 +BDMAP_00002419 +BDMAP_00004077 +BDMAP_00004295 +BDMAP_00002871 +BDMAP_00004148 +BDMAP_00000676 +BDMAP_00001782 +BDMAP_00003947 +BDMAP_00003513 +BDMAP_00003130 +BDMAP_00001545 +BDMAP_00000667 +BDMAP_00005078 +BDMAP_00003435 +BDMAP_00002545 +BDMAP_00002498 +BDMAP_00001255 +BDMAP_00004065 +BDMAP_00002099 +BDMAP_00001504 +BDMAP_00001863 +BDMAP_00000542 +BDMAP_00002326 +BDMAP_00005155 +BDMAP_00001476 +BDMAP_00000388 +BDMAP_00000159 +BDMAP_00004060 +BDMAP_00000332 +BDMAP_00004087 +BDMAP_00000516 +BDMAP_00000574 +BDMAP_00004943 +BDMAP_00004514 +BDMAP_00003329 +BDMAP_00001597 +BDMAP_00002172 +BDMAP_00000833 +BDMAP_00004187 +BDMAP_00004744 +BDMAP_00001676 +BDMAP_00003558 +BDMAP_00003438 +BDMAP_00001957 +BDMAP_00004128 +BDMAP_00005140 +BDMAP_00002656 +BDMAP_00004817 +BDMAP_00000745 +BDMAP_00000205 +BDMAP_00000671 +BDMAP_00001962 +BDMAP_00003543 +BDMAP_00001620 +BDMAP_00003128 +BDMAP_00003409 +BDMAP_00000982 +BDMAP_00004015 +BDMAP_00001707 +BDMAP_00002068 +BDMAP_00001236 +BDMAP_00003973 +BDMAP_00004870 +BDMAP_00000366 +BDMAP_00003685 +BDMAP_00001096 +BDMAP_00003347 +BDMAP_00001892 +BDMAP_00003740 +BDMAP_00004773 +BDMAP_00002260 +BDMAP_00002815 +BDMAP_00000972 +BDMAP_00000998 +BDMAP_00003063 +BDMAP_00001791 +BDMAP_00002085 +BDMAP_00002275 +BDMAP_00004016 +BDMAP_00000438 +BDMAP_00000709 +BDMAP_00004416 +BDMAP_00003884 +BDMAP_00002237 +BDMAP_00001794 +BDMAP_00004378 +BDMAP_00000713 +BDMAP_00004286 +BDMAP_00001109 +BDMAP_00001223 +BDMAP_00001027 +BDMAP_00001001 +BDMAP_00005097 +BDMAP_00002942 +BDMAP_00000607 +BDMAP_00002940 +BDMAP_00002930 +BDMAP_00003377 +BDMAP_00004509 +BDMAP_00000923 +BDMAP_00001413 +BDMAP_00001636 +BDMAP_00001705 +BDMAP_00000273 +BDMAP_00003840 +BDMAP_00001333 +BDMAP_00005092 +BDMAP_00001368 +BDMAP_00003994 +BDMAP_00004925 +BDMAP_00001370 +BDMAP_00003455 +BDMAP_00002631 +BDMAP_00005174 +BDMAP_00005009 +BDMAP_00001549 +BDMAP_00001941 +BDMAP_00000154 +BDMAP_00001521 +BDMAP_00002653 +BDMAP_00001148 +BDMAP_00000774 +BDMAP_00005105 +BDMAP_00002421 +BDMAP_00000139 +BDMAP_00003867 +BDMAP_00003479 +BDMAP_00004741 +BDMAP_00001516 +BDMAP_00002396 +BDMAP_00003481 +BDMAP_00000324 +BDMAP_00002841 +BDMAP_00003326 +BDMAP_00002437 +BDMAP_00000100 +BDMAP_00004586 +BDMAP_00004867 +BDMAP_00001040 +BDMAP_00001185 +BDMAP_00001461 +BDMAP_00000692 +BDMAP_00001563 +BDMAP_00002289 +BDMAP_00004901 +BDMAP_00001632 +BDMAP_00000558 +BDMAP_00000469 +BDMAP_00001966 +BDMAP_00003315 +BDMAP_00002313 +BDMAP_00005006 +BDMAP_00000439 +BDMAP_00004551 +BDMAP_00003294 +BDMAP_00001807 +BDMAP_00004579 +BDMAP_00002057 +BDMAP_00002060 +BDMAP_00004508 +BDMAP_00004104 +BDMAP_00000052 +BDMAP_00003439 +BDMAP_00001502 +BDMAP_00005186 +BDMAP_00002529 +BDMAP_00002775 +BDMAP_00004834 +BDMAP_00001496 +BDMAP_00002319 +BDMAP_00002856 +BDMAP_00004552 +BDMAP_00004878 +BDMAP_00001331 +BDMAP_00001912 +BDMAP_00002758 +BDMAP_00000414 +BDMAP_00004288 +BDMAP_00000805 +BDMAP_00004597 +BDMAP_00003178 +BDMAP_00001752 +BDMAP_00003943 +BDMAP_00004652 +BDMAP_00004541 +BDMAP_00000614 +BDMAP_00004639 +BDMAP_00001804 +BDMAP_00005063 +BDMAP_00002807 +BDMAP_00000062 +BDMAP_00005119 +BDMAP_00004417 +BDMAP_00005075 +BDMAP_00001441 +BDMAP_00002373 +BDMAP_00002041 +BDMAP_00003727 +BDMAP_00001483 +BDMAP_00001128 +BDMAP_00004927 +BDMAP_00001119 +BDMAP_00004106 +BDMAP_00000355 +BDMAP_00002354 +BDMAP_00004030 +BDMAP_00004847 +BDMAP_00000618 +BDMAP_00003736 +BDMAP_00002803 +BDMAP_00005099 +BDMAP_00003168 +BDMAP_00000941 +BDMAP_00000243 +BDMAP_00001664 +BDMAP_00001747 +BDMAP_00003774 +BDMAP_00004917 +BDMAP_00000867 +BDMAP_00000435 +BDMAP_00003822 +BDMAP_00003411 +BDMAP_00000965 +BDMAP_00003612 +BDMAP_00004023 +BDMAP_00002333 +BDMAP_00001270 +BDMAP_00002616 +BDMAP_00004511 +BDMAP_00005130 +BDMAP_00000642 +BDMAP_00002471 +BDMAP_00000589 +BDMAP_00002509 +BDMAP_00004561 +BDMAP_00001275 +BDMAP_00003133 +BDMAP_00000626 +BDMAP_00003491 +BDMAP_00000993 +BDMAP_00003493 +BDMAP_00004499 +BDMAP_00002065 +BDMAP_00001175 +BDMAP_00002696 +BDMAP_00000319 +BDMAP_00002410 +BDMAP_00002485 +BDMAP_00001258 +BDMAP_00000660 +BDMAP_00003272 +BDMAP_00004183 +BDMAP_00003359 +BDMAP_00000956 +BDMAP_00004462 +BDMAP_00001704 +BDMAP_00000039 +BDMAP_00001853 +BDMAP_00003857 +BDMAP_00000572 +BDMAP_00005168 +BDMAP_00000304 +BDMAP_00002426 +BDMAP_00000244 +BDMAP_00001646 +BDMAP_00000413 +BDMAP_00004735 +BDMAP_00002476 +BDMAP_00004039 +BDMAP_00000219 +BDMAP_00004651 +BDMAP_00005065 +BDMAP_00004281 +BDMAP_00000113 +BDMAP_00003956 +BDMAP_00002226 +BDMAP_00004130 +BDMAP_00002707 +BDMAP_00000430 +BDMAP_00002661 +BDMAP_00001617 +BDMAP_00002298 +BDMAP_00003930 +BDMAP_00000687 +BDMAP_00004195 +BDMAP_00001647 +BDMAP_00000487 +BDMAP_00003367 +BDMAP_00003277 +BDMAP_00004600 +BDMAP_00003497 +BDMAP_00004546 +BDMAP_00004808 +BDMAP_00002981 +BDMAP_00000229 +BDMAP_00004185 +BDMAP_00003406 +BDMAP_00002422 +BDMAP_00002947 +BDMAP_00001261 +BDMAP_00005037 +BDMAP_00003590 +BDMAP_00003058 +BDMAP_00003461 +BDMAP_00003151 +BDMAP_00001035 +BDMAP_00001289 +BDMAP_00000087 +BDMAP_00004981 +BDMAP_00001836 +BDMAP_00004712 +BDMAP_00002363 +BDMAP_00002495 +BDMAP_00004398 +BDMAP_00003457 +BDMAP_00003752 +BDMAP_00001891 +BDMAP_00004373 +BDMAP_00001590 +BDMAP_00003506 +BDMAP_00001921 +BDMAP_00004229 +BDMAP_00001898 +BDMAP_00003483 +BDMAP_00004616 +BDMAP_00002648 +BDMAP_00000562 +BDMAP_00002403 +BDMAP_00003361 +BDMAP_00000887 +BDMAP_00001283 +BDMAP_00002719 +BDMAP_00005064 +BDMAP_00002793 +BDMAP_00002242 +BDMAP_00004278 +BDMAP_00002117 +BDMAP_00000320 +BDMAP_00005191 +BDMAP_00000809 +BDMAP_00000859 +BDMAP_00003955 +BDMAP_00004253 +BDMAP_00004031 +BDMAP_00005139 +BDMAP_00003244 +BDMAP_00000149 +BDMAP_00001414 +BDMAP_00001945 +BDMAP_00004510 +BDMAP_00003824 +BDMAP_00001361 +BDMAP_00000662 +BDMAP_00005022 +BDMAP_00000434 +BDMAP_00000241 +BDMAP_00000710 +BDMAP_00005120 +BDMAP_00002383 +BDMAP_00003036 +BDMAP_00002609 +BDMAP_00004922 +BDMAP_00004407 +BDMAP_00004481 +BDMAP_00001225 +BDMAP_00003556 +BDMAP_00000329 +BDMAP_00003052 +BDMAP_00003396 +BDMAP_00002164 +BDMAP_00001077 +BDMAP_00003153 +BDMAP_00003776 +BDMAP_00002710 +BDMAP_00004746 +BDMAP_00000066 +BDMAP_00005085 +BDMAP_00004435 +BDMAP_00002695 +BDMAP_00001828 +BDMAP_00003392 +BDMAP_00003976 +BDMAP_00002744 +BDMAP_00002214 +BDMAP_00000569 +BDMAP_00000571 +BDMAP_00004888 +BDMAP_00003301 +BDMAP_00004956 +BDMAP_00003809 +BDMAP_00002265 +BDMAP_00002944 +BDMAP_00004457 +BDMAP_00001768 +BDMAP_00001020 +BDMAP_00000541 +BDMAP_00000101 +BDMAP_00003664 +BDMAP_00003255 +BDMAP_00001379 +BDMAP_00002347 +BDMAP_00000128 +BDMAP_00002252 +BDMAP_00001697 +BDMAP_00002953 +BDMAP_00001122 +BDMAP_00003525 +BDMAP_00003070 +BDMAP_00004829 +BDMAP_00002233 +BDMAP_00001288 +BDMAP_00002791 +BDMAP_00004199 +BDMAP_00004184 +BDMAP_00003381 +BDMAP_00001766 +BDMAP_00003114 +BDMAP_00004804 +BDMAP_00002184 +BDMAP_00001138 +BDMAP_00000044 +BDMAP_00002271 +BDMAP_00003603 +BDMAP_00001523 +BDMAP_00004097 +BDMAP_00002440 +BDMAP_00004664 +BDMAP_00003808 +BDMAP_00000427 +BDMAP_00002362 +BDMAP_00005169 +BDMAP_00000023 +BDMAP_00003833 +BDMAP_00001710 +BDMAP_00001518 +BDMAP_00004482 +BDMAP_00003549 +BDMAP_00002171 +BDMAP_00002309 +BDMAP_00000338 +BDMAP_00000715 +BDMAP_00003897 +BDMAP_00003812 +BDMAP_00004257 +BDMAP_00001753 +BDMAP_00000117 +BDMAP_00001456 +BDMAP_00004115 +BDMAP_00003319 +BDMAP_00003744 +BDMAP_00004154 +BDMAP_00003658 +BDMAP_00001214 +BDMAP_00004293 +BDMAP_00001842 +BDMAP_00001420 +BDMAP_00003343 +BDMAP_00001325 +BDMAP_00000921 +BDMAP_00002582 +BDMAP_00002864 +BDMAP_00000889 +BDMAP_00001092 +BDMAP_00000968 +BDMAP_00002402 +BDMAP_00004427 +BDMAP_00001605 +BDMAP_00000462 +BDMAP_00005081 +BDMAP_00002463 +BDMAP_00000839 +BDMAP_00000437 +BDMAP_00000604 +BDMAP_00001104 +BDMAP_00001281 +BDMAP_00000679 +BDMAP_00004717 +BDMAP_00001511 +BDMAP_00003281 +BDMAP_00001977 +BDMAP_00000653 +BDMAP_00000232 +BDMAP_00004328 +BDMAP_00002496 +BDMAP_00000987 +BDMAP_00003717 +BDMAP_00004897 +BDMAP_00003713 +BDMAP_00002889 +BDMAP_00003657 +BDMAP_00002829 +BDMAP_00004839 +BDMAP_00001397 +BDMAP_00001908 +BDMAP_00003911 +BDMAP_00004843 +BDMAP_00004969 +BDMAP_00003918 +BDMAP_00004216 +BDMAP_00000034 +BDMAP_00003923 +BDMAP_00000225 +BDMAP_00003576 +BDMAP_00002884 +BDMAP_00002472 +BDMAP_00001688 +BDMAP_00001246 +BDMAP_00004620 +BDMAP_00005017 +BDMAP_00002990 +BDMAP_00000971 +BDMAP_00004578 +BDMAP_00001735 +BDMAP_00002655 +BDMAP_00000233 +BDMAP_00001205 +BDMAP_00003073 +BDMAP_00003957 +BDMAP_00001093 +BDMAP_00003440 +BDMAP_00001251 +BDMAP_00004793 +BDMAP_00000162 +BDMAP_00003444 +BDMAP_00001533 +BDMAP_00003971 +BDMAP_00001584 +BDMAP_00000036 +BDMAP_00002251 +BDMAP_00003141 +BDMAP_00002484 +BDMAP_00004770 +BDMAP_00001487 +BDMAP_00001754 +BDMAP_00003356 +BDMAP_00000353 +BDMAP_00001419 +BDMAP_00001802 +BDMAP_00003701 +BDMAP_00005141 +BDMAP_00000321 +BDMAP_00001746 +BDMAP_00000364 +BDMAP_00003900 +BDMAP_00001995 +BDMAP_00001025 +BDMAP_00004231 +BDMAP_00000918 +BDMAP_00001130 +BDMAP_00003443 +BDMAP_00003215 +BDMAP_00004815 +BDMAP_00002933 +BDMAP_00000192 +BDMAP_00003615 +BDMAP_00004704 +BDMAP_00001218 +BDMAP_00002295 +BDMAP_00000429 +BDMAP_00000532 +BDMAP_00001474 +BDMAP_00003961 +BDMAP_00004129 +BDMAP_00000362 +BDMAP_00002863 +BDMAP_00003267 +BDMAP_00001198 +BDMAP_00000259 +BDMAP_00000683 +BDMAP_00001256 +BDMAP_00003252 +BDMAP_00004475 +BDMAP_00004250 +BDMAP_00004887 +BDMAP_00000240 +BDMAP_00003767 +BDMAP_00003427 +BDMAP_00000043 +BDMAP_00003448 +BDMAP_00001114 +BDMAP_00001067 +BDMAP_00001089 +BDMAP_00002133 +BDMAP_00004033 +BDMAP_00002896 +BDMAP_00003138 +BDMAP_00001010 +BDMAP_00001059 +BDMAP_00004990 +BDMAP_00000936 +BDMAP_00001359 +BDMAP_00005077 +BDMAP_00000582 +BDMAP_00004296 +BDMAP_00005114 +BDMAP_00004389 +BDMAP_00004673 +BDMAP_00003254 +BDMAP_00003516 +BDMAP_00001475 +BDMAP_00002580 +BDMAP_00002689 +BDMAP_00004671 +BDMAP_00003762 +BDMAP_00003330 +BDMAP_00002188 +BDMAP_00001736 +BDMAP_00002404 +BDMAP_00003502 +BDMAP_00004117 +BDMAP_00004964 +BDMAP_00002742 +BDMAP_00000093 +BDMAP_00002361 +BDMAP_00000794 +BDMAP_00002349 +BDMAP_00001273 +BDMAP_00000449 +BDMAP_00001628 +BDMAP_00002250 +BDMAP_00004479 +BDMAP_00000608 +BDMAP_00001834 +BDMAP_00002267 +BDMAP_00001125 +BDMAP_00000447 +BDMAP_00005113 +BDMAP_00004014 +BDMAP_00001701 +BDMAP_00003952 +BDMAP_00003520 +BDMAP_00000347 +BDMAP_00002936 +BDMAP_00001624 +BDMAP_00000104 +BDMAP_00002487 +BDMAP_00005020 +BDMAP_00000511 +BDMAP_00001564 +BDMAP_00004294 +BDMAP_00004858 +BDMAP_00004608 +BDMAP_00003650 +BDMAP_00001171 +BDMAP_00000059 +BDMAP_00000871 +BDMAP_00003996 +BDMAP_00001169 +BDMAP_00003363 +BDMAP_00003376 +BDMAP_00002167 +BDMAP_00002737 +BDMAP_00003694 +BDMAP_00001396 +BDMAP_00005083 +BDMAP_00002918 +BDMAP_00003580 +BDMAP_00001324 +BDMAP_00002855 +BDMAP_00001649 +BDMAP_00004459 +BDMAP_00002288 +BDMAP_00004830 +BDMAP_00004775 +BDMAP_00000279 +BDMAP_00002114 +BDMAP_00005185 +BDMAP_00004885 +BDMAP_00000236 +BDMAP_00003928 +BDMAP_00002663 +BDMAP_00002282 +BDMAP_00003798 +BDMAP_00001055 +BDMAP_00002945 +BDMAP_00001316 +BDMAP_00003451 +BDMAP_00000696 +BDMAP_00003143 +BDMAP_00001522 +BDMAP_00000452 +BDMAP_00002603 +BDMAP_00004131 +BDMAP_00001045 +BDMAP_00004954 +BDMAP_00003358 +BDMAP_00000980 +BDMAP_00001343 +BDMAP_00001410 +BDMAP_00002173 +BDMAP_00002840 +BDMAP_00001200 +BDMAP_00001905 +BDMAP_00003425 +BDMAP_00003672 +BDMAP_00003781 +BDMAP_00001906 +BDMAP_00004636 +BDMAP_00000836 +BDMAP_00002076 +BDMAP_00001230 +BDMAP_00003932 +BDMAP_00002029 +BDMAP_00000331 +BDMAP_00000197 +BDMAP_00001539 +BDMAP_00003524 +BDMAP_00001692 +BDMAP_00004550 +BDMAP_00004331 +BDMAP_00004825 +BDMAP_00002411 +BDMAP_00003484 +BDMAP_00000480 +BDMAP_00003886 +BDMAP_00001907 +BDMAP_00002746 +BDMAP_00002899 +BDMAP_00004420 +BDMAP_00002748 +BDMAP_00003002 +BDMAP_00003300 +BDMAP_00005108 +BDMAP_00003400 +BDMAP_00002283 +BDMAP_00003486 +BDMAP_00000935 +BDMAP_00001618 +BDMAP_00001565 +BDMAP_00000616 +BDMAP_00000810 +BDMAP_00001826 +BDMAP_00000470 +BDMAP_00002017 +BDMAP_00003631 +BDMAP_00001242 +BDMAP_00000907 +BDMAP_00001806 +BDMAP_00002854 +BDMAP_00002902 +BDMAP_00003017 +BDMAP_00001471 diff --git a/Generation_Pipeline_filter_all/syn_colon/requirements.txt b/Generation_Pipeline_filter_all/syn_colon/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9c1067c8eb4f44699bbf517601396113cea4b370 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_colon/requirements.txt @@ -0,0 +1,94 @@ +absl-py==1.1.0 +accelerate==0.11.0 +aiohttp==3.8.1 +aiosignal==1.2.0 +antlr4-python3-runtime==4.9.3 +async-timeout==4.0.2 +attrs==21.4.0 +autopep8==1.6.0 +cachetools==5.2.0 +certifi==2022.6.15 +charset-normalizer==2.0.12 +click==8.1.3 +cycler==0.11.0 +Deprecated==1.2.13 +docker-pycreds==0.4.0 +einops==0.4.1 +einops-exts==0.0.3 +ema-pytorch==0.0.8 +fonttools==4.34.4 +frozenlist==1.3.0 +fsspec==2022.5.0 +ftfy==6.1.1 +future==0.18.2 +gitdb==4.0.9 +GitPython==3.1.27 +google-auth==2.9.0 +google-auth-oauthlib==0.4.6 +grpcio==1.47.0 +h5py==3.7.0 +humanize==4.2.2 +hydra-core==1.2.0 +idna==3.3 +imageio==2.19.3 +imageio-ffmpeg==0.4.7 +importlib-metadata==4.12.0 +importlib-resources==5.9.0 +joblib==1.1.0 +kiwisolver==1.4.3 +lxml==4.9.1 +Markdown==3.3.7 +matplotlib==3.5.2 +multidict==6.0.2 +networkx==2.8.5 +nibabel==4.0.1 +nilearn==0.9.1 +numpy==1.23.0 +oauthlib==3.2.0 +omegaconf==2.2.3 +pandas==1.4.3 +Pillow==9.1.1 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pycodestyle==2.8.0 +pyDeprecate==0.3.1 +pydicom==2.3.0 +pytorch-lightning==1.6.4 +pytz==2022.1 +PyWavelets==1.3.0 +PyYAML==6.0 +pyzmq==19.0.2 +regex==2022.6.2 +requests==2.28.0 +requests-oauthlib==1.3.1 +rotary-embedding-torch==0.1.5 +rsa==4.8 +scikit-image==0.19.3 +scikit-learn==1.1.2 +scikit-video==1.1.11 +scipy==1.8.1 +seaborn==0.11.2 +sentry-sdk==1.7.2 +setproctitle==1.2.3 +shortuuid==1.0.9 +SimpleITK==2.1.1.2 +smmap==5.0.0 +tensorboard==2.9.1 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.1 +threadpoolctl==3.1.0 +tifffile==2022.8.3 +toml==0.10.2 +torch-tb-profiler==0.4.0 +torchio==0.18.80 +torchmetrics==0.9.1 +tqdm==4.64.0 +typing_extensions==4.2.0 +urllib3==1.26.9 +wandb==0.12.21 +Werkzeug==2.1.2 +wrapt==1.14.1 +yarl==1.7.2 +zipp==3.8.0 +wandb +tensorboardX==2.4.1 diff --git a/Generation_Pipeline_filter_all/syn_kidney/CT_syn_kidney_data_new.py b/Generation_Pipeline_filter_all/syn_kidney/CT_syn_kidney_data_new.py new file mode 100644 index 0000000000000000000000000000000000000000..765590e5eb2af9433c9309cc9623a909146034de --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/CT_syn_kidney_data_new.py @@ -0,0 +1,241 @@ +import os, time, csv +import numpy as np +import torch +from sklearn.metrics import confusion_matrix +from scipy import ndimage +from scipy.ndimage import label +from functools import partial +import monai +from monai.transforms import AsDiscrete,AsDiscreted,Compose,Invertd,SaveImaged +from monai import transforms, data +from TumorGeneration.utils import synthesize_early_tumor, synthesize_medium_tumor, synthesize_large_tumor, synt_model_prepare +import nibabel as nib + +import warnings +warnings.filterwarnings("ignore") + +import argparse +parser = argparse.ArgumentParser(description='kidney tumor validation') + +# file dir +parser.add_argument('--data_root', default=None, type=str) +parser.add_argument('--organ_type', default='kidney', type=str) +parser.add_argument('--save_dir', default='out', type=str) +parser.add_argument('--data_file', default='out', type=str) +parser.add_argument('--ddim_ts', default=50, type=int) +parser.add_argument('--fg_thresh', default=30, type=int) +parser.add_argument('--start', default=0, type=int) +parser.add_argument('--end', default=1000, type=int) + +def voxel2R(A): + return (np.array(A)/4*3/np.pi)**(1/3) + +class RandCropByPosNegLabeld_select(transforms.RandCropByPosNegLabeld): + def __init__(self, keys, label_key, spatial_size, + pos=1.0, neg=1.0, num_samples=1, + image_key=None, image_threshold=0.0, allow_missing_keys=True, + fg_thresh=0): + super().__init__(keys=keys, label_key=label_key, spatial_size=spatial_size, + pos=pos, neg=neg, num_samples=num_samples, + image_key=image_key, image_threshold=image_threshold, allow_missing_keys=allow_missing_keys) + self.fg_thresh = fg_thresh + + def R2voxel(self,R): + return (4/3*np.pi)*(R)**(3) + + def __call__(self, data): + d = dict(data) + data_name = d['name'] + d.pop('name') + + if '10_Decathlon' in data_name or '05_KiTS' in data_name: + d_crop = super().__call__(d) + + else: + flag=0 + while 1: + flag+=1 + + d_crop = super().__call__(d) + pixel_num = (d_crop[0]['label']>0).sum() + + if pixel_num > self.R2voxel(self.fg_thresh): + break + if flag>5 and pixel_num > self.R2voxel(max(self.fg_thresh-5, 5)): + break + if flag>10 and pixel_num > self.R2voxel(max(self.fg_thresh-10, 5)): + break + if flag>15 and pixel_num > self.R2voxel(max(self.fg_thresh-15, 5)): + break + if flag>20 and pixel_num > self.R2voxel(max(self.fg_thresh-20, 5)): + break + if flag>25 and pixel_num > self.R2voxel(max(self.fg_thresh-25, 5)): + break + if flag>50: + break + + d_crop[0]['name'] = data_name + + return d_crop + +def _get_loader(args): + # val_data_dir = args.val_dir + # datalist_json = args.json_dir + val_org_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label", "raw_image"]), + transforms.AddChanneld(keys=["image", "label", "raw_image"]), + transforms.Orientationd(keys=["image", "label"], axcodes="RAS"), + transforms.Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear")), + transforms.ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), + transforms.SpatialPadd(keys=["image", "label"], mode=["minimum", "constant"], spatial_size=[96, 96, 96]), + RandCropByPosNegLabeld_select( + keys=["image", "label", "name"], + label_key="label", + spatial_size=(96, 96, 96), + pos=1, + neg=0, + num_samples=1, + image_key="image", + image_threshold=0, + fg_thresh = args.fg_thresh, + ), + transforms.ToTensord(keys=["image", "label", "raw_image"]), + ] + ) + + val_img=[] + val_lbl=[] + val_name=[] + + for line in open(args.data_file): + # name = line.strip().split()[1].split('.')[0] + # val_img.append(args.data_root + line.strip().split()[0]) + # val_lbl.append(args.data_root + line.strip().split()[1]) + # breakpoint() + name = line.strip() + val_img.append(os.path.join(args.data_root, name, 'ct.nii.gz')) + val_lbl.append(os.path.join(args.data_root, name, 'segmentations/kidney_left.nii.gz')) + val_name.append(name) + data_dicts_val = [{'image': image, 'raw_image':image, 'label': label, 'name': name} + for image, label, name in zip(val_img, val_lbl, val_name)] + + if args.end < len(data_dicts_val): + data_dicts_val = data_dicts_val[args.start:args.end] + else: + data_dicts_val = data_dicts_val[args.start:] + print('val len {}'.format(len(data_dicts_val))) + val_org_ds = data.Dataset(data_dicts_val, transform=val_org_transform) + val_org_loader = data.DataLoader(val_org_ds, batch_size=1, shuffle=False, num_workers=4, sampler=None, pin_memory=True) + + post_transforms = Compose([ + Invertd( + keys=['image'], + transform=val_org_transform, + orig_keys="image", + nearest_interp=False, + # nearest_interp=True, + to_tensor=True, + ), + Invertd( + keys=['label'], + transform=val_org_transform, + orig_keys="label", + nearest_interp=False, + # nearest_interp=True, + to_tensor=True, + ) + ]) + return val_org_loader, post_transforms + +def main(): + args = parser.parse_args() + output_dir = args.save_dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + print("MAIN Argument values:") + for k, v in vars(args).items(): + print(k, '=>', v) + print('-----------------') + + ## loader and post_transform + val_loader, post_transforms = _get_loader(args) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device) + model.load_state_dict(torch.load("../best_metric_model_classification3d_dict.pth")) + model.eval() + + start_time=0 + with torch.no_grad(): + for idx, val_data in enumerate(val_loader): + print('idx',idx) + if idx == 0: + start_time = time.time() + # val_inputs = val_data["image"] + # name = val_data['label_meta_dict']['filename_or_obj'][0].split('/')[-1].split('.')[0] + + vqgan, early_sampler, noearly_sampler= synt_model_prepare(device = torch.device("cuda"), fold=0, organ=args.organ_type) + + healthy_data, healthy_target, data_names, raw_data = val_data['image'], val_data['label'], val_data['name'], val_data['raw_image'] + case_name = data_names[0].split('/')[-1] + print('case_name', case_name) + original_affine = val_data["label_meta_dict"]["original_affine"][0].numpy() + + if healthy_target.sum() == 0: + val_data = [post_transforms(i) for i in data.decollate_batch(val_data)] + tumor_mask = val_data[0]['label'][0].cpu().numpy().astype(np.uint8) + tumor_mask_ = np.zeros_like(tumor_mask) + nib.save(nib.Nifti1Image(tumor_mask_, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/kidney_tumor.nii.gz')) + continue + + healthy_data, healthy_target = healthy_data.cuda(), healthy_target.cuda() + healthy_target = (healthy_target>0).to(healthy_target) + + tumor_types = ['early', 'medium', 'large'] + # tumor_probs = np.array([0.45, 0.45, 0.1]) + # tumor_probs = np.array([1.0, 0.0, 0.0]) + tumor_probs = np.array([0.5, 0.4, 0.1]) + synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + print('synthetic_tumor_type',synthetic_tumor_type) + flag=0 + while 1: + flag+=1 + if synthetic_tumor_type == 'early': + synt_data, synt_target = synthesize_early_tumor(healthy_data, healthy_target, args.organ_type, vqgan, early_sampler) + elif synthetic_tumor_type == 'medium': + synt_data, synt_target = synthesize_medium_tumor(healthy_data, healthy_target, args.organ_type, vqgan, noearly_sampler, ddim_ts=args.ddim_ts) + elif synthetic_tumor_type == 'large': + synt_data, synt_target = synthesize_large_tumor(healthy_data, healthy_target, args.organ_type, vqgan, noearly_sampler, ddim_ts=args.ddim_ts) + + syn_confidence = model(synt_data).sigmoid()[:,1] + if syn_confidence>0.01: + break + elif flag > 20 and syn_confidence>0.005: + break + elif flag > 40 and syn_confidence>0.001: + break + val_data['image'] = synt_data.detach() + val_data['label'] = synt_target.detach() + + val_data = [post_transforms(i) for i in data.decollate_batch(val_data)] + synt_data = val_data[0]['image'][0] + synt_target = val_data[0]['label'][0] + final_data = raw_data[0,0] + + synt_data = (synt_data*(250+175)-175) + final_data[synt_target>1] = synt_data[synt_target>1] + final_data = final_data.cpu().numpy() + final_label = (synt_target>=1.5).cpu().numpy().astype(np.uint8) + + os.makedirs(os.path.join(output_dir, f'{case_name}'), exist_ok=True) + os.makedirs(os.path.join(output_dir, f'{case_name}/segmentations'), exist_ok=True) + nib.save(nib.Nifti1Image(final_data, original_affine), os.path.join(output_dir, f'{case_name}/ct.nii.gz')) + nib.save(nib.Nifti1Image(final_label, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/kidney_tumor.nii.gz')) + + print('time = ', time.time()-start_time) + start_time = time.time() + + +if __name__ == "__main__": + main() diff --git a/Generation_Pipeline_filter_all/syn_kidney/CT_syn_kidney_data_new2.py b/Generation_Pipeline_filter_all/syn_kidney/CT_syn_kidney_data_new2.py new file mode 100644 index 0000000000000000000000000000000000000000..d178d6d6b374003fb5f53d5d6f16361a190f806f --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/CT_syn_kidney_data_new2.py @@ -0,0 +1,251 @@ +import os, time, csv +import numpy as np +import torch +from sklearn.metrics import confusion_matrix +from scipy import ndimage +from scipy.ndimage import label +from functools import partial +import monai +from monai.transforms import AsDiscrete,AsDiscreted,Compose,Invertd,SaveImaged +from monai import transforms, data +from TumorGeneration.utils import synthesize_early_tumor, synthesize_medium_tumor, synthesize_large_tumor, synt_model_prepare +import nibabel as nib + +import warnings +warnings.filterwarnings("ignore") + +import argparse +parser = argparse.ArgumentParser(description='kidney tumor validation') + +# file dir +parser.add_argument('--data_root', default=None, type=str) +parser.add_argument('--organ_type', default='kidney', type=str) +parser.add_argument('--save_dir', default='out', type=str) +parser.add_argument('--data_file', default='out', type=str) +parser.add_argument('--ddim_ts', default=50, type=int) +parser.add_argument('--fg_thresh', default=30, type=int) +parser.add_argument('--start', default=0, type=int) +parser.add_argument('--end', default=1000, type=int) + +def voxel2R(A): + return (np.array(A)/4*3/np.pi)**(1/3) + +class RandCropByPosNegLabeld_select(transforms.RandCropByPosNegLabeld): + def __init__(self, keys, label_key, spatial_size, + pos=1.0, neg=1.0, num_samples=1, + image_key=None, image_threshold=0.0, allow_missing_keys=True, + fg_thresh=0): + super().__init__(keys=keys, label_key=label_key, spatial_size=spatial_size, + pos=pos, neg=neg, num_samples=num_samples, + image_key=image_key, image_threshold=image_threshold, allow_missing_keys=allow_missing_keys) + self.fg_thresh = fg_thresh + + def R2voxel(self,R): + return (4/3*np.pi)*(R)**(3) + + def __call__(self, data): + d = dict(data) + data_name = d['name'] + d.pop('name') + + if '10_Decathlon' in data_name or '05_KiTS' in data_name: + d_crop = super().__call__(d) + + else: + flag=0 + while 1: + flag+=1 + + d_crop = super().__call__(d) + pixel_num = (d_crop[0]['label']>0).sum() + + if pixel_num > self.R2voxel(self.fg_thresh): + break + if flag>5 and pixel_num > self.R2voxel(max(self.fg_thresh-5, 5)): + break + if flag>10 and pixel_num > self.R2voxel(max(self.fg_thresh-10, 5)): + break + if flag>15 and pixel_num > self.R2voxel(max(self.fg_thresh-15, 5)): + break + if flag>20 and pixel_num > self.R2voxel(max(self.fg_thresh-20, 5)): + break + if flag>25 and pixel_num > self.R2voxel(max(self.fg_thresh-25, 5)): + break + if flag>50: + break + + d_crop[0]['name'] = data_name + + return d_crop + +def _get_loader(args): + # val_data_dir = args.val_dir + # datalist_json = args.json_dir + val_org_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label", 'tumor_label', "raw_image"]), + transforms.AddChanneld(keys=["image", "label", 'tumor_label', "raw_image"]), + transforms.Orientationd(keys=["image", "label", 'tumor_label'], axcodes="RAS"), + transforms.Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear")), + transforms.ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), + transforms.SpatialPadd(keys=["image", "label"], mode=["minimum", "constant"], spatial_size=[96, 96, 96]), + RandCropByPosNegLabeld_select( + keys=["image", "label", "name"], + label_key="label", + spatial_size=(96, 96, 96), + pos=1, + neg=0, + num_samples=1, + image_key="image", + image_threshold=0, + fg_thresh = args.fg_thresh, + ), + transforms.ToTensord(keys=["image", "label", 'tumor_label', "raw_image"]), + ] + ) + + val_img=[] + val_lbl=[] + val_name=[] + tumor_lbl=[] + for line in open(args.data_file): + # name = line.strip().split()[1].split('.')[0] + # val_img.append(args.data_root + line.strip().split()[0]) + # val_lbl.append(args.data_root + line.strip().split()[1]) + # breakpoint() + name = line.strip() + val_img.append(os.path.join(args.data_root, name, 'ct.nii.gz')) + val_lbl.append(os.path.join(args.data_root, name, 'segmentations/kidney_right.nii.gz')) + tumor_lbl.append(os.path.join(args.data_root, name, 'segmentations/kidney_tumor.nii.gz')) + val_name.append(name) + data_dicts_val = [{'image': image, 'raw_image':image, 'label': label, 'tumor_label':tumor_label,'name': name} + for image, label, tumor_label, name in zip(val_img, val_lbl, tumor_lbl, val_name)] + + if args.end < len(data_dicts_val): + data_dicts_val = data_dicts_val[args.start:args.end] + else: + data_dicts_val = data_dicts_val[args.start:] + print('val len {}'.format(len(data_dicts_val))) + val_org_ds = data.Dataset(data_dicts_val, transform=val_org_transform) + val_org_loader = data.DataLoader(val_org_ds, batch_size=1, shuffle=False, num_workers=4, sampler=None, pin_memory=True) + + post_transforms = Compose([ + Invertd( + keys=['image'], + transform=val_org_transform, + orig_keys="image", + nearest_interp=False, + # nearest_interp=True, + to_tensor=True, + ), + Invertd( + keys=['label'], + transform=val_org_transform, + orig_keys="label", + nearest_interp=False, + # nearest_interp=True, + to_tensor=True, + ), + Invertd( + keys=['tumor_label'], + transform=val_org_transform, + orig_keys="tumor_label", + nearest_interp=False, + to_tensor=True, + ) + ]) + return val_org_loader, post_transforms + +def main(): + args = parser.parse_args() + output_dir = args.save_dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + print("MAIN Argument values:") + for k, v in vars(args).items(): + print(k, '=>', v) + print('-----------------') + + ## loader and post_transform + val_loader, post_transforms = _get_loader(args) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device) + model.load_state_dict(torch.load("../best_metric_model_classification3d_dict.pth")) + model.eval() + + start_time=0 + with torch.no_grad(): + for idx, val_data in enumerate(val_loader): + print('idx',idx) + if idx == 0: + start_time = time.time() + # val_inputs = val_data["image"] + # name = val_data['label_meta_dict']['filename_or_obj'][0].split('/')[-1].split('.')[0] + + vqgan, early_sampler, noearly_sampler= synt_model_prepare(device = torch.device("cuda"), fold=0, organ=args.organ_type) + + healthy_data, healthy_target, data_names, raw_data = val_data['image'], val_data['label'], val_data['name'], val_data['raw_image'] + case_name = data_names[0].split('/')[-1] + print('case_name', case_name) + original_affine = val_data["label_meta_dict"]["original_affine"][0].numpy() + + if healthy_target.sum() == 0: + # val_data = [post_transforms(i) for i in data.decollate_batch(val_data)] + # tumor_mask = val_data[0]['label'][0].cpu().numpy().astype(np.uint8) + # tumor_mask_ = np.zeros_like(tumor_mask) + # nib.save(nib.Nifti1Image(tumor_mask_, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/kidney_tumor.nii.gz')) + continue + + healthy_data, healthy_target = healthy_data.cuda(), healthy_target.cuda() + healthy_target = (healthy_target>0).to(healthy_target) + + tumor_types = ['early', 'medium', 'large'] + # tumor_probs = np.array([0.45, 0.45, 0.1]) + # tumor_probs = np.array([1.0, 0.0, 0.0]) + tumor_probs = np.array([0.5, 0.4, 0.1]) + synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + print('synthetic_tumor_type',synthetic_tumor_type) + flag=0 + while 1: + flag+=1 + if synthetic_tumor_type == 'early': + synt_data, synt_target = synthesize_early_tumor(healthy_data, healthy_target, args.organ_type, vqgan, early_sampler) + elif synthetic_tumor_type == 'medium': + synt_data, synt_target = synthesize_medium_tumor(healthy_data, healthy_target, args.organ_type, vqgan, noearly_sampler, ddim_ts=args.ddim_ts) + elif synthetic_tumor_type == 'large': + synt_data, synt_target = synthesize_large_tumor(healthy_data, healthy_target, args.organ_type, vqgan, noearly_sampler, ddim_ts=args.ddim_ts) + + syn_confidence = model(synt_data).sigmoid()[:,1] + if syn_confidence>0.01: + break + elif flag > 20 and syn_confidence>0.005: + break + elif flag > 40 and syn_confidence>0.001: + break + val_data['image'] = synt_data.detach() + val_data['label'] = synt_target.detach() + + val_data = [post_transforms(i) for i in data.decollate_batch(val_data)] + synt_data = val_data[0]['image'][0] + synt_target = val_data[0]['label'][0] + tumor_mask = val_data[0]['tumor_label'][0] + final_data = raw_data[0,0] + + synt_data = (synt_data*(250+175)-175) + final_data[synt_target>1] = synt_data[synt_target>1] + final_data = final_data.cpu().numpy() + final_label = (synt_target>=1.5).cpu().numpy().astype(np.uint8) + final_label[tumor_mask==1] = 1 + + os.makedirs(os.path.join(output_dir, f'{case_name}'), exist_ok=True) + os.makedirs(os.path.join(output_dir, f'{case_name}/segmentations'), exist_ok=True) + nib.save(nib.Nifti1Image(final_data, original_affine), os.path.join(output_dir, f'{case_name}/ct.nii.gz')) + nib.save(nib.Nifti1Image(final_label, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/kidney_tumor.nii.gz')) + + print('time = ', time.time()-start_time) + start_time = time.time() + + +if __name__ == "__main__": + main() diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/.DS_Store b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..fb0d207e124cfecc777faae5b4a56c5ca1b9cd2e Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/.DS_Store differ diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/README.md b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/README.md new file mode 100644 index 0000000000000000000000000000000000000000..201911031ca988c410781a60dc3c192a65ee56b3 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/README.md @@ -0,0 +1,5 @@ +```bash +wget https://www.dropbox.com/scl/fi/k856fhk60kck8uqxtxazw/model_weight.tar.gz?rlkey=hrcn4cbt690dzern1bkbfejll +mv model_weight.tar.gz?rlkey=hrcn4cbt690dzern1bkbfejll model_weight.tar.gz +tar -xzvf model_weight.tar.gz +``` diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/TumorGenerated.py b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/TumorGenerated.py new file mode 100644 index 0000000000000000000000000000000000000000..9983cf047b1532f5edd20c0d4c78102d6d611eb9 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/TumorGenerated.py @@ -0,0 +1,39 @@ +import random +from typing import Hashable, Mapping, Dict + +from monai.config import KeysCollection +from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms.transform import MapTransform, RandomizableTransform + +from .utils_ import SynthesisTumor +import numpy as np + +class TumorGenerated(RandomizableTransform, MapTransform): + def __init__(self, + keys: KeysCollection, + prob: float = 0.1, + tumor_prob = [0.2, 0.2, 0.2, 0.2, 0.2], + allow_missing_keys: bool = False + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob) + random.seed(0) + np.random.seed(0) + + self.tumor_types = ['tiny', 'small', 'medium', 'large', 'mix'] + + assert len(tumor_prob) == 5 + self.tumor_prob = np.array(tumor_prob) + + + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + self.randomize(None) + + if self._do_transform and (np.max(d['label']) <= 1): + tumor_type = np.random.choice(self.tumor_types, p=self.tumor_prob.ravel()) + + d['image'][0], d['label'][0] = SynthesisTumor(d['image'][0], d['label'][0], tumor_type) + # print(tumor_type, d['image'].shape, np.max(d['label'])) + return d diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/__init__.py b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1afc8a195ba5fd106ca18d4e219c123a75e6e831 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/__init__.py @@ -0,0 +1,5 @@ +### Online Version TumorGeneration ### + +from .TumorGenerated import TumorGenerated + +from .utils import synthesize_early_tumor, synthesize_medium_tumor, synthesize_large_tumor \ No newline at end of file diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8eba2478d809b697c49e4425ba2fc619b4554f12 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4d7e4379e19d41aaf49dbc8394daeb6ab80b6bc Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/__pycache__/utils.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3065f278859daa8566ac262917eed6ba21daffd1 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/__pycache__/utils.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/__pycache__/utils_.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/__pycache__/utils_.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c35e0a87fa9e862aaefa5d34991f85f12516a30 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/__pycache__/utils_.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/diffusion_config/ddpm.yaml b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/diffusion_config/ddpm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..474a973b1f76c2a026e2df916c4a153b5bbea05f --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/diffusion_config/ddpm.yaml @@ -0,0 +1,29 @@ +vqgan_ckpt: None + +# Have to be derived from VQ-GAN Latent space dimensions +diffusion_img_size: 24 +diffusion_depth_size: 24 +diffusion_num_channels: 17 # 17 +out_dim: 8 +dim_mults: [1,2,4,8] +results_folder: checkpoints/ddpm/ +results_folder_postfix: 'own_dataset_t2' +load_milestone: False # False + +batch_size: 2 # 40 +num_workers: 20 +logger: wandb +objective: pred_x0 +save_and_sample_every: 1000 +denoising_fn: Unet3D +train_lr: 1e-4 +timesteps: 2 # number of steps +sampling_timesteps: 250 # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) +loss_type: l1 # L1 or L2 +train_num_steps: 700000 # total training steps +gradient_accumulate_every: 2 # gradient accumulation steps +ema_decay: 0.995 # exponential moving average decay +amp: False # turn on mixed precision +num_sample_rows: 1 +gpus: 0 + diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/diffusion_config/vq_gan_3d.yaml b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/diffusion_config/vq_gan_3d.yaml new file mode 100644 index 0000000000000000000000000000000000000000..29940377c5cadb0c06322de8ac60d0713f799024 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/diffusion_config/vq_gan_3d.yaml @@ -0,0 +1,37 @@ +seed: 1234 +batch_size: 2 # 30 +num_workers: 32 # 30 + +gpus: 1 +accumulate_grad_batches: 1 +default_root_dir: checkpoints/vq_gan/ +default_root_dir_postfix: 'flair' +resume_from_checkpoint: +max_steps: -1 +max_epochs: -1 +precision: 16 +gradient_clip_val: 1.0 + + +embedding_dim: 8 # 256 +n_codes: 16384 # 2048 +n_hiddens: 16 +lr: 3e-4 +downsample: [2, 2, 2] # [4, 4, 4] +disc_channels: 64 +disc_layers: 3 +discriminator_iter_start: 10000 # 50000 +disc_loss_type: hinge +image_gan_weight: 1.0 +video_gan_weight: 1.0 +l1_weight: 4.0 +gan_feat_weight: 4.0 # 0.0 +perceptual_weight: 4.0 # 0.0 +i3d_feat: False +restart_thres: 1.0 +no_random_restart: False +norm_type: group +padding_type: replicate +num_groups: 32 + + diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__init__.py b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a01336e46e37471097d8e1420d0ffd5d803a1edd --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__init__.py @@ -0,0 +1 @@ +from .diffusion import Unet3D, GaussianDiffusion, Tester diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e2e233d6ef842d8b6623b3c34f1efde66e12e471 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f7b9717ebc5153d3d8a33bbefbb8ac5a9e73f742 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1854ef5db100c68b0b4add2f150825df3ae5eb4 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7ec4dc0860dc1632c385c75fa9723a8920ceac40 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7ac1a6aea44a1d1741652f2f0a598fb18b2730d Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20c6786f9995cbd455c900944b0ab9501a97c202 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e6ce39946ee6c03ff7518cbcb1309135e0354b2 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/ddim.py b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..1bc51f94355732aedb0ffd254b8166144c468370 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/ddim.py @@ -0,0 +1,206 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from .util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): # "uniform" 'quad' + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + # breakpoint() + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, T, H, W = shape + # breakpoint() + size = (batch_size, C, T, H, W) + # print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + # iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(time_range): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + # breakpoint() + e_t = self.model.denoise_fn(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.denoise_fn(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/diffusion.py b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..87111dc3b2ab671e798efc3a320ae81d024f20b9 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/diffusion.py @@ -0,0 +1,1016 @@ +"Largely taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch" + +import math +import copy +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial + +from torch.utils import data +from pathlib import Path +from torch.optim import Adam +from torchvision import transforms as T, utils +from torch.cuda.amp import autocast, GradScaler +from PIL import Image + +from tqdm import tqdm +from einops import rearrange +from einops_exts import check_shape, rearrange_many + +from rotary_embedding_torch import RotaryEmbedding + +from .text import tokenize, bert_embed, BERT_MODEL_DIM +from torch.utils.data import Dataset, DataLoader +from ..vq_gan_3d.model.vqgan import VQGAN + +import matplotlib.pyplot as plt + +# helpers functions + + +def exists(x): + return x is not None + + +def noop(*args, **kwargs): + pass + + +def is_odd(n): + return (n % 2) == 1 + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +def cycle(dl): + while True: + for data in dl: + yield data + + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + + +def prob_mask_like(shape, prob, device): + if prob == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif prob == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob + + +def is_list_str(x): + if not isinstance(x, (list, tuple)): + return False + return all([type(el) == str for el in x]) + +# relative positional bias + + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads=8, + num_buckets=32, + max_distance=128 + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / + max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min( + val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype=torch.long, device=device) + k_pos = torch.arange(n, dtype=torch.long, device=device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket( + rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') + +# small helper modules + + +class EMA(): + def __init__(self, beta): + super().__init__() + self.beta = beta + + def update_model_average(self, ma_model, current_model): + for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + old_weight, up_weight = ma_params.data, current_params.data + ma_params.data = self.update_average(old_weight, up_weight) + + def update_average(self, old, new): + if old is None: + return new + return old * self.beta + (1 - self.beta) * new + + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +def Upsample(dim): + return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) + + +def Downsample(dim): + return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) + + +class LayerNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1)) + + def forward(self, x): + var = torch.var(x, dim=1, unbiased=False, keepdim=True) + mean = torch.mean(x, dim=1, keepdim=True) + return (x - mean) / (var + self.eps).sqrt() * self.gamma + + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = LayerNorm(dim) + + def forward(self, x, **kwargs): + x = self.norm(x) + return self.fn(x, **kwargs) + +# building block modules + + +class Block(nn.Module): + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding=(0, 1, 1)) + self.norm = nn.GroupNorm(groups, dim_out) + self.act = nn.SiLU() + + def forward(self, x, scale_shift=None): + x = self.proj(x) + x = self.norm(x) + + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + + return self.act(x) + + +class ResnetBlock(nn.Module): + def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): + super().__init__() + self.mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(time_emb_dim, dim_out * 2) + ) if exists(time_emb_dim) else None + + self.block1 = Block(dim, dim_out, groups=groups) + self.block2 = Block(dim_out, dim_out, groups=groups) + self.res_conv = nn.Conv3d( + dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb=None): + + scale_shift = None + if exists(self.mlp): + assert exists(time_emb), 'time emb must be passed in' + time_emb = self.mlp(time_emb) + time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') + scale_shift = time_emb.chunk(2, dim=1) + + h = self.block1(x, scale_shift=scale_shift) + + h = self.block2(h) + return h + self.res_conv(x) + + +class SpatialLinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, f, h, w = x.shape + x = rearrange(x, 'b c f h w -> (b f) c h w') + + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = rearrange_many( + qkv, 'b (h c) x y -> b h c (x y)', h=self.heads) + + q = q.softmax(dim=-2) + k = k.softmax(dim=-1) + + q = q * self.scale + context = torch.einsum('b h d n, b h e n -> b h d e', k, v) + + out = torch.einsum('b h d e, b h d n -> b h e n', context, q) + out = rearrange(out, 'b h c (x y) -> b (h c) x y', + h=self.heads, x=h, y=w) + out = self.to_out(out) + return rearrange(out, '(b f) c h w -> b c f h w', b=b) + +# attention along space and time + + +class EinopsToAndFrom(nn.Module): + def __init__(self, from_einops, to_einops, fn): + super().__init__() + self.from_einops = from_einops + self.to_einops = to_einops + self.fn = fn + + def forward(self, x, **kwargs): + shape = x.shape + reconstitute_kwargs = dict( + tuple(zip(self.from_einops.split(' '), shape))) + x = rearrange(x, f'{self.from_einops} -> {self.to_einops}') + x = self.fn(x, **kwargs) + x = rearrange( + x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + heads=4, + dim_head=32, + rotary_emb=None + ): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + + self.rotary_emb = rotary_emb + self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False) + self.to_out = nn.Linear(hidden_dim, dim, bias=False) + + def forward( + self, + x, + pos_bias=None, + focus_present_mask=None + ): + n, device = x.shape[-2], x.device + + qkv = self.to_qkv(x).chunk(3, dim=-1) + + if exists(focus_present_mask) and focus_present_mask.all(): + # if all batch samples are focusing on present + # it would be equivalent to passing that token's values through to the output + values = qkv[-1] + return self.to_out(values) + + # split out heads + + q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads) + + # scale + + q = q * self.scale + + # rotate positions into queries and keys for time attention + + if exists(self.rotary_emb): + q = self.rotary_emb.rotate_queries_or_keys(q) + k = self.rotary_emb.rotate_queries_or_keys(k) + + # similarity + + sim = einsum('... h i d, ... h j d -> ... h i j', q, k) + + # relative positional bias + + if exists(pos_bias): + sim = sim + pos_bias + + if exists(focus_present_mask) and not (~focus_present_mask).all(): + attend_all_mask = torch.ones( + (n, n), device=device, dtype=torch.bool) + attend_self_mask = torch.eye(n, device=device, dtype=torch.bool) + + mask = torch.where( + rearrange(focus_present_mask, 'b -> b 1 1 1 1'), + rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), + rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), + ) + + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + # numerical stability + + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + # aggregate values + + out = einsum('... h i j, ... h j d -> ... h i d', attn, v) + out = rearrange(out, '... h n d -> ... n (h d)') + return self.to_out(out) + +# model + + +class Unet3D(nn.Module): + def __init__( + self, + dim, + cond_dim=None, + out_dim=None, + dim_mults=(1, 2, 4, 8), + channels=3, + attn_heads=8, + attn_dim_head=32, + use_bert_text_cond=False, + init_dim=None, + init_kernel_size=7, + use_sparse_linear_attn=True, + block_type='resnet', + resnet_groups=8 + ): + super().__init__() + self.channels = channels + + # temporal attention and its relative positional encoding + + rotary_emb = RotaryEmbedding(min(32, attn_dim_head)) + + def temporal_attn(dim): return EinopsToAndFrom('b c f h w', 'b (h w) f c', Attention( + dim, heads=attn_heads, dim_head=attn_dim_head, rotary_emb=rotary_emb)) + + # realistically will not be able to generate that many frames of video... yet + self.time_rel_pos_bias = RelativePositionBias( + heads=attn_heads, max_distance=32) + + # initial conv + + init_dim = default(init_dim, dim) + assert is_odd(init_kernel_size) + + init_padding = init_kernel_size // 2 + self.init_conv = nn.Conv3d(channels, init_dim, (1, init_kernel_size, + init_kernel_size), padding=(0, init_padding, init_padding)) + + self.init_temporal_attn = Residual( + PreNorm(init_dim, temporal_attn(init_dim))) + + # dimensions + + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + # time conditioning + + time_dim = dim * 4 + self.time_mlp = nn.Sequential( + SinusoidalPosEmb(dim), + nn.Linear(dim, time_dim), + nn.GELU(), + nn.Linear(time_dim, time_dim) + ) + + # text conditioning + + self.has_cond = exists(cond_dim) or use_bert_text_cond + cond_dim = BERT_MODEL_DIM if use_bert_text_cond else cond_dim + + self.null_cond_emb = nn.Parameter( + torch.randn(1, cond_dim)) if self.has_cond else None + + cond_dim = time_dim + int(cond_dim or 0) + + # layers + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + + num_resolutions = len(in_out) + # block type + + block_klass = partial(ResnetBlock, groups=resnet_groups) + block_klass_cond = partial(block_klass, time_emb_dim=cond_dim) + + # modules for all layers + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + + self.downs.append(nn.ModuleList([ + block_klass_cond(dim_in, dim_out), + block_klass_cond(dim_out, dim_out), + Residual(PreNorm(dim_out, SpatialLinearAttention( + dim_out, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(), + Residual(PreNorm(dim_out, temporal_attn(dim_out))), + Downsample(dim_out) if not is_last else nn.Identity() + ])) + + mid_dim = dims[-1] + self.mid_block1 = block_klass_cond(mid_dim, mid_dim) + + spatial_attn = EinopsToAndFrom( + 'b c f h w', 'b f (h w) c', Attention(mid_dim, heads=attn_heads)) + + self.mid_spatial_attn = Residual(PreNorm(mid_dim, spatial_attn)) + self.mid_temporal_attn = Residual( + PreNorm(mid_dim, temporal_attn(mid_dim))) + + self.mid_block2 = block_klass_cond(mid_dim, mid_dim) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): + is_last = ind >= (num_resolutions - 1) + + self.ups.append(nn.ModuleList([ + block_klass_cond(dim_out * 2, dim_in), + block_klass_cond(dim_in, dim_in), + Residual(PreNorm(dim_in, SpatialLinearAttention( + dim_in, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(), + Residual(PreNorm(dim_in, temporal_attn(dim_in))), + Upsample(dim_in) if not is_last else nn.Identity() + ])) + + out_dim = default(out_dim, channels) + self.final_conv = nn.Sequential( + block_klass(dim * 2, dim), + nn.Conv3d(dim, out_dim, 1) + ) + + def forward_with_cond_scale( + self, + *args, + cond_scale=2., + **kwargs + ): + logits = self.forward(*args, null_cond_prob=0., **kwargs) + if cond_scale == 1 or not self.has_cond: + return logits + + null_logits = self.forward(*args, null_cond_prob=1., **kwargs) + return null_logits + (logits - null_logits) * cond_scale + + def forward( + self, + x, + time, + cond=None, + null_cond_prob=0., + focus_present_mask=None, + # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time) + prob_focus_present=0. + ): + assert not (self.has_cond and not exists(cond) + ), 'cond must be passed in if cond_dim specified' + x = torch.cat([x, cond], dim=1) + + batch, device = x.shape[0], x.device + + focus_present_mask = default(focus_present_mask, lambda: prob_mask_like( + (batch,), prob_focus_present, device=device)) + + time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device=x.device) + + x = self.init_conv(x) + r = x.clone() + + x = self.init_temporal_attn(x, pos_bias=time_rel_pos_bias) + + t = self.time_mlp(time) if exists(self.time_mlp) else None # [2, 128] + + # classifier free guidance + + if self.has_cond: + batch, device = x.shape[0], x.device + mask = prob_mask_like((batch,), null_cond_prob, device=device) + cond = torch.where(rearrange(mask, 'b -> b 1'), + self.null_cond_emb, cond) + t = torch.cat((t, cond), dim=-1) + + h = [] + + for block1, block2, spatial_attn, temporal_attn, downsample in self.downs: + x = block1(x, t) + x = block2(x, t) + x = spatial_attn(x) + x = temporal_attn(x, pos_bias=time_rel_pos_bias, + focus_present_mask=focus_present_mask) + h.append(x) + x = downsample(x) + + # [2, 256, 32, 4, 4] + x = self.mid_block1(x, t) + x = self.mid_spatial_attn(x) + x = self.mid_temporal_attn( + x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask) + x = self.mid_block2(x, t) + + for block1, block2, spatial_attn, temporal_attn, upsample in self.ups: + x = torch.cat((x, h.pop()), dim=1) + x = block1(x, t) + x = block2(x, t) + x = spatial_attn(x) + x = temporal_attn(x, pos_bias=time_rel_pos_bias, + focus_present_mask=focus_present_mask) + x = upsample(x) + + x = torch.cat((x, r), dim=1) + return self.final_conv(x) + +# gaussian diffusion trainer class + + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = torch.linspace(0, timesteps, steps, dtype=torch.float64) + alphas_cumprod = torch.cos( + ((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return torch.clip(betas, 0, 0.9999) + + +class GaussianDiffusion(nn.Module): + def __init__( + self, + denoise_fn, + *, + image_size, + num_frames, + text_use_bert_cls=False, + channels=3, + timesteps=1000, + loss_type='l1', + use_dynamic_thres=False, # from the Imagen paper + dynamic_thres_percentile=0.9, + vqgan_ckpt=None, + device=None + ): + super().__init__() + self.channels = channels + self.image_size = image_size + self.num_frames = num_frames + self.denoise_fn = denoise_fn + self.device = device + + if vqgan_ckpt: + self.vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt).cuda() + self.vqgan.eval() + else: + self.vqgan = None + + betas = cosine_beta_schedule(timesteps) + + alphas = 1. - betas + alphas_cumprod = torch.cumprod(alphas, axis=0) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.loss_type = loss_type + + # register buffer helper function that casts float64 to float32 + + def register_buffer(name, val): return self.register_buffer( + name, val.to(torch.float32)) + + register_buffer('betas', betas) + register_buffer('alphas_cumprod', alphas_cumprod) + register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) + + # calculations for diffusion q(x_t | x_{t-1}) and others + + register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) + register_buffer('sqrt_one_minus_alphas_cumprod', + torch.sqrt(1. - alphas_cumprod)) + register_buffer('log_one_minus_alphas_cumprod', + torch.log(1. - alphas_cumprod)) + register_buffer('sqrt_recip_alphas_cumprod', + torch.sqrt(1. / alphas_cumprod)) + register_buffer('sqrt_recipm1_alphas_cumprod', + torch.sqrt(1. / alphas_cumprod - 1)) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + + posterior_variance = betas * \ + (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + + register_buffer('posterior_variance', posterior_variance) + + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + + register_buffer('posterior_log_variance_clipped', + torch.log(posterior_variance.clamp(min=1e-20))) + register_buffer('posterior_mean_coef1', betas * + torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) + register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) + * torch.sqrt(alphas) / (1. - alphas_cumprod)) + + # text conditioning parameters + + self.text_use_bert_cls = text_use_bert_cls + + # dynamic thresholding when sampling + + self.use_dynamic_thres = use_dynamic_thres + self.dynamic_thres_percentile = dynamic_thres_percentile + + def q_mean_variance(self, x_start, t): + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract(1. - self.alphas_cumprod, t, x_start.shape) + log_variance = extract( + self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract( + self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool, cond=None, cond_scale=1.): + x_recon = self.predict_start_from_noise( + x, t=t, noise=self.denoise_fn.forward_with_cond_scale(x, t, cond=cond, cond_scale=cond_scale)) + + if clip_denoised: + s = 1. + if self.use_dynamic_thres: + s = torch.quantile( + rearrange(x_recon, 'b ... -> b (...)').abs(), + self.dynamic_thres_percentile, + dim=-1 + ) + + s.clamp_(min=1.) + s = s.view(-1, *((1,) * (x_recon.ndim - 1))) + + # clip by threshold, depending on whether static or dynamic + x_recon = x_recon.clamp(-s, s) / s + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.inference_mode() + def p_sample(self, x, t, cond=None, cond_scale=1., clip_denoised=True): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance( + x=x, t=t, clip_denoised=clip_denoised, cond=cond, cond_scale=cond_scale) + noise = torch.randn_like(x) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, + *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.inference_mode() + def p_sample_loop(self, shape, cond=None, cond_scale=1.): + device = self.betas.device + + b = shape[0] + img = torch.randn(shape, device=device) + # print('cond', cond.shape) + for i in reversed(range(0, self.num_timesteps)): + img = self.p_sample(img, torch.full( + (b,), i, device=device, dtype=torch.long), cond=cond, cond_scale=cond_scale) + + return img + + @torch.inference_mode() + def sample(self, cond=None, cond_scale=1., batch_size=16): + device = next(self.denoise_fn.parameters()).device + + if is_list_str(cond): + cond = bert_embed(tokenize(cond)).to(device) + + # batch_size = cond.shape[0] if exists(cond) else batch_size + batch_size = batch_size + image_size = self.image_size + channels = 8 # self.channels + num_frames = self.num_frames + # print((batch_size, channels, num_frames, image_size, image_size)) + # print('cond_',cond.shape) + _sample = self.p_sample_loop( + (batch_size, channels, num_frames, image_size, image_size), cond=cond, cond_scale=cond_scale) + + if isinstance(self.vqgan, VQGAN): + # denormalize TODO: Remove eventually + _sample = (((_sample + 1.0) / 2.0) * (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) + self.vqgan.codebook.embeddings.min() + + _sample = self.vqgan.decode(_sample, quantize=True) + else: + unnormalize_img(_sample) + + return _sample + + @torch.inference_mode() + def interpolate(self, x1, x2, t=None, lam=0.5): + b, *_, device = *x1.shape, x1.device + t = default(t, self.num_timesteps - 1) + + assert x1.shape == x2.shape + + t_batched = torch.stack([torch.tensor(t, device=device)] * b) + xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) + + img = (1 - lam) * xt1 + lam * xt2 + for i in reversed(range(0, t)): + img = self.p_sample(img, torch.full( + (b,), i, device=device, dtype=torch.long)) + + return img + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, + t, x_start.shape) * noise + ) + + def p_losses(self, x_start, t, cond=None, noise=None, **kwargs): + b, c, f, h, w, device = *x_start.shape, x_start.device + noise = default(noise, lambda: torch.randn_like(x_start)) + # breakpoint() + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # [2, 8, 32, 32, 32] + + if is_list_str(cond): + cond = bert_embed( + tokenize(cond), return_cls_repr=self.text_use_bert_cls) + cond = cond.to(device) + + x_recon = self.denoise_fn(x_noisy, t, cond=cond, **kwargs) + + if self.loss_type == 'l1': + loss = F.l1_loss(noise, x_recon) + elif self.loss_type == 'l2': + loss = F.mse_loss(noise, x_recon) + else: + raise NotImplementedError() + + return loss + + def forward(self, x, *args, **kwargs): + bs = int(x.shape[0]/2) + img=x[:bs,...] + mask=x[bs:,...] + mask_=(1-mask).detach() + masked_img = (img*mask_).detach() + masked_img=masked_img.permute(0,1,-1,-3,-2) + img=img.permute(0,1,-1,-3,-2) + mask=mask.permute(0,1,-1,-3,-2) + # breakpoint() + if isinstance(self.vqgan, VQGAN): + with torch.no_grad(): + img = self.vqgan.encode( + img, quantize=False, include_embeddings=True) + # normalize to -1 and 1 + img = ((img - self.vqgan.codebook.embeddings.min()) / + (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + masked_img = self.vqgan.encode( + masked_img, quantize=False, include_embeddings=True) + # normalize to -1 and 1 + masked_img = ((masked_img - self.vqgan.codebook.embeddings.min()) / + (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + else: + print("Hi") + img = normalize_img(img) + masked_img = normalize_img(masked_img) + mask = mask*2.0 - 1.0 + cc = torch.nn.functional.interpolate(mask, size=masked_img.shape[-3:]) + cond = torch.cat((masked_img, cc), dim=1) + + b, device, img_size, = img.shape[0], img.device, self.image_size + t = torch.randint(0, self.num_timesteps, (b,), device=device).long() + # breakpoint() + return self.p_losses(img, t, cond=cond, *args, **kwargs) + +# trainer class + + +CHANNELS_TO_MODE = { + 1: 'L', + 3: 'RGB', + 4: 'RGBA' +} + + +def seek_all_images(img, channels=3): + assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid' + mode = CHANNELS_TO_MODE[channels] + + i = 0 + while True: + try: + img.seek(i) + yield img.convert(mode) + except EOFError: + break + i += 1 + +# tensor of shape (channels, frames, height, width) -> gif + + +def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True): + tensor = ((tensor - tensor.min()) / (tensor.max() - tensor.min())) * 1.0 + images = map(T.ToPILImage(), tensor.unbind(dim=1)) + first_img, *rest_imgs = images + first_img.save(path, save_all=True, append_images=rest_imgs, + duration=duration, loop=loop, optimize=optimize) + return images + +# gif -> (channels, frame, height, width) tensor + + +def gif_to_tensor(path, channels=3, transform=T.ToTensor()): + img = Image.open(path) + tensors = tuple(map(transform, seek_all_images(img, channels=channels))) + return torch.stack(tensors, dim=1) + + +def identity(t, *args, **kwargs): + return t + + +def normalize_img(t): + return t * 2 - 1 + + +def unnormalize_img(t): + return (t + 1) * 0.5 + + +def cast_num_frames(t, *, frames): + f = t.shape[1] + + if f == frames: + return t + + if f > frames: + return t[:, :frames] + + return F.pad(t, (0, 0, 0, 0, 0, frames - f)) + + +class Dataset(data.Dataset): + def __init__( + self, + folder, + image_size, + channels=3, + num_frames=16, + horizontal_flip=False, + force_num_frames=True, + exts=['gif'] + ): + super().__init__() + self.folder = folder + self.image_size = image_size + self.channels = channels + self.paths = [p for ext in exts for p in Path( + f'{folder}').glob(f'**/*.{ext}')] + + self.cast_num_frames_fn = partial( + cast_num_frames, frames=num_frames) if force_num_frames else identity + + self.transform = T.Compose([ + T.Resize(image_size), + T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity), + T.CenterCrop(image_size), + T.ToTensor() + ]) + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = self.paths[index] + tensor = gif_to_tensor(path, self.channels, transform=self.transform) + return self.cast_num_frames_fn(tensor) + +# trainer class + + +class Tester(object): + def __init__( + self, + diffusion_model, + ): + super().__init__() + self.model = diffusion_model + self.ema_model = copy.deepcopy(self.model) + self.step=0 + self.image_size = diffusion_model.image_size + + self.reset_parameters() + + def reset_parameters(self): + self.ema_model.load_state_dict(self.model.state_dict()) + + + def load(self, milestone, map_location=None, **kwargs): + if milestone == -1: + all_milestones = [int(p.stem.split('-')[-1]) + for p in Path(self.results_folder).glob('**/*.pt')] + assert len( + all_milestones) > 0, 'need to have at least one milestone to load from latest checkpoint (milestone == -1)' + milestone = max(all_milestones) + + if map_location: + data = torch.load(milestone, map_location=map_location) + else: + data = torch.load(milestone) + + self.step = data['step'] + self.model.load_state_dict(data['model'], **kwargs) + self.ema_model.load_state_dict(data['ema'], **kwargs) + + diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/text.py b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/text.py new file mode 100644 index 0000000000000000000000000000000000000000..42dca78bec5075fb4f59c522aa3d00cc395d7536 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/text.py @@ -0,0 +1,94 @@ +"Taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch" + +import torch +from einops import rearrange + + +def exists(val): + return val is not None + +# singleton globals + + +MODEL = None +TOKENIZER = None +BERT_MODEL_DIM = 768 + + +def get_tokenizer(): + global TOKENIZER + if not exists(TOKENIZER): + TOKENIZER = torch.hub.load( + 'huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased') + return TOKENIZER + + +def get_bert(): + global MODEL + if not exists(MODEL): + MODEL = torch.hub.load( + 'huggingface/pytorch-transformers', 'model', 'bert-base-cased') + if torch.cuda.is_available(): + MODEL = MODEL.cuda() + + return MODEL + +# tokenize + + +def tokenize(texts, add_special_tokens=True): + if not isinstance(texts, (list, tuple)): + texts = [texts] + + tokenizer = get_tokenizer() + + encoding = tokenizer.batch_encode_plus( + texts, + add_special_tokens=add_special_tokens, + padding=True, + return_tensors='pt' + ) + + token_ids = encoding.input_ids + return token_ids + +# embedding function + + +@torch.no_grad() +def bert_embed( + token_ids, + return_cls_repr=False, + eps=1e-8, + pad_id=0. +): + model = get_bert() + mask = token_ids != pad_id + + if torch.cuda.is_available(): + token_ids = token_ids.cuda() + mask = mask.cuda() + + outputs = model( + input_ids=token_ids, + attention_mask=mask, + output_hidden_states=True + ) + + hidden_state = outputs.hidden_states[-1] + + if return_cls_repr: + # return [cls] as representation + return hidden_state[:, 0] + + if not exists(mask): + return hidden_state.mean(dim=1) + + # mean all tokens excluding [cls], accounting for length + mask = mask[:, 1:] + mask = rearrange(mask, 'b n -> b n 1') + + numer = (hidden_state[:, 1:] * mask).sum(dim=1) + denom = mask.sum(dim=1) + masked_mean = numer / (denom + eps) + return diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/time_embedding.py b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/time_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..7c911fa4f485295005cda9c9c2099db3ddbda15c --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/time_embedding.py @@ -0,0 +1,75 @@ +import math +import torch +import torch.nn as nn + +from monai.networks.layers.utils import get_act_layer + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, emb_dim=16, downscale_freq_shift=1, max_period=10000, flip_sin_to_cos=False): + super().__init__() + self.emb_dim = emb_dim + self.downscale_freq_shift = downscale_freq_shift + self.max_period = max_period + self.flip_sin_to_cos = flip_sin_to_cos + + def forward(self, x): + device = x.device + half_dim = self.emb_dim // 2 + emb = math.log(self.max_period) / \ + (half_dim - self.downscale_freq_shift) + emb = torch.exp(-emb*torch.arange(half_dim, device=device)) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + + if self.flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + if self.emb_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class LearnedSinusoidalPosEmb(nn.Module): + """ following @crowsonkb 's lead with learned sinusoidal pos emb """ + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, emb_dim): + super().__init__() + self.emb_dim = emb_dim + half_dim = emb_dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x): + x = x[:, None] + freqs = x * self.weights[None, :] * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + if self.emb_dim % 2 == 1: + fouriered = torch.nn.functional.pad(fouriered, (0, 1, 0, 0)) + return fouriered + + +class TimeEmbbeding(nn.Module): + def __init__( + self, + emb_dim=64, + pos_embedder=SinusoidalPosEmb, + pos_embedder_kwargs={}, + act_name=("SWISH", {}) # Swish = SiLU + ): + super().__init__() + self.emb_dim = emb_dim + self.pos_emb_dim = pos_embedder_kwargs.get('emb_dim', emb_dim//4) + pos_embedder_kwargs['emb_dim'] = self.pos_emb_dim + self.pos_embedder = pos_embedder(**pos_embedder_kwargs) + + self.time_emb = nn.Sequential( + self.pos_embedder, + nn.Linear(self.pos_emb_dim, self.emb_dim), + get_act_layer(act_name), + nn.Linear(self.emb_dim, self.emb_dim) + ) + + def forward(self, time): + return self.time_emb(time) diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/unet.py b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..e9cae8021e874a92b6baaaf4decd802516c2dc87 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/unet.py @@ -0,0 +1,226 @@ +from ddpm.time_embedding import TimeEmbbeding + +import monai.networks.nets as nets +import torch +import torch.nn as nn +from einops import rearrange + +from monai.networks.blocks import UnetBasicBlock, UnetResBlock, UnetUpBlock, Convolution, UnetOutBlock +from monai.networks.layers.utils import get_act_layer + + +class DownBlock(nn.Module): + def __init__( + self, + spatial_dims, + in_ch, + out_ch, + time_emb_dim, + cond_emb_dim, + act_name=("swish", {}), + **kwargs): + super(DownBlock, self).__init__() + self.loca_time_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(time_emb_dim, in_ch) # in_ch * 2 + ) + if cond_emb_dim is not None: + self.loca_cond_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(cond_emb_dim, in_ch), + ) + self.down_op = UnetBasicBlock( + spatial_dims, in_ch, out_ch, act_name=act_name, **kwargs) + + def forward(self, x, time_emb, cond_emb): + b, c, *_ = x.shape + sp_dim = x.ndim-2 + + # ------------ Time ---------- + time_emb = self.loca_time_embedder(time_emb) + time_emb = time_emb.reshape(b, c, *((1,)*sp_dim)) + # scale, shift = time_emb.chunk(2, dim = 1) + + # ------------ Combine ------------ + # x = x * (scale + 1) + shift + x = x + time_emb + + # ----------- Condition ------------ + if cond_emb is not None: + cond_emb = self.loca_cond_embedder(cond_emb) + cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim)) + x = x + cond_emb + + # ----------- Image --------- + y = self.down_op(x) + return y + + +class UpBlock(nn.Module): + def __init__( + self, + spatial_dims, + skip_ch, + enc_ch, + time_emb_dim, + cond_emb_dim, + act_name=("swish", {}), + **kwargs): + super(UpBlock, self).__init__() + self.up_op = UnetUpBlock(spatial_dims, enc_ch, + skip_ch, act_name=act_name, **kwargs) + self.loca_time_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(time_emb_dim, skip_ch * 2), + ) + if cond_emb_dim is not None: + self.loca_cond_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(cond_emb_dim, skip_ch * 2), + ) + + def forward(self, x_skip, x_enc, time_emb, cond_emb): + b, c, *_ = x_enc.shape + sp_dim = x_enc.ndim-2 + + # ----------- Time -------------- + time_emb = self.loca_time_embedder(time_emb) + time_emb = time_emb.reshape(b, c, *((1,)*sp_dim)) + # scale, shift = time_emb.chunk(2, dim = 1) + + # -------- Combine ------------- + # y = x * (scale + 1) + shift + x_enc = x_enc + time_emb + + # ----------- Condition ------------ + if cond_emb is not None: + cond_emb = self.loca_cond_embedder(cond_emb) + cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim)) + x_enc = x_enc + cond_emb + + # ----------- Image ------------- + y = self.up_op(x_enc, x_skip) + + # -------- Combine ------------- + # y = y * (scale + 1) + shift + + return y + + +class UNet(nn.Module): + + def __init__(self, + in_ch=1, + out_ch=1, + spatial_dims=3, + hid_chs=[32, 64, 128, 256, 512], + kernel_sizes=[(1, 3, 3), (1, 3, 3), (1, 3, 3), 3, 3], + strides=[1, (1, 2, 2), (1, 2, 2), 2, 2], + upsample_kernel_sizes=None, + act_name=("SWISH", {}), + norm_name=("INSTANCE", {"affine": True}), + time_embedder=TimeEmbbeding, + time_embedder_kwargs={}, + cond_embedder=None, + cond_embedder_kwargs={}, + # True = all but last layer, 0/False=disable, 1=only first layer, ... + deep_ver_supervision=True, + estimate_variance=False, + use_self_conditioning=False, + **kwargs + ): + super().__init__() + if upsample_kernel_sizes is None: + upsample_kernel_sizes = strides[1:] + + # ------------- Time-Embedder----------- + self.time_embedder = time_embedder(**time_embedder_kwargs) + + # ------------- Condition-Embedder----------- + if cond_embedder is not None: + self.cond_embedder = cond_embedder(**cond_embedder_kwargs) + cond_emb_dim = self.cond_embedder.emb_dim + else: + self.cond_embedder = None + cond_emb_dim = None + + # ----------- In-Convolution ------------ + in_ch = in_ch*2 if use_self_conditioning else in_ch + self.inc = UnetBasicBlock(spatial_dims, in_ch, hid_chs[0], kernel_size=kernel_sizes[0], stride=strides[0], + act_name=act_name, norm_name=norm_name, **kwargs) + + # ----------- Encoder ---------------- + self.encoders = nn.ModuleList([ + DownBlock(spatial_dims, hid_chs[i-1], hid_chs[i], time_emb_dim=self.time_embedder.emb_dim, + cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[ + i], stride=strides[i], act_name=act_name, + norm_name=norm_name, **kwargs) + for i in range(1, len(strides)) + ]) + + # ------------ Decoder ---------- + self.decoders = nn.ModuleList([ + UpBlock(spatial_dims, hid_chs[i], hid_chs[i+1], time_emb_dim=self.time_embedder.emb_dim, + cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[i + + 1], stride=strides[i+1], act_name=act_name, + norm_name=norm_name, upsample_kernel_size=upsample_kernel_sizes[i], **kwargs) + for i in range(len(strides)-1) + ]) + + # --------------- Out-Convolution ---------------- + out_ch_hor = out_ch*2 if estimate_variance else out_ch + self.outc = UnetOutBlock( + spatial_dims, hid_chs[0], out_ch_hor, dropout=None) + if isinstance(deep_ver_supervision, bool): + deep_ver_supervision = len( + strides)-2 if deep_ver_supervision else 0 + self.outc_ver = nn.ModuleList([ + UnetOutBlock(spatial_dims, hid_chs[i], out_ch, dropout=None) + for i in range(1, deep_ver_supervision+1) + ]) + + def forward(self, x_t, t, cond=None, self_cond=None, **kwargs): + condition = cond + # x_t [B, C, (D), H, W] + # t [B,] + + # -------- In-Convolution -------------- + x = [None for _ in range(len(self.encoders)+1)] + x_t = torch.cat([x_t, self_cond], + dim=1) if self_cond is not None else x_t + x[0] = self.inc(x_t) + + # -------- Time Embedding (Gloabl) ----------- + time_emb = self.time_embedder(t) # [B, C] + + # -------- Condition Embedding (Gloabl) ----------- + if (condition is None) or (self.cond_embedder is None): + cond_emb = None + else: + cond_emb = self.cond_embedder(condition) # [B, C] + + # --------- Encoder -------------- + for i in range(len(self.encoders)): + x[i+1] = self.encoders[i](x[i], time_emb, cond_emb) + + # -------- Decoder ----------- + for i in range(len(self.decoders), 0, -1): + x[i-1] = self.decoders[i-1](x[i-1], x[i], time_emb, cond_emb) + + # ---------Out-Convolution ------------ + y_hor = self.outc(x[0]) + y_ver = [outc_ver_i(x[i+1]) + for i, outc_ver_i in enumerate(self.outc_ver)] + + return y_hor # , y_ver + + def forward_with_cond_scale(self, *args, cond_scale=0., **kwargs): + return self.forward(*args, **kwargs) + + +if __name__ == '__main__': + model = UNet(in_ch=3) + input = torch.randn((1, 3, 16, 128, 128)) + time = torch.randn((1,)) + out_hor, out_ver = model(input, time) + print(out_hor[0].shape) diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/util.py b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..2047cf56802046bbf37cfd33f819ed75a239a7df --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/ddpm/util.py @@ -0,0 +1,271 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +# from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + if c != 1: + steps_out = ddim_timesteps + 1 + else: + steps_out = ddim_timesteps + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cbf3deb85a40168e55d5395f42018dd06034c55b Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__init__.py b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38c08845c1ee2f7b4b52bf95fa282b0808931a3a --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__init__.py @@ -0,0 +1,3 @@ +from .vqgan import VQGAN +from .codebook import Codebook +from .lpips import LPIPS diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c2851c21c12d1381767f75f705c7b896990b106 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff829982a3c7189eca26440f3947bc15d9040771 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55de5141f6e8504572dfd74c30009219697c8c7d Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d0c1659ef1fe1d88ada76ea763eafd39c5b70820 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth new file mode 100644 index 0000000000000000000000000000000000000000..f57dcf5cc764d61c8a460365847fb2137ff0a62d --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868 +size 7289 diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/codebook.py b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/codebook.py new file mode 100644 index 0000000000000000000000000000000000000000..c17ed701a2824cc0dcd75670d47a6ab3842e9d35 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/codebook.py @@ -0,0 +1,109 @@ +""" Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from ..utils import shift_dim + + +class Codebook(nn.Module): + def __init__(self, n_codes, embedding_dim, no_random_restart=False, restart_thres=1.0): + super().__init__() + self.register_buffer('embeddings', torch.randn(n_codes, embedding_dim)) + self.register_buffer('N', torch.zeros(n_codes)) + self.register_buffer('z_avg', self.embeddings.data.clone()) + + self.n_codes = n_codes + self.embedding_dim = embedding_dim + self._need_init = True + self.no_random_restart = no_random_restart + self.restart_thres = restart_thres + + def _tile(self, x): + d, ew = x.shape + if d < self.n_codes: + n_repeats = (self.n_codes + d - 1) // d + std = 0.01 / np.sqrt(ew) + x = x.repeat(n_repeats, 1) + x = x + torch.randn_like(x) * std + return x + + def _init_embeddings(self, z): + # z: [b, c, t, h, w] + self._need_init = False + breakpoint() + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [65536, 8] [2, 8, 32, 32, 32] + y = self._tile(flat_inputs) # [65536, 8] + + d = y.shape[0] + _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + self.embeddings.data.copy_(_k_rand) + self.z_avg.data.copy_(_k_rand) + self.N.data.copy_(torch.ones(self.n_codes)) + + def forward(self, z): + # z: [b, c, t, h, w] + if self._need_init and self.training: + self._init_embeddings(z) + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [bthw, c] [65536, 8] + distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \ + - 2 * flat_inputs @ self.embeddings.t() \ + + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) # [bthw, c] [65536, 8] + + encoding_indices = torch.argmin(distances, dim=1) # [65536] + encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as( + flat_inputs) # [bthw, ncode] [65536, 16384] + encoding_indices = encoding_indices.view( + z.shape[0], *z.shape[2:]) # [b, t, h, w, ncode] [2, 32, 32, 32] + + embeddings = F.embedding( + encoding_indices, self.embeddings) # [b, t, h, w, c] self.embeddings [16384, 8] + embeddings = shift_dim(embeddings, -1, 1) # [b, c, t, h, w] [2, 8, 32, 32, 32] + + commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) + + # EMA codebook update + if self.training: + n_total = encode_onehot.sum(dim=0) # [16384] + encode_sum = flat_inputs.t() @ encode_onehot # [8, 16384] + if dist.is_initialized(): + dist.all_reduce(n_total) + dist.all_reduce(encode_sum) + + self.N.data.mul_(0.99).add_(n_total, alpha=0.01) + self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) + + n = self.N.sum() + weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n + encode_normalized = self.z_avg / weights.unsqueeze(1) + self.embeddings.data.copy_(encode_normalized) + + y = self._tile(flat_inputs) + _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + + if not self.no_random_restart: + usage = (self.N.view(self.n_codes, 1) + >= self.restart_thres).float() + self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) + + embeddings_st = (embeddings - z).detach() + z + + avg_probs = torch.mean(encode_onehot, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * + torch.log(avg_probs + 1e-10))) + + return dict(embeddings=embeddings_st, encodings=encoding_indices, + commitment_loss=commitment_loss, perplexity=perplexity) + + def dictionary_lookup(self, encodings): + embeddings = F.embedding(encodings, self.embeddings) + return embeddings diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/lpips.py b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..019434b9b29945d67e6e0a95dec324df7ff908f3 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/lpips.py @@ -0,0 +1,181 @@ +"""Adapted from https://github.com/SongweiGe/TATS""" + +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + + +from collections import namedtuple +from torchvision import models +import torch.nn as nn +import torch +from tqdm import tqdm +import requests +import os +import hashlib +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format( + name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + # self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + self.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name is not "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + model.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer( + input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor( + outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) + for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor( + [-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor( + [.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, + padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple( + "VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, + h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x/(norm_factor+eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py new file mode 100644 index 0000000000000000000000000000000000000000..a64ec8878c0ecbf418775e2958b2cbf578ce2918 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py @@ -0,0 +1,561 @@ +"""Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import math +import argparse +import numpy as np +import pickle as pkl + +import pytorch_lightning as pl +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from ..utils import shift_dim, adopt_weight, comp_getattr +from .lpips import LPIPS +from .codebook import Codebook + + +def silu(x): + return x*torch.sigmoid(x) + + +class SiLU(nn.Module): + def __init__(self): + super(SiLU, self).__init__() + + def forward(self, x): + return silu(x) + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1. - logits_real)) + loss_fake = torch.mean(F.relu(1. + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake))) + return d_loss + + +class VQGAN(pl.LightningModule): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.embedding_dim = cfg.model.embedding_dim # 8 + self.n_codes = cfg.model.n_codes # 16384 + + self.encoder = Encoder(cfg.model.n_hiddens, # 16 + cfg.model.downsample, # [2, 2, 2] + cfg.dataset.image_channels, # 1 + cfg.model.norm_type, # group + cfg.model.padding_type, # replicate + cfg.model.num_groups, # 32 + ) + self.decoder = Decoder( + cfg.model.n_hiddens, cfg.model.downsample, cfg.dataset.image_channels, cfg.model.norm_type, cfg.model.num_groups) + self.enc_out_ch = self.encoder.out_channels + self.pre_vq_conv = SamePadConv3d( + self.enc_out_ch, cfg.model.embedding_dim, 1, padding_type=cfg.model.padding_type) + self.post_vq_conv = SamePadConv3d( + cfg.model.embedding_dim, self.enc_out_ch, 1) + + self.codebook = Codebook(cfg.model.n_codes, cfg.model.embedding_dim, + no_random_restart=cfg.model.no_random_restart, restart_thres=cfg.model.restart_thres) + + self.gan_feat_weight = cfg.model.gan_feat_weight + # TODO: Changed batchnorm from sync to normal + self.image_discriminator = NLayerDiscriminator( + cfg.dataset.image_channels, cfg.model.disc_channels, cfg.model.disc_layers, norm_layer=nn.BatchNorm2d) + self.video_discriminator = NLayerDiscriminator3D( + cfg.dataset.image_channels, cfg.model.disc_channels, cfg.model.disc_layers, norm_layer=nn.BatchNorm3d) + + if cfg.model.disc_loss_type == 'vanilla': + self.disc_loss = vanilla_d_loss + elif cfg.model.disc_loss_type == 'hinge': + self.disc_loss = hinge_d_loss + + self.perceptual_model = LPIPS().eval() + + self.image_gan_weight = cfg.model.image_gan_weight + self.video_gan_weight = cfg.model.video_gan_weight + + self.perceptual_weight = cfg.model.perceptual_weight + + self.l1_weight = cfg.model.l1_weight + self.save_hyperparameters() + + def encode(self, x, include_embeddings=False, quantize=True): + h = self.pre_vq_conv(self.encoder(x)) + if quantize: + vq_output = self.codebook(h) + if include_embeddings: + return vq_output['embeddings'], vq_output['encodings'] + else: + return vq_output['encodings'] + return h + + def decode(self, latent, quantize=False): + if quantize: + vq_output = self.codebook(latent) + latent = vq_output['encodings'] + h = F.embedding(latent, self.codebook.embeddings) + h = self.post_vq_conv(shift_dim(h, -1, 1)) + return self.decoder(h) + + def forward(self, x, optimizer_idx=None, log_image=False): + B, C, T, H, W = x.shape + z = self.pre_vq_conv(self.encoder(x)) # [2, 32, 32, 32, 32] [2, 8, 32, 32, 32] + vq_output = self.codebook(z) # ['embeddings', 'encodings', 'commitment_loss', 'perplexity'] + x_recon = self.decoder(self.post_vq_conv(vq_output['embeddings'])) # [2, 8, 32, 32, 32] [2, 32, 32, 32, 32] + + recon_loss = F.l1_loss(x_recon, x) * self.l1_weight + + # Selects one random 2D image from each 3D Image + frame_idx = torch.randint(0, T, [B]).cuda() + frame_idx_selected = frame_idx.reshape(-1, + 1, 1, 1, 1).repeat(1, C, 1, H, W) # [2, 1, 1, 64, 64] + frames = torch.gather(x, 2, frame_idx_selected).squeeze(2) # [2, 1, 64, 64] + frames_recon = torch.gather(x_recon, 2, frame_idx_selected).squeeze(2) # [2, 1, 64, 64] + + if log_image: + return frames, frames_recon, x, x_recon + + if optimizer_idx == 0: + # Autoencoder - train the "generator" + + # Perceptual loss + perceptual_loss = 0 + if self.perceptual_weight > 0: + perceptual_loss = self.perceptual_model( + frames, frames_recon).mean() * self.perceptual_weight + + # Discriminator loss (turned on after a certain epoch) + logits_image_fake, pred_image_fake = self.image_discriminator( + frames_recon) + logits_video_fake, pred_video_fake = self.video_discriminator( + x_recon) + g_image_loss = -torch.mean(logits_image_fake) + g_video_loss = -torch.mean(logits_video_fake) + g_loss = self.image_gan_weight*g_image_loss + self.video_gan_weight*g_video_loss + disc_factor = adopt_weight( + self.global_step, threshold=self.cfg.model.discriminator_iter_start) + aeloss = disc_factor * g_loss + + # GAN feature matching loss - tune features such that we get the same prediction result on the discriminator + image_gan_feat_loss = 0 + video_gan_feat_loss = 0 + feat_weights = 4.0 / (3 + 1) + if self.image_gan_weight > 0: + logits_image_real, pred_image_real = self.image_discriminator( + frames) + for i in range(len(pred_image_fake)-1): + image_gan_feat_loss += feat_weights * \ + F.l1_loss(pred_image_fake[i], pred_image_real[i].detach( + )) * (self.image_gan_weight > 0) + if self.video_gan_weight > 0: + logits_video_real, pred_video_real = self.video_discriminator( + x) + for i in range(len(pred_video_fake)-1): + video_gan_feat_loss += feat_weights * \ + F.l1_loss(pred_video_fake[i], pred_video_real[i].detach( + )) * (self.video_gan_weight > 0) + gan_feat_loss = disc_factor * self.gan_feat_weight * \ + (image_gan_feat_loss + video_gan_feat_loss) + + self.log("train/g_image_loss", g_image_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/g_video_loss", g_video_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/image_gan_feat_loss", image_gan_feat_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/video_gan_feat_loss", video_gan_feat_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/perceptual_loss", perceptual_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log("train/recon_loss", recon_loss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + self.log("train/aeloss", aeloss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + self.log("train/commitment_loss", vq_output['commitment_loss'], + prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log('train/perplexity', vq_output['perplexity'], + prog_bar=True, logger=True, on_step=True, on_epoch=True) + return recon_loss, x_recon, vq_output, aeloss, perceptual_loss, gan_feat_loss + + if optimizer_idx == 1: + # Train discriminator + logits_image_real, _ = self.image_discriminator(frames.detach()) + logits_video_real, _ = self.video_discriminator(x.detach()) + + logits_image_fake, _ = self.image_discriminator( + frames_recon.detach()) + logits_video_fake, _ = self.video_discriminator(x_recon.detach()) + + d_image_loss = self.disc_loss(logits_image_real, logits_image_fake) + d_video_loss = self.disc_loss(logits_video_real, logits_video_fake) + disc_factor = adopt_weight( + self.global_step, threshold=self.cfg.model.discriminator_iter_start) + discloss = disc_factor * \ + (self.image_gan_weight*d_image_loss + + self.video_gan_weight*d_video_loss) + + self.log("train/logits_image_real", logits_image_real.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_image_fake", logits_image_fake.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_video_real", logits_video_real.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_video_fake", logits_video_fake.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/d_image_loss", d_image_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/d_video_loss", d_video_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/discloss", discloss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + return discloss + + perceptual_loss = self.perceptual_model( + frames, frames_recon) * self.perceptual_weight + return recon_loss, x_recon, vq_output, perceptual_loss + + def training_step(self, batch, batch_idx, optimizer_idx): + x = batch['image'] + if optimizer_idx == 0: + recon_loss, _, vq_output, aeloss, perceptual_loss, gan_feat_loss = self.forward( + x, optimizer_idx) + commitment_loss = vq_output['commitment_loss'] + loss = recon_loss + commitment_loss + aeloss + perceptual_loss + gan_feat_loss + if optimizer_idx == 1: + discloss = self.forward(x, optimizer_idx) + loss = discloss + return loss + + def validation_step(self, batch, batch_idx): + x = batch['image'] # TODO: batch['stft'] + recon_loss, _, vq_output, perceptual_loss = self.forward(x) + self.log('val/recon_loss', recon_loss, prog_bar=True) + self.log('val/perceptual_loss', perceptual_loss, prog_bar=True) + self.log('val/perplexity', vq_output['perplexity'], prog_bar=True) + self.log('val/commitment_loss', + vq_output['commitment_loss'], prog_bar=True) + + def configure_optimizers(self): + lr = self.cfg.model.lr + opt_ae = torch.optim.Adam(list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.pre_vq_conv.parameters()) + + list(self.post_vq_conv.parameters()) + + list(self.codebook.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(list(self.image_discriminator.parameters()) + + list(self.video_discriminator.parameters()), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def log_images(self, batch, **kwargs): + log = dict() + x = batch['image'] + x = x.to(self.device) + frames, frames_rec, _, _ = self(x, log_image=True) + log["inputs"] = frames + log["reconstructions"] = frames_rec + #log['mean_org'] = batch['mean_org'] + #log['std_org'] = batch['std_org'] + return log + + def log_videos(self, batch, **kwargs): + log = dict() + x = batch['image'] + _, _, x, x_rec = self(x, log_image=True) + log["inputs"] = x + log["reconstructions"] = x_rec + #log['mean_org'] = batch['mean_org'] + #log['std_org'] = batch['std_org'] + return log + + +def Normalize(in_channels, norm_type='group', num_groups=32): + assert norm_type in ['group', 'batch'] + if norm_type == 'group': + # TODO Changed num_groups from 32 to 8 + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == 'batch': + return torch.nn.SyncBatchNorm(in_channels) + + +class Encoder(nn.Module): + def __init__(self, n_hiddens, downsample, image_channel=3, norm_type='group', padding_type='replicate', num_groups=32): + super().__init__() + n_times_downsample = np.array([int(math.log2(d)) for d in downsample]) + self.conv_blocks = nn.ModuleList() + max_ds = n_times_downsample.max() + + self.conv_first = SamePadConv3d( + image_channel, n_hiddens, kernel_size=3, padding_type=padding_type) + + for i in range(max_ds): + block = nn.Module() + in_channels = n_hiddens * 2**i + out_channels = n_hiddens * 2**(i+1) + stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) + block.down = SamePadConv3d( + in_channels, out_channels, 4, stride=stride, padding_type=padding_type) + block.res = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + self.conv_blocks.append(block) + n_times_downsample -= 1 + + self.final_block = nn.Sequential( + Normalize(out_channels, norm_type, num_groups=num_groups), + SiLU() + ) + + self.out_channels = out_channels + + def forward(self, x): + h = self.conv_first(x) + for block in self.conv_blocks: + h = block.down(h) + h = block.res(h) + h = self.final_block(h) + return h + + +class Decoder(nn.Module): + def __init__(self, n_hiddens, upsample, image_channel, norm_type='group', num_groups=32): + super().__init__() + + n_times_upsample = np.array([int(math.log2(d)) for d in upsample]) + max_us = n_times_upsample.max() + + in_channels = n_hiddens*2**max_us + self.final_block = nn.Sequential( + Normalize(in_channels, norm_type, num_groups=num_groups), + SiLU() + ) + + self.conv_blocks = nn.ModuleList() + for i in range(max_us): + block = nn.Module() + in_channels = in_channels if i == 0 else n_hiddens*2**(max_us-i+1) + out_channels = n_hiddens*2**(max_us-i) + us = tuple([2 if d > 0 else 1 for d in n_times_upsample]) + block.up = SamePadConvTranspose3d( + in_channels, out_channels, 4, stride=us) + block.res1 = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + block.res2 = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + self.conv_blocks.append(block) + n_times_upsample -= 1 + + self.conv_last = SamePadConv3d( + out_channels, image_channel, kernel_size=3) + + def forward(self, x): + h = self.final_block(x) + for i, block in enumerate(self.conv_blocks): + h = block.up(h) + h = block.res1(h) + h = block.res2(h) + h = self.conv_last(h) + return h + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group', padding_type='replicate', num_groups=32): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, norm_type, num_groups=num_groups) + self.conv1 = SamePadConv3d( + in_channels, out_channels, kernel_size=3, padding_type=padding_type) + self.dropout = torch.nn.Dropout(dropout) + self.norm2 = Normalize(in_channels, norm_type, num_groups=num_groups) + self.conv2 = SamePadConv3d( + out_channels, out_channels, kernel_size=3, padding_type=padding_type) + if self.in_channels != self.out_channels: + self.conv_shortcut = SamePadConv3d( + in_channels, out_channels, kernel_size=3, padding_type=padding_type) + + def forward(self, x): + h = x + h = self.norm1(h) + h = silu(h) + h = self.conv1(h) + h = self.norm2(h) + h = silu(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.conv_shortcut(x) + + return x+h + + +# Does not support dilation +class SamePadConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + if isinstance(stride, int): + stride = (stride,) * 3 + + # assumes that the input shape is divisible by stride + total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) + pad_input = [] + for p in total_pad[::-1]: # reverse since F.pad starts from last dim + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + self.pad_input = pad_input + self.padding_type = padding_type + + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, + stride=stride, padding=0, bias=bias) + + def forward(self, x): + return self.conv(F.pad(x, self.pad_input, mode=self.padding_type)) + + +class SamePadConvTranspose3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + if isinstance(stride, int): + stride = (stride,) * 3 + + total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) + pad_input = [] + for p in total_pad[::-1]: # reverse since F.pad starts from last dim + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + self.pad_input = pad_input + self.padding_type = padding_type + + self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, + stride=stride, bias=bias, + padding=tuple([k - 1 for k in kernel_size])) + + def forward(self, x): + return self.convt(F.pad(x, self.pad_input, mode=self.padding_type)) + + +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True): + # def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=True): + super(NLayerDiscriminator, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, + stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, + stride=1, padding=padw)]] + + if use_sigmoid: + sequence += [[nn.Sigmoid()]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[-1], res[1:] + else: + return self.model(input), _ + + +class NLayerDiscriminator3D(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True): + super(NLayerDiscriminator3D, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv3d(input_nc, ndf, kernel_size=kw, + stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv3d(nf, 1, kernel_size=kw, + stride=1, padding=padw)]] + + if use_sigmoid: + sequence += [[nn.Sigmoid()]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[-1], res[1:] + else: + return self.model(input), _ diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/utils.py b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..acf1a21270d98d0a1ea9935d00f9f16d89d54551 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/ldm/vq_gan_3d/utils.py @@ -0,0 +1,177 @@ +""" Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import warnings +import torch +import imageio + +import math +import numpy as np + +import sys +import pdb as pdb_original +import logging + +import imageio.core.util +logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR) + + +class ForkedPdb(pdb_original.Pdb): + """A Pdb subclass that may be used + from a forked multiprocessing child + + """ + + def interaction(self, *args, **kwargs): + _stdin = sys.stdin + try: + sys.stdin = open('/dev/stdin') + pdb_original.Pdb.interaction(self, *args, **kwargs) + finally: + sys.stdin = _stdin + + +# Shifts src_tf dim to dest dim +# i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c) +def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): + n_dims = len(x.shape) + if src_dim < 0: + src_dim = n_dims + src_dim + if dest_dim < 0: + dest_dim = n_dims + dest_dim + + assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims + + dims = list(range(n_dims)) + del dims[src_dim] + + permutation = [] + ctr = 0 + for i in range(n_dims): + if i == dest_dim: + permutation.append(src_dim) + else: + permutation.append(dims[ctr]) + ctr += 1 + x = x.permute(permutation) + if make_contiguous: + x = x.contiguous() + return x + + +# reshapes tensor start from dim i (inclusive) +# to dim j (exclusive) to the desired shape +# e.g. if x.shape = (b, thw, c) then +# view_range(x, 1, 2, (t, h, w)) returns +# x of shape (b, t, h, w, c) +def view_range(x, i, j, shape): + shape = tuple(shape) + + n_dims = len(x.shape) + if i < 0: + i = n_dims + i + + if j is None: + j = n_dims + elif j < 0: + j = n_dims + j + + assert 0 <= i < j <= n_dims + + x_shape = x.shape + target_shape = x_shape[:i] + shape + x_shape[j:] + return x.view(target_shape) + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def tensor_slice(x, begin, size): + assert all([b >= 0 for b in begin]) + size = [l - b if s == -1 else s + for s, b, l in zip(size, begin, x.shape)] + assert all([s >= 0 for s in size]) + + slices = [slice(b, b + s) for b, s in zip(begin, size)] + return x[slices] + + +def adopt_weight(global_step, threshold=0, value=0.): + weight = 1 + if global_step < threshold: + weight = value + return weight + + +def save_video_grid(video, fname, nrow=None, fps=6): + b, c, t, h, w = video.shape + video = video.permute(0, 2, 3, 4, 1) + video = (video.cpu().numpy() * 255).astype('uint8') + if nrow is None: + nrow = math.ceil(math.sqrt(b)) + ncol = math.ceil(b / nrow) + padding = 1 + video_grid = np.zeros((t, (padding + h) * nrow + padding, + (padding + w) * ncol + padding, c), dtype='uint8') + for i in range(b): + r = i // ncol + c = i % ncol + start_r = (padding + h) * r + start_c = (padding + w) * c + video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] + video = [] + for i in range(t): + video.append(video_grid[i]) + imageio.mimsave(fname, video, fps=fps) + ## skvideo.io.vwrite(fname, video_grid, inputdict={'-r': '5'}) + #print('saved videos to', fname) + + +def comp_getattr(args, attr_name, default=None): + if hasattr(args, attr_name): + return getattr(args, attr_name) + else: + return default + + +def visualize_tensors(t, name=None, nest=0): + if name is not None: + print(name, "current nest: ", nest) + print("type: ", type(t)) + if 'dict' in str(type(t)): + print(t.keys()) + for k in t.keys(): + if t[k] is None: + print(k, "None") + else: + if 'Tensor' in str(type(t[k])): + print(k, t[k].shape) + elif 'dict' in str(type(t[k])): + print(k, 'dict') + visualize_tensors(t[k], name, nest + 1) + elif 'list' in str(type(t[k])): + print(k, len(t[k])) + visualize_tensors(t[k], name, nest + 1) + elif 'list' in str(type(t)): + print("list length: ", len(t)) + for t2 in t: + visualize_tensors(t2, name, nest + 1) + elif 'Tensor' in str(type(t)): + print(t.shape) + else: + print(t) + return "" diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/utils.py b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..54a3d68432d165ab2895859a89d7be4d150e9721 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/utils.py @@ -0,0 +1,465 @@ +### Tumor Generateion +import random +import cv2 +import elasticdeform +import numpy as np +from scipy.ndimage import gaussian_filter +from TumorGeneration.ldm.ddpm.ddim import DDIMSampler + +# Step 1: Random select (numbers) location for tumor. +def random_select(mask_scan, organ_type): + # we first find z index and then sample point with z slice + # print('mask_scan',np.unique(mask_scan)) + # print('pixel num', (mask_scan == 1).sum()) + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0].min(), np.where(np.any(mask_scan, axis=(0, 1)))[0].max() + # print('z_start, z_end',z_start, z_end) + # we need to strict number z's position (0.3 - 0.7 in the middle of liver) + while 1: + z = round(random.uniform(0.3, 0.7) * (z_end - z_start)) + z_start + liver_mask = mask_scan[..., z] + # erode the mask (we don't want the edge points) + if organ_type == 'liver': + kernel = np.ones((5,5), dtype=np.uint8) + liver_mask = cv2.erode(liver_mask, kernel, iterations=1) + if (liver_mask == 1).sum() > 0: + break + + + + # print('liver_mask', (liver_mask == 1).sum()) + coordinates = np.argwhere(liver_mask == 1) + random_index = np.random.randint(0, len(coordinates)) + xyz = coordinates[random_index].tolist() # get x,y + xyz.append(z) + potential_points = xyz + + return potential_points + +def center_select(mask_scan): + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0].min(), np.where(np.any(mask_scan, axis=(0, 1)))[0].max() + x_start, x_end = np.where(np.any(mask_scan, axis=(1, 2)))[0].min(), np.where(np.any(mask_scan, axis=(1, 2)))[0].max() + y_start, y_end = np.where(np.any(mask_scan, axis=(0, 2)))[0].min(), np.where(np.any(mask_scan, axis=(0, 2)))[0].max() + + z = round(0.5 * (z_end - z_start)) + z_start + x = round(0.5 * (x_end - x_start)) + x_start + y = round(0.5 * (y_end - y_start)) + y_start + + xyz = [x, y, z] + potential_points = xyz + + return potential_points + +# Step 2 : generate the ellipsoid +def get_ellipsoid(x, y, z): + """" + x, y, z is the radius of this ellipsoid in x, y, z direction respectly. + """ + sh = (4*x, 4*y, 4*z) + out = np.zeros(sh, int) + aux = np.zeros(sh) + radii = np.array([x, y, z]) + com = np.array([2*x, 2*y, 2*z]) # center point + + # calculate the ellipsoid + bboxl = np.floor(com-radii).clip(0,None).astype(int) + bboxh = (np.ceil(com+radii)+1).clip(None, sh).astype(int) + roi = out[tuple(map(slice,bboxl,bboxh))] + roiaux = aux[tuple(map(slice,bboxl,bboxh))] + logrid = *map(np.square,np.ogrid[tuple( + map(slice,(bboxl-com)/radii,(bboxh-com-1)/radii,1j*(bboxh-bboxl)))]), + dst = (1-sum(logrid)).clip(0,None) + mask = dst>roiaux + roi[mask] = 1 + np.copyto(roiaux,dst,where=mask) + + return out + +def get_fixed_geo(mask_scan, tumor_type, organ_type): + if tumor_type == 'large': + enlarge_x, enlarge_y, enlarge_z = 280, 280, 280 + else: + enlarge_x, enlarge_y, enlarge_z = 160, 160, 160 + geo_mask = np.zeros((mask_scan.shape[0] + enlarge_x, mask_scan.shape[1] + enlarge_y, mask_scan.shape[2] + enlarge_z), dtype=np.int8) + tiny_radius, small_radius, medium_radius, large_radius = 4, 8, 16, 32 + + if tumor_type == 'tiny': + num_tumor = random.randint(1,3) + # num_tumor = 1 + for _ in range(num_tumor): + # Tiny tumor + x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + sigma = random.uniform(0.5,1) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + + if tumor_type == 'small': + num_tumor = random.randint(1,3) + # num_tumor = 1 + for _ in range(num_tumor): + # Small tumor + x = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + y = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + z = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + sigma = random.randint(1, 2) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'medium': + # num_tumor = random.randint(1, 3) + num_tumor = 1 + for _ in range(num_tumor): + # medium tumor + x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + sigma = random.randint(3, 6) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste medium tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'large': + num_tumor = 1 # random.randint(1,3) + for _ in range(num_tumor): + # Large tumor + + x = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + y = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + z = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + sigma = random.randint(5, 10) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + if organ_type == 'liver' or organ_type == 'kidney' : + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + else: + x_start, x_end = np.where(np.any(geo, axis=(1, 2)))[0].min(), np.where(np.any(geo, axis=(1, 2)))[0].max() + y_start, y_end = np.where(np.any(geo, axis=(0, 2)))[0].min(), np.where(np.any(geo, axis=(0, 2)))[0].max() + z_start, z_end = np.where(np.any(geo, axis=(0, 1)))[0].min(), np.where(np.any(geo, axis=(0, 1)))[0].max() + geo = geo[x_start:x_end, y_start:y_end, z_start:z_end] + + point = center_select(mask_scan) + + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low = new_point[0] - geo.shape[0]//2 + y_low = new_point[1] - geo.shape[1]//2 + z_low = new_point[2] - geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_low+geo.shape[0], y_low:y_low+geo.shape[1], z_low:z_low+geo.shape[2]] += geo + + geo_mask = geo_mask[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + # texture_map = texture_map[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + if ((tumor_type == 'medium') or (tumor_type == 'large')) and (organ_type == 'kidney'): + if random.random() > 0.5: + geo_mask = (geo_mask>=1) + else: + geo_mask = (geo_mask * mask_scan) >=1 + else: + geo_mask = (geo_mask * mask_scan) >=1 + + return geo_mask + + +from .ldm.vq_gan_3d.model.vqgan import VQGAN +import matplotlib.pyplot as plt +import SimpleITK as sitk +from .ldm.ddpm import Unet3D, GaussianDiffusion, Tester +from hydra import initialize, compose +import torch +import yaml +def synt_model_prepare(device, vqgan_ckpt='TumorGeneration/model_weight/recon_96d4_all.ckpt', diffusion_ckpt='TumorGeneration/model_weight/', fold=0, organ='liver'): + with initialize(config_path="diffusion_config/"): + cfg=compose(config_name="ddpm.yaml") + print('diffusion_ckpt',diffusion_ckpt) + vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt) + vqgan = vqgan.to(device) + vqgan.eval() + + early_Unet3D = Unet3D( + dim=cfg.diffusion_img_size, + dim_mults=cfg.dim_mults, + channels=cfg.diffusion_num_channels, + out_dim=cfg.out_dim + ).to(device) + + early_diffusion = GaussianDiffusion( + early_Unet3D, + vqgan_ckpt= vqgan_ckpt, # cfg.vqgan_ckpt, + image_size=cfg.diffusion_img_size, + num_frames=cfg.diffusion_depth_size, + channels=cfg.diffusion_num_channels, + timesteps=4, # cfg.timesteps, + loss_type=cfg.loss_type, + device=device + ).to(device) + + noearly_Unet3D = Unet3D( + dim=cfg.diffusion_img_size, + dim_mults=cfg.dim_mults, + channels=cfg.diffusion_num_channels, + out_dim=cfg.out_dim + ).to(device) + + noearly_diffusion = GaussianDiffusion( + noearly_Unet3D, + vqgan_ckpt= vqgan_ckpt, # cfg.vqgan_ckpt, + image_size=cfg.diffusion_img_size, + num_frames=cfg.diffusion_depth_size, + channels=cfg.diffusion_num_channels, + timesteps=200, # cfg.timesteps, + loss_type=cfg.loss_type, + device=device + ).to(device) + + early_tester = Tester(early_diffusion) + # noearly_tester = Tester(noearly_diffusion) + early_tester.load(diffusion_ckpt+'diff_{}_fold{}_early.pt'.format(organ, fold), map_location=device) + # noearly_tester.load(diffusion_ckpt+'diff_liver_fold{}_noearly_t200.pt'.format(fold), map_location=device) + + # early_checkpoint = torch.load(diffusion_ckpt+'diff_liver_fold{}_early.pt'.format(fold), map_location=device) + noearly_checkpoint = torch.load(diffusion_ckpt+'diff_{}_fold{}_noearly_t200.pt'.format(organ, fold), map_location=device) + # early_diffusion.load_state_dict(early_checkpoint['ema']) + noearly_diffusion.load_state_dict(noearly_checkpoint['ema']) + # early_sampler = DDIMSampler(early_diffusion, schedule="cosine") + noearly_sampler = DDIMSampler(noearly_diffusion, schedule="cosine") + # breakpoint() + return vqgan, early_tester, noearly_sampler + +def synthesize_early_tumor(ct_volume, organ_mask, organ_type, vqgan, tester): + device=ct_volume.device + + # generate tumor mask + tumor_types = ['tiny', 'small'] + # tumor_probs = np.array([0.5, 0.5]) + tumor_probs = np.array([0.2, 0.8]) + total_tumor_mask = [] + organ_mask_np = organ_mask.cpu().numpy() + with torch.no_grad(): + # get model input + for bs in range(organ_mask_np.shape[0]): + synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type) + total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:]) + total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device) + + volume = ct_volume*2.0 - 1.0 + mask = total_tumor_mask*2.0 - 1.0 + mask_ = 1-total_tumor_mask + masked_volume = (volume*mask_).detach() + + volume = volume.permute(0,1,-1,-3,-2) + masked_volume = masked_volume.permute(0,1,-1,-3,-2) + mask = mask.permute(0,1,-1,-3,-2) + + # vqgan encoder inference + masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True) + masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) / + (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:]) + cond = torch.cat((masked_volume_feat, cc), dim=1) + + # diffusion inference and decoder + tester.ema_model.eval() + sample = tester.ema_model.sample(batch_size=volume.shape[0], cond=cond) + + # if organ_type == 'liver' or organ_type == 'kidney' : + + mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0) + sigma = np.random.uniform(0, 4) # (1, 2) + mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma]) + # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy() + + volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0) + sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0) + + mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device) + final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_ + final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0) + # elif organ_type == 'pancreas': + # final_volume_ = (sample+1.0)/2.0 + final_volume_ = final_volume_.permute(0,1,-2,-1,-3) + organ_tumor_mask = torch.ones_like(organ_mask) + organ_tumor_mask[total_tumor_mask==1] = 2 + + return final_volume_, organ_tumor_mask + +def synthesize_medium_tumor(ct_volume, organ_mask, organ_type, vqgan, sampler, ddim_ts=50): + device=ct_volume.device + + # generate tumor mask + # tumor_types = ['large'] + # tumor_probs = np.array([1.0]) + total_tumor_mask = [] + organ_mask_np = organ_mask.cpu().numpy() + with torch.no_grad(): + # get model input + for bs in range(organ_mask_np.shape[0]): + # synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + synthetic_tumor_type = 'medium' + tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type) + total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:]) + total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device) + + volume = ct_volume*2.0 - 1.0 + mask = total_tumor_mask*2.0 - 1.0 + mask_ = 1-total_tumor_mask + masked_volume = (volume*mask_).detach() + + volume = volume.permute(0,1,-1,-3,-2) + masked_volume = masked_volume.permute(0,1,-1,-3,-2) + mask = mask.permute(0,1,-1,-3,-2) + + # vqgan encoder inference + masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True) + masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) / + (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:]) + cond = torch.cat((masked_volume_feat, cc), dim=1) + + # diffusion inference and decoder + shape = masked_volume_feat.shape[-4:] + samples_ddim, _ = sampler.sample(S=ddim_ts, + conditioning=cond, + batch_size=1, + shape=shape, + verbose=False) + samples_ddim = (((samples_ddim + 1.0) / 2.0) * (vqgan.codebook.embeddings.max() - + vqgan.codebook.embeddings.min())) + vqgan.codebook.embeddings.min() + + sample = vqgan.decode(samples_ddim, quantize=True) + + # if organ_type == 'liver' or organ_type == 'kidney': + # post-process + mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0) + sigma = np.random.uniform(0, 4) # (1, 2) + mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma]) + # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy() + + volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0) + sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0) + + mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device) + final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_ + final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0) + # elif organ_type == 'pancreas': + # final_volume_ = (sample+1.0)/2.0 + + final_volume_ = final_volume_.permute(0,1,-2,-1,-3) + organ_tumor_mask = torch.ones_like(organ_mask) + organ_tumor_mask[total_tumor_mask==1] = 2 + + return final_volume_, organ_tumor_mask + +def synthesize_large_tumor(ct_volume, organ_mask, organ_type, vqgan, sampler, ddim_ts=50): + device=ct_volume.device + + # generate tumor mask + # tumor_types = ['large'] + # tumor_probs = np.array([1.0]) + total_tumor_mask = [] + organ_mask_np = organ_mask.cpu().numpy() + with torch.no_grad(): + # get model input + for bs in range(organ_mask_np.shape[0]): + # synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + synthetic_tumor_type = 'large' + tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type) + total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:]) + total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device) + + volume = ct_volume*2.0 - 1.0 + mask = total_tumor_mask*2.0 - 1.0 + mask_ = 1-total_tumor_mask + masked_volume = (volume*mask_).detach() + + volume = volume.permute(0,1,-1,-3,-2) + masked_volume = masked_volume.permute(0,1,-1,-3,-2) + mask = mask.permute(0,1,-1,-3,-2) + + # vqgan encoder inference + masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True) + masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) / + (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:]) + cond = torch.cat((masked_volume_feat, cc), dim=1) + + # diffusion inference and decoder + shape = masked_volume_feat.shape[-4:] + samples_ddim, _ = sampler.sample(S=ddim_ts, + conditioning=cond, + batch_size=1, + shape=shape, + verbose=False) + samples_ddim = (((samples_ddim + 1.0) / 2.0) * (vqgan.codebook.embeddings.max() - + vqgan.codebook.embeddings.min())) + vqgan.codebook.embeddings.min() + + sample = vqgan.decode(samples_ddim, quantize=True) + + # if organ_type == 'liver' or organ_type == 'kidney': + # post-process + mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0) + sigma = np.random.uniform(0, 4) # (1, 2) + mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma]) + # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy() + + volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0) + sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0) + + mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device) + final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_ + final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0) + # elif organ_type == 'pancreas': + # final_volume_ = (sample+1.0)/2.0 + + final_volume_ = final_volume_.permute(0,1,-2,-1,-3) + organ_tumor_mask = torch.ones_like(organ_mask) + organ_tumor_mask[total_tumor_mask==1] = 2 + + return final_volume_, organ_tumor_mask \ No newline at end of file diff --git a/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/utils_.py b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/utils_.py new file mode 100644 index 0000000000000000000000000000000000000000..312e72d1571b3cd996fd567bcedf939443a4e182 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/TumorGeneration/utils_.py @@ -0,0 +1,298 @@ +### Tumor Generateion +import random +import cv2 +import elasticdeform +import numpy as np +from scipy.ndimage import gaussian_filter + +# Step 1: Random select (numbers) location for tumor. +def random_select(mask_scan): + # we first find z index and then sample point with z slice + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0][[0, -1]] + + # we need to strict number z's position (0.3 - 0.7 in the middle of liver) + z = round(random.uniform(0.3, 0.7) * (z_end - z_start)) + z_start + + liver_mask = mask_scan[..., z] + + # erode the mask (we don't want the edge points) + kernel = np.ones((5,5), dtype=np.uint8) + liver_mask = cv2.erode(liver_mask, kernel, iterations=1) + + coordinates = np.argwhere(liver_mask == 1) + random_index = np.random.randint(0, len(coordinates)) + xyz = coordinates[random_index].tolist() # get x,y + xyz.append(z) + potential_points = xyz + + return potential_points + +# Step 2 : generate the ellipsoid +def get_ellipsoid(x, y, z): + """" + x, y, z is the radius of this ellipsoid in x, y, z direction respectly. + """ + sh = (4*x, 4*y, 4*z) + out = np.zeros(sh, int) + aux = np.zeros(sh) + radii = np.array([x, y, z]) + com = np.array([2*x, 2*y, 2*z]) # center point + + # calculate the ellipsoid + bboxl = np.floor(com-radii).clip(0,None).astype(int) + bboxh = (np.ceil(com+radii)+1).clip(None, sh).astype(int) + roi = out[tuple(map(slice,bboxl,bboxh))] + roiaux = aux[tuple(map(slice,bboxl,bboxh))] + logrid = *map(np.square,np.ogrid[tuple( + map(slice,(bboxl-com)/radii,(bboxh-com-1)/radii,1j*(bboxh-bboxl)))]), + dst = (1-sum(logrid)).clip(0,None) + mask = dst>roiaux + roi[mask] = 1 + np.copyto(roiaux,dst,where=mask) + + return out + +def get_fixed_geo(mask_scan, tumor_type): + + enlarge_x, enlarge_y, enlarge_z = 160, 160, 160 + geo_mask = np.zeros((mask_scan.shape[0] + enlarge_x, mask_scan.shape[1] + enlarge_y, mask_scan.shape[2] + enlarge_z), dtype=np.int8) + tiny_radius, small_radius, medium_radius, large_radius = 4, 8, 16, 32 + + if tumor_type == 'tiny': + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Tiny tumor + x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + sigma = random.uniform(0.5,1) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + + if tumor_type == 'small': + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Small tumor + x = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + y = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + z = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + sigma = random.randint(1, 2) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'medium': + num_tumor = random.randint(2, 5) + for _ in range(num_tumor): + # medium tumor + x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + sigma = random.randint(3, 6) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste medium tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'large': + num_tumor = random.randint(1,3) + for _ in range(num_tumor): + # Large tumor + x = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + y = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + z = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + sigma = random.randint(5, 10) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == "mix": + # tiny + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Tiny tumor + x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + sigma = random.uniform(0.5,1) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + + # small + num_tumor = random.randint(5,10) + for _ in range(num_tumor): + # Small tumor + x = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + y = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + z = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + sigma = random.randint(1, 2) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + # medium + num_tumor = random.randint(2, 5) + for _ in range(num_tumor): + # medium tumor + x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + sigma = random.randint(3, 6) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste medium tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + # large + num_tumor = random.randint(1,3) + for _ in range(num_tumor): + # Large tumor + x = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + y = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + z = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + sigma = random.randint(5, 10) + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + geo_mask = geo_mask[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + # texture_map = texture_map[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + geo_mask = (geo_mask * mask_scan) >=1 + + return geo_mask + + +def get_tumor(volume_scan, mask_scan, tumor_type): + tumor_mask = get_fixed_geo(mask_scan, tumor_type) + + sigma = np.random.uniform(1, 2) + # difference = np.random.uniform(65, 145) + difference = 1 + + # blur the boundary + tumor_mask_blur = gaussian_filter(tumor_mask*1.0, sigma) + + + abnormally_full = volume_scan * (1 - mask_scan) + abnormally + abnormally_mask = mask_scan + geo_mask + + return abnormally_full, abnormally_mask + +def SynthesisTumor(volume_scan, mask_scan, tumor_type): + # for speed_generate_tumor, we only send the liver part into the generate program + x_start, x_end = np.where(np.any(mask_scan, axis=(1, 2)))[0][[0, -1]] + y_start, y_end = np.where(np.any(mask_scan, axis=(0, 2)))[0][[0, -1]] + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0][[0, -1]] + + # shrink the boundary + x_start, x_end = max(0, x_start+1), min(mask_scan.shape[0], x_end-1) + y_start, y_end = max(0, y_start+1), min(mask_scan.shape[1], y_end-1) + z_start, z_end = max(0, z_start+1), min(mask_scan.shape[2], z_end-1) + + ct_volume = volume_scan[x_start:x_end, y_start:y_end, z_start:z_end] + organ_mask = mask_scan[x_start:x_end, y_start:y_end, z_start:z_end] + + # input texture shape: 420, 300, 320 + # we need to cut it into the shape of liver_mask + # for examples, the liver_mask.shape = 286, 173, 46; we should change the texture shape + x_length, y_length, z_length = 64, 64, 64 + crop_x = random.randint(x_start, x_end - x_length - 1) # random select the start point, -1 is to avoid boundary check + crop_y = random.randint(y_start, y_end - y_length - 1) + crop_z = random.randint(z_start, z_end - z_length - 1) + + ct_volume, organ_tumor_mask = get_tumor(ct_volume, organ_mask, tumor_type) + volume_scan[x_start:x_end, y_start:y_end, z_start:z_end] = ct_volume + mask_scan[x_start:x_end, y_start:y_end, z_start:z_end] = organ_tumor_mask + + return volume_scan, mask_scan diff --git a/Generation_Pipeline_filter_all/syn_kidney/healthy_kidney_1k.txt b/Generation_Pipeline_filter_all/syn_kidney/healthy_kidney_1k.txt new file mode 100644 index 0000000000000000000000000000000000000000..f9487280079db99e7abc891c32997dfa6f4e6751 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/healthy_kidney_1k.txt @@ -0,0 +1,565 @@ +BDMAP_00002275 +BDMAP_00001907 +BDMAP_00002712 +BDMAP_00004615 +BDMAP_00004651 +BDMAP_00002230 +BDMAP_00002955 +BDMAP_00004183 +BDMAP_00002304 +BDMAP_00002029 +BDMAP_00001646 +BDMAP_00002909 +BDMAP_00002328 +BDMAP_00004829 +BDMAP_00001093 +BDMAP_00002117 +BDMAP_00004600 +BDMAP_00003771 +BDMAP_00001198 +BDMAP_00003451 +BDMAP_00002719 +BDMAP_00002846 +BDMAP_00002282 +BDMAP_00003827 +BDMAP_00001649 +BDMAP_00005141 +BDMAP_00000941 +BDMAP_00002875 +BDMAP_00004641 +BDMAP_00003373 +BDMAP_00001924 +BDMAP_00003897 +BDMAP_00005074 +BDMAP_00001753 +BDMAP_00000101 +BDMAP_00003412 +BDMAP_00002945 +BDMAP_00002598 +BDMAP_00004858 +BDMAP_00001632 +BDMAP_00003327 +BDMAP_00005130 +BDMAP_00004783 +BDMAP_00002844 +BDMAP_00002479 +BDMAP_00001464 +BDMAP_00001809 +BDMAP_00003385 +BDMAP_00003918 +BDMAP_00004995 +BDMAP_00004447 +BDMAP_00003972 +BDMAP_00003438 +BDMAP_00003898 +BDMAP_00001057 +BDMAP_00005005 +BDMAP_00003244 +BDMAP_00003631 +BDMAP_00004103 +BDMAP_00000069 +BDMAP_00001736 +BDMAP_00003002 +BDMAP_00004704 +BDMAP_00001055 +BDMAP_00000447 +BDMAP_00000778 +BDMAP_00005097 +BDMAP_00004264 +BDMAP_00004304 +BDMAP_00005170 +BDMAP_00000547 +BDMAP_00004764 +BDMAP_00004229 +BDMAP_00001414 +BDMAP_00001828 +BDMAP_00003151 +BDMAP_00003769 +BDMAP_00001962 +BDMAP_00003333 +BDMAP_00000676 +BDMAP_00001704 +BDMAP_00004459 +BDMAP_00003683 +BDMAP_00003439 +BDMAP_00004016 +BDMAP_00000438 +BDMAP_00004117 +BDMAP_00001785 +BDMAP_00002688 +BDMAP_00000913 +BDMAP_00000942 +BDMAP_00003400 +BDMAP_00003824 +BDMAP_00000470 +BDMAP_00002918 +BDMAP_00002828 +BDMAP_00004286 +BDMAP_00001845 +BDMAP_00002791 +BDMAP_00004672 +BDMAP_00002717 +BDMAP_00002856 +BDMAP_00002188 +BDMAP_00001701 +BDMAP_00001175 +BDMAP_00002841 +BDMAP_00003254 +BDMAP_00004508 +BDMAP_00000373 +BDMAP_00001565 +BDMAP_00002214 +BDMAP_00000701 +BDMAP_00000690 +BDMAP_00001215 +BDMAP_00000324 +BDMAP_00004015 +BDMAP_00004196 +BDMAP_00001419 +BDMAP_00000618 +BDMAP_00003640 +BDMAP_00001697 +BDMAP_00000332 +BDMAP_00004023 +BDMAP_00002815 +BDMAP_00004199 +BDMAP_00003890 +BDMAP_00002529 +BDMAP_00004843 +BDMAP_00002076 +BDMAP_00004895 +BDMAP_00000623 +BDMAP_00002244 +BDMAP_00000205 +BDMAP_00001185 +BDMAP_00003133 +BDMAP_00001957 +BDMAP_00001015 +BDMAP_00003932 +BDMAP_00001010 +BDMAP_00001102 +BDMAP_00004880 +BDMAP_00004664 +BDMAP_00002748 +BDMAP_00000430 +BDMAP_00004293 +BDMAP_00002829 +BDMAP_00000558 +BDMAP_00000084 +BDMAP_00001438 +BDMAP_00001917 +BDMAP_00004129 +BDMAP_00000232 +BDMAP_00002463 +BDMAP_00004839 +BDMAP_00003664 +BDMAP_00004604 +BDMAP_00002021 +BDMAP_00004550 +BDMAP_00004106 +BDMAP_00004128 +BDMAP_00000696 +BDMAP_00002411 +BDMAP_00003569 +BDMAP_00001912 +BDMAP_00003036 +BDMAP_00001288 +BDMAP_00002216 +BDMAP_00002199 +BDMAP_00000100 +BDMAP_00003634 +BDMAP_00000345 +BDMAP_00000614 +BDMAP_00001769 +BDMAP_00002580 +BDMAP_00004676 +BDMAP_00000388 +BDMAP_00003357 +BDMAP_00004431 +BDMAP_00002359 +BDMAP_00000132 +BDMAP_00004097 +BDMAP_00003847 +BDMAP_00003017 +BDMAP_00003680 +BDMAP_00001737 +BDMAP_00003361 +BDMAP_00003377 +BDMAP_00000437 +BDMAP_00002237 +BDMAP_00003900 +BDMAP_00001754 +BDMAP_00004288 +BDMAP_00002612 +BDMAP_00003329 +BDMAP_00004187 +BDMAP_00000873 +BDMAP_00003525 +BDMAP_00000921 +BDMAP_00004231 +BDMAP_00001343 +BDMAP_00004793 +BDMAP_00001898 +BDMAP_00002271 +BDMAP_00002313 +BDMAP_00002896 +BDMAP_00000851 +BDMAP_00004165 +BDMAP_00003840 +BDMAP_00000338 +BDMAP_00000715 +BDMAP_00004295 +BDMAP_00000236 +BDMAP_00001985 +BDMAP_00003633 +BDMAP_00004825 +BDMAP_00002305 +BDMAP_00001237 +BDMAP_00002419 +BDMAP_00001766 +BDMAP_00004546 +BDMAP_00000881 +BDMAP_00001836 +BDMAP_00003052 +BDMAP_00001502 +BDMAP_00003483 +BDMAP_00003396 +BDMAP_00005119 +BDMAP_00003299 +BDMAP_00000568 +BDMAP_00003590 +BDMAP_00002616 +BDMAP_00001835 +BDMAP_00002172 +BDMAP_00004964 +BDMAP_00002944 +BDMAP_00002465 +BDMAP_00002227 +BDMAP_00001905 +BDMAP_00002603 +BDMAP_00003111 +BDMAP_00004398 +BDMAP_00002373 +BDMAP_00000093 +BDMAP_00001247 +BDMAP_00003172 +BDMAP_00001865 +BDMAP_00001545 +BDMAP_00000411 +BDMAP_00002349 +BDMAP_00001617 +BDMAP_00003884 +BDMAP_00000809 +BDMAP_00003497 +BDMAP_00003961 +BDMAP_00005139 +BDMAP_00001628 +BDMAP_00004969 +BDMAP_00004228 +BDMAP_00001316 +BDMAP_00005160 +BDMAP_00001024 +BDMAP_00005073 +BDMAP_00001209 +BDMAP_00004954 +BDMAP_00003798 +BDMAP_00005063 +BDMAP_00001476 +BDMAP_00000243 +BDMAP_00003809 +BDMAP_00001309 +BDMAP_00003886 +BDMAP_00002758 +BDMAP_00002289 +BDMAP_00001862 +BDMAP_00004804 +BDMAP_00003113 +BDMAP_00001361 +BDMAP_00000692 +BDMAP_00001523 +BDMAP_00004115 +BDMAP_00002387 +BDMAP_00003781 +BDMAP_00000087 +BDMAP_00001823 +BDMAP_00000940 +BDMAP_00004719 +BDMAP_00004624 +BDMAP_00002849 +BDMAP_00003657 +BDMAP_00001461 +BDMAP_00002690 +BDMAP_00003236 +BDMAP_00004558 +BDMAP_00004639 +BDMAP_00004541 +BDMAP_00005083 +BDMAP_00000907 +BDMAP_00000972 +BDMAP_00001200 +BDMAP_00003168 +BDMAP_00000828 +BDMAP_00004450 +BDMAP_00001597 +BDMAP_00003867 +BDMAP_00001746 +BDMAP_00002252 +BDMAP_00002947 +BDMAP_00004878 +BDMAP_00001842 +BDMAP_00002654 +BDMAP_00002185 +BDMAP_00001802 +BDMAP_00001040 +BDMAP_00004198 +BDMAP_00000831 +BDMAP_00004491 +BDMAP_00003109 +BDMAP_00002120 +BDMAP_00001834 +BDMAP_00002619 +BDMAP_00000138 +BDMAP_00004773 +BDMAP_00001236 +BDMAP_00002402 +BDMAP_00001598 +BDMAP_00000714 +BDMAP_00003356 +BDMAP_00000462 +BDMAP_00001114 +BDMAP_00000607 +BDMAP_00004297 +BDMAP_00004841 +BDMAP_00005022 +BDMAP_00000572 +BDMAP_00000541 +BDMAP_00005140 +BDMAP_00004415 +BDMAP_00003946 +BDMAP_00003319 +BDMAP_00003510 +BDMAP_00004163 +BDMAP_00002458 +BDMAP_00005020 +BDMAP_00004511 +BDMAP_00004549 +BDMAP_00005155 +BDMAP_00004147 +BDMAP_00004876 +BDMAP_00002103 +BDMAP_00000882 +BDMAP_00003138 +BDMAP_00005037 +BDMAP_00003853 +BDMAP_00002039 +BDMAP_00000774 +BDMAP_00004741 +BDMAP_00001171 +BDMAP_00004636 +BDMAP_00002332 +BDMAP_00004894 +BDMAP_00002730 +BDMAP_00001125 +BDMAP_00003822 +BDMAP_00003592 +BDMAP_00001368 +BDMAP_00003513 +BDMAP_00003612 +BDMAP_00005169 +BDMAP_00004017 +BDMAP_00002855 +BDMAP_00000152 +BDMAP_00000091 +BDMAP_00004529 +BDMAP_00003443 +BDMAP_00003543 +BDMAP_00002267 +BDMAP_00004462 +BDMAP_00000874 +BDMAP_00002793 +BDMAP_00001471 +BDMAP_00001605 +BDMAP_00000709 +BDMAP_00004435 +BDMAP_00003524 +BDMAP_00000965 +BDMAP_00000939 +BDMAP_00002278 +BDMAP_00002295 +BDMAP_00000971 +BDMAP_00004917 +BDMAP_00003812 +BDMAP_00002401 +BDMAP_00003074 +BDMAP_00004028 +BDMAP_00001982 +BDMAP_00004281 +BDMAP_00000347 +BDMAP_00001732 +BDMAP_00001205 +BDMAP_00001379 +BDMAP_00001095 +BDMAP_00004770 +BDMAP_00002283 +BDMAP_00000052 +BDMAP_00000192 +BDMAP_00003564 +BDMAP_00003427 +BDMAP_00004888 +BDMAP_00005016 +BDMAP_00004745 +BDMAP_00001078 +BDMAP_00001122 +BDMAP_00001584 +BDMAP_00003551 +BDMAP_00002495 +BDMAP_00000589 +BDMAP_00005065 +BDMAP_00002171 +BDMAP_00004830 +BDMAP_00001804 +BDMAP_00004493 +BDMAP_00000400 +BDMAP_00000745 +BDMAP_00001333 +BDMAP_00004890 +BDMAP_00002845 +BDMAP_00001875 +BDMAP_00001096 +BDMAP_00004060 +BDMAP_00002451 +BDMAP_00002523 +BDMAP_00002899 +BDMAP_00000642 +BDMAP_00005075 +BDMAP_00003685 +BDMAP_00004650 +BDMAP_00001618 +BDMAP_00000771 +BDMAP_00003920 +BDMAP_00002309 +BDMAP_00004847 +BDMAP_00002485 +BDMAP_00001590 +BDMAP_00001692 +BDMAP_00003502 +BDMAP_00000431 +BDMAP_00000679 +BDMAP_00002986 +BDMAP_00003277 +BDMAP_00004885 +BDMAP_00000427 +BDMAP_00000716 +BDMAP_00003744 +BDMAP_00001806 +BDMAP_00003857 +BDMAP_00000859 +BDMAP_00001067 +BDMAP_00004121 +BDMAP_00002475 +BDMAP_00002318 +BDMAP_00003114 +BDMAP_00001712 +BDMAP_00001214 +BDMAP_00000362 +BDMAP_00001441 +BDMAP_00003272 +BDMAP_00000956 +BDMAP_00005064 +BDMAP_00000154 +BDMAP_00005186 +BDMAP_00003658 +BDMAP_00002704 +BDMAP_00004796 +BDMAP_00000197 +BDMAP_00005070 +BDMAP_00005001 +BDMAP_00000480 +BDMAP_00005078 +BDMAP_00001564 +BDMAP_00001025 +BDMAP_00003598 +BDMAP_00004262 +BDMAP_00001092 +BDMAP_00004185 +BDMAP_00003776 +BDMAP_00001270 +BDMAP_00000615 +BDMAP_00003141 +BDMAP_00003330 +BDMAP_00000190 +BDMAP_00003650 +BDMAP_00001397 +BDMAP_00005185 +BDMAP_00001966 +BDMAP_00004184 +BDMAP_00004992 +BDMAP_00004416 +BDMAP_00000993 +BDMAP_00001445 +BDMAP_00003482 +BDMAP_00004514 +BDMAP_00001504 +BDMAP_00000416 +BDMAP_00002805 +BDMAP_00002232 +BDMAP_00004384 +BDMAP_00001921 +BDMAP_00001426 +BDMAP_00004910 +BDMAP_00003560 +BDMAP_00003130 +BDMAP_00005108 +BDMAP_00000113 +BDMAP_00001521 +BDMAP_00003556 +BDMAP_00003376 +BDMAP_00000273 +BDMAP_00004735 +BDMAP_00001539 +BDMAP_00004494 +BDMAP_00001212 +BDMAP_00005067 +BDMAP_00000413 +BDMAP_00002863 +BDMAP_00000671 +BDMAP_00004927 +BDMAP_00002167 +BDMAP_00002152 +BDMAP_00005168 +BDMAP_00003911 +BDMAP_00002250 +BDMAP_00003215 +BDMAP_00002737 +BDMAP_00001514 +BDMAP_00003440 +BDMAP_00003031 +BDMAP_00001786 +BDMAP_00000552 +BDMAP_00004943 +BDMAP_00003268 +BDMAP_00002233 +BDMAP_00002362 +BDMAP_00001440 +BDMAP_00000225 +BDMAP_00003347 +BDMAP_00002739 +BDMAP_00003479 +BDMAP_00003481 +BDMAP_00003326 +BDMAP_00000683 +BDMAP_00004378 +BDMAP_00003367 +BDMAP_00000855 +BDMAP_00002298 +BDMAP_00004077 +BDMAP_00002253 +BDMAP_00001331 +BDMAP_00000542 +BDMAP_00002924 +BDMAP_00005092 +BDMAP_00004374 +BDMAP_00004509 +BDMAP_00000264 +BDMAP_00000918 +BDMAP_00000030 diff --git a/Generation_Pipeline_filter_all/syn_kidney/requirements.txt b/Generation_Pipeline_filter_all/syn_kidney/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9c1067c8eb4f44699bbf517601396113cea4b370 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_kidney/requirements.txt @@ -0,0 +1,94 @@ +absl-py==1.1.0 +accelerate==0.11.0 +aiohttp==3.8.1 +aiosignal==1.2.0 +antlr4-python3-runtime==4.9.3 +async-timeout==4.0.2 +attrs==21.4.0 +autopep8==1.6.0 +cachetools==5.2.0 +certifi==2022.6.15 +charset-normalizer==2.0.12 +click==8.1.3 +cycler==0.11.0 +Deprecated==1.2.13 +docker-pycreds==0.4.0 +einops==0.4.1 +einops-exts==0.0.3 +ema-pytorch==0.0.8 +fonttools==4.34.4 +frozenlist==1.3.0 +fsspec==2022.5.0 +ftfy==6.1.1 +future==0.18.2 +gitdb==4.0.9 +GitPython==3.1.27 +google-auth==2.9.0 +google-auth-oauthlib==0.4.6 +grpcio==1.47.0 +h5py==3.7.0 +humanize==4.2.2 +hydra-core==1.2.0 +idna==3.3 +imageio==2.19.3 +imageio-ffmpeg==0.4.7 +importlib-metadata==4.12.0 +importlib-resources==5.9.0 +joblib==1.1.0 +kiwisolver==1.4.3 +lxml==4.9.1 +Markdown==3.3.7 +matplotlib==3.5.2 +multidict==6.0.2 +networkx==2.8.5 +nibabel==4.0.1 +nilearn==0.9.1 +numpy==1.23.0 +oauthlib==3.2.0 +omegaconf==2.2.3 +pandas==1.4.3 +Pillow==9.1.1 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pycodestyle==2.8.0 +pyDeprecate==0.3.1 +pydicom==2.3.0 +pytorch-lightning==1.6.4 +pytz==2022.1 +PyWavelets==1.3.0 +PyYAML==6.0 +pyzmq==19.0.2 +regex==2022.6.2 +requests==2.28.0 +requests-oauthlib==1.3.1 +rotary-embedding-torch==0.1.5 +rsa==4.8 +scikit-image==0.19.3 +scikit-learn==1.1.2 +scikit-video==1.1.11 +scipy==1.8.1 +seaborn==0.11.2 +sentry-sdk==1.7.2 +setproctitle==1.2.3 +shortuuid==1.0.9 +SimpleITK==2.1.1.2 +smmap==5.0.0 +tensorboard==2.9.1 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.1 +threadpoolctl==3.1.0 +tifffile==2022.8.3 +toml==0.10.2 +torch-tb-profiler==0.4.0 +torchio==0.18.80 +torchmetrics==0.9.1 +tqdm==4.64.0 +typing_extensions==4.2.0 +urllib3==1.26.9 +wandb==0.12.21 +Werkzeug==2.1.2 +wrapt==1.14.1 +yarl==1.7.2 +zipp==3.8.0 +wandb +tensorboardX==2.4.1 diff --git a/Generation_Pipeline_filter_all/syn_liver/CT_syn_data.py b/Generation_Pipeline_filter_all/syn_liver/CT_syn_data.py new file mode 100644 index 0000000000000000000000000000000000000000..aa2343cd6488587bfd05f76d0112d79192af1563 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_liver/CT_syn_data.py @@ -0,0 +1,242 @@ +import os, time, csv +import numpy as np +import torch +from sklearn.metrics import confusion_matrix +from scipy import ndimage +from scipy.ndimage import label +from functools import partial +import monai +from monai.transforms import AsDiscrete,AsDiscreted,Compose,Invertd,SaveImaged +from monai import transforms, data +from TumorGeneration.utils import synthesize_early_tumor, synthesize_medium_tumor, synthesize_large_tumor, synt_model_prepare +import nibabel as nib + +import warnings +warnings.filterwarnings("ignore") + +import argparse +parser = argparse.ArgumentParser(description='liver tumor validation') + +# file dir +parser.add_argument('--data_root', default=None, type=str) +parser.add_argument('--organ_type', default='liver', type=str) +parser.add_argument('--save_dir', default='out', type=str) +parser.add_argument('--data_file', default='out', type=str) +parser.add_argument('--ddim_ts', default=50, type=int) +parser.add_argument('--fg_thresh', default=30, type=int) +parser.add_argument('--start', default=0, type=int) +parser.add_argument('--end', default=1000, type=int) + +def voxel2R(A): + return (np.array(A)/4*3/np.pi)**(1/3) + +class RandCropByPosNegLabeld_select(transforms.RandCropByPosNegLabeld): + def __init__(self, keys, label_key, spatial_size, + pos=1.0, neg=1.0, num_samples=1, + image_key=None, image_threshold=0.0, allow_missing_keys=True, + fg_thresh=0): + super().__init__(keys=keys, label_key=label_key, spatial_size=spatial_size, + pos=pos, neg=neg, num_samples=num_samples, + image_key=image_key, image_threshold=image_threshold, allow_missing_keys=allow_missing_keys) + self.fg_thresh = fg_thresh + + def R2voxel(self,R): + return (4/3*np.pi)*(R)**(3) + + def __call__(self, data): + d = dict(data) + data_name = d['name'] + d.pop('name') + + if '10_Decathlon' in data_name or '05_KiTS' in data_name: + d_crop = super().__call__(d) + + else: + flag=0 + while 1: + flag+=1 + + d_crop = super().__call__(d) + pixel_num = (d_crop[0]['label']>0).sum() + + if pixel_num > self.R2voxel(self.fg_thresh): + break + if flag>5 and pixel_num > self.R2voxel(max(self.fg_thresh-5, 5)): + break + if flag>10 and pixel_num > self.R2voxel(max(self.fg_thresh-10, 5)): + break + if flag>15 and pixel_num > self.R2voxel(max(self.fg_thresh-15, 5)): + break + if flag>20 and pixel_num > self.R2voxel(max(self.fg_thresh-20, 5)): + break + if flag>25 and pixel_num > self.R2voxel(max(self.fg_thresh-25, 5)): + break + if flag>50: + break + + d_crop[0]['name'] = data_name + + return d_crop + +def _get_loader(args): + # val_data_dir = args.val_dir + # datalist_json = args.json_dir + val_org_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label", "raw_image"]), + transforms.AddChanneld(keys=["image", "label", "raw_image"]), + transforms.Orientationd(keys=["image", "label"], axcodes="RAS"), + transforms.Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear")), + transforms.ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), + transforms.SpatialPadd(keys=["image", "label"], mode=["minimum", "constant"], spatial_size=[96, 96, 96]), + RandCropByPosNegLabeld_select( + keys=["image", "label", "name"], + label_key="label", + spatial_size=(96, 96, 96), + pos=1, + neg=0, + num_samples=1, + image_key="image", + image_threshold=0, + fg_thresh = args.fg_thresh, + ), + transforms.ToTensord(keys=["image", "label", "raw_image"]), + ] + ) + + val_img=[] + val_lbl=[] + val_name=[] + + for line in open(args.data_file): + # name = line.strip().split()[1].split('.')[0] + # val_img.append(args.data_root + line.strip().split()[0]) + # val_lbl.append(args.data_root + line.strip().split()[1]) + # breakpoint() + name = line.strip() + val_img.append(os.path.join(args.data_root, name, 'ct.nii.gz')) + val_lbl.append(os.path.join(args.data_root, name, 'segmentations/liver.nii.gz')) + val_name.append(name) + data_dicts_val = [{'image': image, 'raw_image':image, 'label': label, 'name': name} + for image, label, name in zip(val_img, val_lbl, val_name)] + + if args.end < len(data_dicts_val): + data_dicts_val = data_dicts_val[args.start:args.end] + else: + data_dicts_val = data_dicts_val[args.start:] + print('val len {}'.format(len(data_dicts_val))) + val_org_ds = data.Dataset(data_dicts_val, transform=val_org_transform) + val_org_loader = data.DataLoader(val_org_ds, batch_size=1, shuffle=False, num_workers=4, sampler=None, pin_memory=True) + + post_transforms = Compose([ + Invertd( + keys=['image'], + transform=val_org_transform, + orig_keys="image", + nearest_interp=False, + # nearest_interp=True, + to_tensor=True, + ), + Invertd( + keys=['label'], + transform=val_org_transform, + orig_keys="label", + nearest_interp=False, + # nearest_interp=True, + to_tensor=True, + ) + ]) + return val_org_loader, post_transforms + +def main(): + args = parser.parse_args() + output_dir = args.save_dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + print("MAIN Argument values:") + for k, v in vars(args).items(): + print(k, '=>', v) + print('-----------------') + + ## loader and post_transform + val_loader, post_transforms = _get_loader(args) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device) + model.load_state_dict(torch.load("../best_metric_model_classification3d_dict.pth")) + model.eval() + + start_time=0 + with torch.no_grad(): + for idx, val_data in enumerate(val_loader): + print('idx',idx) + if idx == 0: + start_time = time.time() + # val_inputs = val_data["image"] + # name = val_data['label_meta_dict']['filename_or_obj'][0].split('/')[-1].split('.')[0] + + vqgan, early_sampler, noearly_sampler= synt_model_prepare(device = torch.device("cuda"), fold=0, organ=args.organ_type) + + healthy_data, healthy_target, data_names, raw_data = val_data['image'], val_data['label'], val_data['name'], val_data['raw_image'] + case_name = data_names[0].split('/')[-1] + print('case_name', case_name) + original_affine = val_data["label_meta_dict"]["original_affine"][0].numpy() + if healthy_target.sum() == 0: + val_data = [post_transforms(i) for i in data.decollate_batch(val_data)] + tumor_mask = val_data[0]['label'][0].cpu().numpy().astype(np.uint8) + tumor_mask_ = np.zeros_like(tumor_mask) + nib.save(nib.Nifti1Image(tumor_mask_, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/liver_tumor.nii.gz')) + continue + + healthy_data, healthy_target = healthy_data.cuda(), healthy_target.cuda() + healthy_target = (healthy_target==1).to(healthy_target) + + tumor_types = ['early', 'medium', 'large'] + # tumor_probs = np.array([0.45, 0.45, 0.1]) + # tumor_probs = np.array([1.0, 0.0, 0.0]) + tumor_probs = np.array([0.5, 0.4, 0.1]) + synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + print('synthetic_tumor_type',synthetic_tumor_type) + flag=0 + while 1: + flag+=1 + if synthetic_tumor_type == 'early': + synt_data, synt_target = synthesize_early_tumor(healthy_data, healthy_target, args.organ_type, vqgan, early_sampler) + elif synthetic_tumor_type == 'medium': + synt_data, synt_target = synthesize_medium_tumor(healthy_data, healthy_target, args.organ_type, vqgan, noearly_sampler, ddim_ts=args.ddim_ts) + elif synthetic_tumor_type == 'large': + synt_data, synt_target = synthesize_large_tumor(healthy_data, healthy_target, args.organ_type, vqgan, noearly_sampler, ddim_ts=args.ddim_ts) + + syn_confidence = model(synt_data).sigmoid()[:,1] + if syn_confidence>0.01: + break + elif flag > 20 and syn_confidence>0.005: + break + elif flag > 40 and syn_confidence>0.001: + break + val_data['image'] = synt_data.detach() + val_data['label'] = synt_target.detach() + + val_data = [post_transforms(i) for i in data.decollate_batch(val_data)] + synt_data = val_data[0]['image'][0] + synt_target = val_data[0]['label'][0] + final_data = raw_data[0,0] + + synt_data = (synt_data*(250+175)-175) + final_data[synt_target>1] = synt_data[synt_target>1] + final_data = final_data.cpu().numpy() + final_label = (synt_target>=1.5).cpu().numpy().astype(np.uint8) + + os.makedirs(os.path.join(output_dir, f'{case_name}'), exist_ok=True) + os.makedirs(os.path.join(output_dir, f'{case_name}/segmentations'), exist_ok=True) + nib.save(nib.Nifti1Image(final_data, original_affine), os.path.join(output_dir, f'{case_name}/ct.nii.gz')) + nib.save(nib.Nifti1Image(final_label, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/liver_tumor.nii.gz')) + # breakpoint() + # nib.save(nib.Nifti1Image(synt_data.cpu().numpy(), original_affine), os.path.join(output_dir, 'synt_data.nii.gz')) + # nib.save(nib.Nifti1Image(synt_target.cpu().numpy(), original_affine), os.path.join(output_dir, 'synt_target.nii.gz')) + print('time = ', time.time()-start_time) + start_time = time.time() + + # breakpoint() +if __name__ == "__main__": + main() diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/.DS_Store b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..fb0d207e124cfecc777faae5b4a56c5ca1b9cd2e Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/.DS_Store differ diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/README.md b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/README.md new file mode 100644 index 0000000000000000000000000000000000000000..201911031ca988c410781a60dc3c192a65ee56b3 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/README.md @@ -0,0 +1,5 @@ +```bash +wget https://www.dropbox.com/scl/fi/k856fhk60kck8uqxtxazw/model_weight.tar.gz?rlkey=hrcn4cbt690dzern1bkbfejll +mv model_weight.tar.gz?rlkey=hrcn4cbt690dzern1bkbfejll model_weight.tar.gz +tar -xzvf model_weight.tar.gz +``` diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/TumorGenerated.py b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/TumorGenerated.py new file mode 100644 index 0000000000000000000000000000000000000000..9983cf047b1532f5edd20c0d4c78102d6d611eb9 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/TumorGenerated.py @@ -0,0 +1,39 @@ +import random +from typing import Hashable, Mapping, Dict + +from monai.config import KeysCollection +from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms.transform import MapTransform, RandomizableTransform + +from .utils_ import SynthesisTumor +import numpy as np + +class TumorGenerated(RandomizableTransform, MapTransform): + def __init__(self, + keys: KeysCollection, + prob: float = 0.1, + tumor_prob = [0.2, 0.2, 0.2, 0.2, 0.2], + allow_missing_keys: bool = False + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob) + random.seed(0) + np.random.seed(0) + + self.tumor_types = ['tiny', 'small', 'medium', 'large', 'mix'] + + assert len(tumor_prob) == 5 + self.tumor_prob = np.array(tumor_prob) + + + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + self.randomize(None) + + if self._do_transform and (np.max(d['label']) <= 1): + tumor_type = np.random.choice(self.tumor_types, p=self.tumor_prob.ravel()) + + d['image'][0], d['label'][0] = SynthesisTumor(d['image'][0], d['label'][0], tumor_type) + # print(tumor_type, d['image'].shape, np.max(d['label'])) + return d diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/__init__.py b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d9bbd5e8cede113145b2742ebdd63d7226fe6396 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/__init__.py @@ -0,0 +1,5 @@ +### Online Version TumorGeneration ### + +from .TumorGenerated import TumorGenerated + +# from .utils import synthesize_early_tumor, synthesize_medium_tumor, synthesize_large_tumor \ No newline at end of file diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9261842ad2342d4665210e72b1d4fbf1cecc6102 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..186f2601c1547c1c75e98ec4c4d43d8f8b666622 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/__pycache__/utils.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..98fd1dddb42405a44aba063ce03c63e9c6037e60 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/__pycache__/utils.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/__pycache__/utils_.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/__pycache__/utils_.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a986e9e3aaa342229e214e9ce3baeb2d14309bc8 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/__pycache__/utils_.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/diffusion_config/ddpm.yaml b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/diffusion_config/ddpm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..474a973b1f76c2a026e2df916c4a153b5bbea05f --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/diffusion_config/ddpm.yaml @@ -0,0 +1,29 @@ +vqgan_ckpt: None + +# Have to be derived from VQ-GAN Latent space dimensions +diffusion_img_size: 24 +diffusion_depth_size: 24 +diffusion_num_channels: 17 # 17 +out_dim: 8 +dim_mults: [1,2,4,8] +results_folder: checkpoints/ddpm/ +results_folder_postfix: 'own_dataset_t2' +load_milestone: False # False + +batch_size: 2 # 40 +num_workers: 20 +logger: wandb +objective: pred_x0 +save_and_sample_every: 1000 +denoising_fn: Unet3D +train_lr: 1e-4 +timesteps: 2 # number of steps +sampling_timesteps: 250 # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) +loss_type: l1 # L1 or L2 +train_num_steps: 700000 # total training steps +gradient_accumulate_every: 2 # gradient accumulation steps +ema_decay: 0.995 # exponential moving average decay +amp: False # turn on mixed precision +num_sample_rows: 1 +gpus: 0 + diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/diffusion_config/vq_gan_3d.yaml b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/diffusion_config/vq_gan_3d.yaml new file mode 100644 index 0000000000000000000000000000000000000000..29940377c5cadb0c06322de8ac60d0713f799024 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/diffusion_config/vq_gan_3d.yaml @@ -0,0 +1,37 @@ +seed: 1234 +batch_size: 2 # 30 +num_workers: 32 # 30 + +gpus: 1 +accumulate_grad_batches: 1 +default_root_dir: checkpoints/vq_gan/ +default_root_dir_postfix: 'flair' +resume_from_checkpoint: +max_steps: -1 +max_epochs: -1 +precision: 16 +gradient_clip_val: 1.0 + + +embedding_dim: 8 # 256 +n_codes: 16384 # 2048 +n_hiddens: 16 +lr: 3e-4 +downsample: [2, 2, 2] # [4, 4, 4] +disc_channels: 64 +disc_layers: 3 +discriminator_iter_start: 10000 # 50000 +disc_loss_type: hinge +image_gan_weight: 1.0 +video_gan_weight: 1.0 +l1_weight: 4.0 +gan_feat_weight: 4.0 # 0.0 +perceptual_weight: 4.0 # 0.0 +i3d_feat: False +restart_thres: 1.0 +no_random_restart: False +norm_type: group +padding_type: replicate +num_groups: 32 + + diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__init__.py b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a01336e46e37471097d8e1420d0ffd5d803a1edd --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__init__.py @@ -0,0 +1 @@ +from .diffusion import Unet3D, GaussianDiffusion, Tester diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5a4c25126591caccaa2ed31e3a72691a9cd24627 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3045045494382964990c4ff5045b8fc419d87cc Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..af98810e7df2e5c57b8d018ebb489bf6db6530cc Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da1d313f12fe28085f1e582f20ed83e9d11dc01a Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7ac1a6aea44a1d1741652f2f0a598fb18b2730d Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20c6786f9995cbd455c900944b0ab9501a97c202 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c1d57fa4c73e9d438d9c37de5dbbe4cbd481361e Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/ddim.py b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..1bc51f94355732aedb0ffd254b8166144c468370 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/ddim.py @@ -0,0 +1,206 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from .util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): # "uniform" 'quad' + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + # breakpoint() + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, T, H, W = shape + # breakpoint() + size = (batch_size, C, T, H, W) + # print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + # iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(time_range): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + # breakpoint() + e_t = self.model.denoise_fn(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.denoise_fn(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/diffusion.py b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..87111dc3b2ab671e798efc3a320ae81d024f20b9 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/diffusion.py @@ -0,0 +1,1016 @@ +"Largely taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch" + +import math +import copy +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial + +from torch.utils import data +from pathlib import Path +from torch.optim import Adam +from torchvision import transforms as T, utils +from torch.cuda.amp import autocast, GradScaler +from PIL import Image + +from tqdm import tqdm +from einops import rearrange +from einops_exts import check_shape, rearrange_many + +from rotary_embedding_torch import RotaryEmbedding + +from .text import tokenize, bert_embed, BERT_MODEL_DIM +from torch.utils.data import Dataset, DataLoader +from ..vq_gan_3d.model.vqgan import VQGAN + +import matplotlib.pyplot as plt + +# helpers functions + + +def exists(x): + return x is not None + + +def noop(*args, **kwargs): + pass + + +def is_odd(n): + return (n % 2) == 1 + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +def cycle(dl): + while True: + for data in dl: + yield data + + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + + +def prob_mask_like(shape, prob, device): + if prob == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif prob == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob + + +def is_list_str(x): + if not isinstance(x, (list, tuple)): + return False + return all([type(el) == str for el in x]) + +# relative positional bias + + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads=8, + num_buckets=32, + max_distance=128 + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / + max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min( + val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype=torch.long, device=device) + k_pos = torch.arange(n, dtype=torch.long, device=device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket( + rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') + +# small helper modules + + +class EMA(): + def __init__(self, beta): + super().__init__() + self.beta = beta + + def update_model_average(self, ma_model, current_model): + for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + old_weight, up_weight = ma_params.data, current_params.data + ma_params.data = self.update_average(old_weight, up_weight) + + def update_average(self, old, new): + if old is None: + return new + return old * self.beta + (1 - self.beta) * new + + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +def Upsample(dim): + return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) + + +def Downsample(dim): + return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) + + +class LayerNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1)) + + def forward(self, x): + var = torch.var(x, dim=1, unbiased=False, keepdim=True) + mean = torch.mean(x, dim=1, keepdim=True) + return (x - mean) / (var + self.eps).sqrt() * self.gamma + + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = LayerNorm(dim) + + def forward(self, x, **kwargs): + x = self.norm(x) + return self.fn(x, **kwargs) + +# building block modules + + +class Block(nn.Module): + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding=(0, 1, 1)) + self.norm = nn.GroupNorm(groups, dim_out) + self.act = nn.SiLU() + + def forward(self, x, scale_shift=None): + x = self.proj(x) + x = self.norm(x) + + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + + return self.act(x) + + +class ResnetBlock(nn.Module): + def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): + super().__init__() + self.mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(time_emb_dim, dim_out * 2) + ) if exists(time_emb_dim) else None + + self.block1 = Block(dim, dim_out, groups=groups) + self.block2 = Block(dim_out, dim_out, groups=groups) + self.res_conv = nn.Conv3d( + dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb=None): + + scale_shift = None + if exists(self.mlp): + assert exists(time_emb), 'time emb must be passed in' + time_emb = self.mlp(time_emb) + time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') + scale_shift = time_emb.chunk(2, dim=1) + + h = self.block1(x, scale_shift=scale_shift) + + h = self.block2(h) + return h + self.res_conv(x) + + +class SpatialLinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, f, h, w = x.shape + x = rearrange(x, 'b c f h w -> (b f) c h w') + + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = rearrange_many( + qkv, 'b (h c) x y -> b h c (x y)', h=self.heads) + + q = q.softmax(dim=-2) + k = k.softmax(dim=-1) + + q = q * self.scale + context = torch.einsum('b h d n, b h e n -> b h d e', k, v) + + out = torch.einsum('b h d e, b h d n -> b h e n', context, q) + out = rearrange(out, 'b h c (x y) -> b (h c) x y', + h=self.heads, x=h, y=w) + out = self.to_out(out) + return rearrange(out, '(b f) c h w -> b c f h w', b=b) + +# attention along space and time + + +class EinopsToAndFrom(nn.Module): + def __init__(self, from_einops, to_einops, fn): + super().__init__() + self.from_einops = from_einops + self.to_einops = to_einops + self.fn = fn + + def forward(self, x, **kwargs): + shape = x.shape + reconstitute_kwargs = dict( + tuple(zip(self.from_einops.split(' '), shape))) + x = rearrange(x, f'{self.from_einops} -> {self.to_einops}') + x = self.fn(x, **kwargs) + x = rearrange( + x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + heads=4, + dim_head=32, + rotary_emb=None + ): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + + self.rotary_emb = rotary_emb + self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False) + self.to_out = nn.Linear(hidden_dim, dim, bias=False) + + def forward( + self, + x, + pos_bias=None, + focus_present_mask=None + ): + n, device = x.shape[-2], x.device + + qkv = self.to_qkv(x).chunk(3, dim=-1) + + if exists(focus_present_mask) and focus_present_mask.all(): + # if all batch samples are focusing on present + # it would be equivalent to passing that token's values through to the output + values = qkv[-1] + return self.to_out(values) + + # split out heads + + q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads) + + # scale + + q = q * self.scale + + # rotate positions into queries and keys for time attention + + if exists(self.rotary_emb): + q = self.rotary_emb.rotate_queries_or_keys(q) + k = self.rotary_emb.rotate_queries_or_keys(k) + + # similarity + + sim = einsum('... h i d, ... h j d -> ... h i j', q, k) + + # relative positional bias + + if exists(pos_bias): + sim = sim + pos_bias + + if exists(focus_present_mask) and not (~focus_present_mask).all(): + attend_all_mask = torch.ones( + (n, n), device=device, dtype=torch.bool) + attend_self_mask = torch.eye(n, device=device, dtype=torch.bool) + + mask = torch.where( + rearrange(focus_present_mask, 'b -> b 1 1 1 1'), + rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), + rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), + ) + + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + # numerical stability + + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + # aggregate values + + out = einsum('... h i j, ... h j d -> ... h i d', attn, v) + out = rearrange(out, '... h n d -> ... n (h d)') + return self.to_out(out) + +# model + + +class Unet3D(nn.Module): + def __init__( + self, + dim, + cond_dim=None, + out_dim=None, + dim_mults=(1, 2, 4, 8), + channels=3, + attn_heads=8, + attn_dim_head=32, + use_bert_text_cond=False, + init_dim=None, + init_kernel_size=7, + use_sparse_linear_attn=True, + block_type='resnet', + resnet_groups=8 + ): + super().__init__() + self.channels = channels + + # temporal attention and its relative positional encoding + + rotary_emb = RotaryEmbedding(min(32, attn_dim_head)) + + def temporal_attn(dim): return EinopsToAndFrom('b c f h w', 'b (h w) f c', Attention( + dim, heads=attn_heads, dim_head=attn_dim_head, rotary_emb=rotary_emb)) + + # realistically will not be able to generate that many frames of video... yet + self.time_rel_pos_bias = RelativePositionBias( + heads=attn_heads, max_distance=32) + + # initial conv + + init_dim = default(init_dim, dim) + assert is_odd(init_kernel_size) + + init_padding = init_kernel_size // 2 + self.init_conv = nn.Conv3d(channels, init_dim, (1, init_kernel_size, + init_kernel_size), padding=(0, init_padding, init_padding)) + + self.init_temporal_attn = Residual( + PreNorm(init_dim, temporal_attn(init_dim))) + + # dimensions + + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + # time conditioning + + time_dim = dim * 4 + self.time_mlp = nn.Sequential( + SinusoidalPosEmb(dim), + nn.Linear(dim, time_dim), + nn.GELU(), + nn.Linear(time_dim, time_dim) + ) + + # text conditioning + + self.has_cond = exists(cond_dim) or use_bert_text_cond + cond_dim = BERT_MODEL_DIM if use_bert_text_cond else cond_dim + + self.null_cond_emb = nn.Parameter( + torch.randn(1, cond_dim)) if self.has_cond else None + + cond_dim = time_dim + int(cond_dim or 0) + + # layers + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + + num_resolutions = len(in_out) + # block type + + block_klass = partial(ResnetBlock, groups=resnet_groups) + block_klass_cond = partial(block_klass, time_emb_dim=cond_dim) + + # modules for all layers + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + + self.downs.append(nn.ModuleList([ + block_klass_cond(dim_in, dim_out), + block_klass_cond(dim_out, dim_out), + Residual(PreNorm(dim_out, SpatialLinearAttention( + dim_out, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(), + Residual(PreNorm(dim_out, temporal_attn(dim_out))), + Downsample(dim_out) if not is_last else nn.Identity() + ])) + + mid_dim = dims[-1] + self.mid_block1 = block_klass_cond(mid_dim, mid_dim) + + spatial_attn = EinopsToAndFrom( + 'b c f h w', 'b f (h w) c', Attention(mid_dim, heads=attn_heads)) + + self.mid_spatial_attn = Residual(PreNorm(mid_dim, spatial_attn)) + self.mid_temporal_attn = Residual( + PreNorm(mid_dim, temporal_attn(mid_dim))) + + self.mid_block2 = block_klass_cond(mid_dim, mid_dim) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): + is_last = ind >= (num_resolutions - 1) + + self.ups.append(nn.ModuleList([ + block_klass_cond(dim_out * 2, dim_in), + block_klass_cond(dim_in, dim_in), + Residual(PreNorm(dim_in, SpatialLinearAttention( + dim_in, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(), + Residual(PreNorm(dim_in, temporal_attn(dim_in))), + Upsample(dim_in) if not is_last else nn.Identity() + ])) + + out_dim = default(out_dim, channels) + self.final_conv = nn.Sequential( + block_klass(dim * 2, dim), + nn.Conv3d(dim, out_dim, 1) + ) + + def forward_with_cond_scale( + self, + *args, + cond_scale=2., + **kwargs + ): + logits = self.forward(*args, null_cond_prob=0., **kwargs) + if cond_scale == 1 or not self.has_cond: + return logits + + null_logits = self.forward(*args, null_cond_prob=1., **kwargs) + return null_logits + (logits - null_logits) * cond_scale + + def forward( + self, + x, + time, + cond=None, + null_cond_prob=0., + focus_present_mask=None, + # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time) + prob_focus_present=0. + ): + assert not (self.has_cond and not exists(cond) + ), 'cond must be passed in if cond_dim specified' + x = torch.cat([x, cond], dim=1) + + batch, device = x.shape[0], x.device + + focus_present_mask = default(focus_present_mask, lambda: prob_mask_like( + (batch,), prob_focus_present, device=device)) + + time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device=x.device) + + x = self.init_conv(x) + r = x.clone() + + x = self.init_temporal_attn(x, pos_bias=time_rel_pos_bias) + + t = self.time_mlp(time) if exists(self.time_mlp) else None # [2, 128] + + # classifier free guidance + + if self.has_cond: + batch, device = x.shape[0], x.device + mask = prob_mask_like((batch,), null_cond_prob, device=device) + cond = torch.where(rearrange(mask, 'b -> b 1'), + self.null_cond_emb, cond) + t = torch.cat((t, cond), dim=-1) + + h = [] + + for block1, block2, spatial_attn, temporal_attn, downsample in self.downs: + x = block1(x, t) + x = block2(x, t) + x = spatial_attn(x) + x = temporal_attn(x, pos_bias=time_rel_pos_bias, + focus_present_mask=focus_present_mask) + h.append(x) + x = downsample(x) + + # [2, 256, 32, 4, 4] + x = self.mid_block1(x, t) + x = self.mid_spatial_attn(x) + x = self.mid_temporal_attn( + x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask) + x = self.mid_block2(x, t) + + for block1, block2, spatial_attn, temporal_attn, upsample in self.ups: + x = torch.cat((x, h.pop()), dim=1) + x = block1(x, t) + x = block2(x, t) + x = spatial_attn(x) + x = temporal_attn(x, pos_bias=time_rel_pos_bias, + focus_present_mask=focus_present_mask) + x = upsample(x) + + x = torch.cat((x, r), dim=1) + return self.final_conv(x) + +# gaussian diffusion trainer class + + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = torch.linspace(0, timesteps, steps, dtype=torch.float64) + alphas_cumprod = torch.cos( + ((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return torch.clip(betas, 0, 0.9999) + + +class GaussianDiffusion(nn.Module): + def __init__( + self, + denoise_fn, + *, + image_size, + num_frames, + text_use_bert_cls=False, + channels=3, + timesteps=1000, + loss_type='l1', + use_dynamic_thres=False, # from the Imagen paper + dynamic_thres_percentile=0.9, + vqgan_ckpt=None, + device=None + ): + super().__init__() + self.channels = channels + self.image_size = image_size + self.num_frames = num_frames + self.denoise_fn = denoise_fn + self.device = device + + if vqgan_ckpt: + self.vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt).cuda() + self.vqgan.eval() + else: + self.vqgan = None + + betas = cosine_beta_schedule(timesteps) + + alphas = 1. - betas + alphas_cumprod = torch.cumprod(alphas, axis=0) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.loss_type = loss_type + + # register buffer helper function that casts float64 to float32 + + def register_buffer(name, val): return self.register_buffer( + name, val.to(torch.float32)) + + register_buffer('betas', betas) + register_buffer('alphas_cumprod', alphas_cumprod) + register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) + + # calculations for diffusion q(x_t | x_{t-1}) and others + + register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) + register_buffer('sqrt_one_minus_alphas_cumprod', + torch.sqrt(1. - alphas_cumprod)) + register_buffer('log_one_minus_alphas_cumprod', + torch.log(1. - alphas_cumprod)) + register_buffer('sqrt_recip_alphas_cumprod', + torch.sqrt(1. / alphas_cumprod)) + register_buffer('sqrt_recipm1_alphas_cumprod', + torch.sqrt(1. / alphas_cumprod - 1)) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + + posterior_variance = betas * \ + (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + + register_buffer('posterior_variance', posterior_variance) + + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + + register_buffer('posterior_log_variance_clipped', + torch.log(posterior_variance.clamp(min=1e-20))) + register_buffer('posterior_mean_coef1', betas * + torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) + register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) + * torch.sqrt(alphas) / (1. - alphas_cumprod)) + + # text conditioning parameters + + self.text_use_bert_cls = text_use_bert_cls + + # dynamic thresholding when sampling + + self.use_dynamic_thres = use_dynamic_thres + self.dynamic_thres_percentile = dynamic_thres_percentile + + def q_mean_variance(self, x_start, t): + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract(1. - self.alphas_cumprod, t, x_start.shape) + log_variance = extract( + self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract( + self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool, cond=None, cond_scale=1.): + x_recon = self.predict_start_from_noise( + x, t=t, noise=self.denoise_fn.forward_with_cond_scale(x, t, cond=cond, cond_scale=cond_scale)) + + if clip_denoised: + s = 1. + if self.use_dynamic_thres: + s = torch.quantile( + rearrange(x_recon, 'b ... -> b (...)').abs(), + self.dynamic_thres_percentile, + dim=-1 + ) + + s.clamp_(min=1.) + s = s.view(-1, *((1,) * (x_recon.ndim - 1))) + + # clip by threshold, depending on whether static or dynamic + x_recon = x_recon.clamp(-s, s) / s + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.inference_mode() + def p_sample(self, x, t, cond=None, cond_scale=1., clip_denoised=True): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance( + x=x, t=t, clip_denoised=clip_denoised, cond=cond, cond_scale=cond_scale) + noise = torch.randn_like(x) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, + *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.inference_mode() + def p_sample_loop(self, shape, cond=None, cond_scale=1.): + device = self.betas.device + + b = shape[0] + img = torch.randn(shape, device=device) + # print('cond', cond.shape) + for i in reversed(range(0, self.num_timesteps)): + img = self.p_sample(img, torch.full( + (b,), i, device=device, dtype=torch.long), cond=cond, cond_scale=cond_scale) + + return img + + @torch.inference_mode() + def sample(self, cond=None, cond_scale=1., batch_size=16): + device = next(self.denoise_fn.parameters()).device + + if is_list_str(cond): + cond = bert_embed(tokenize(cond)).to(device) + + # batch_size = cond.shape[0] if exists(cond) else batch_size + batch_size = batch_size + image_size = self.image_size + channels = 8 # self.channels + num_frames = self.num_frames + # print((batch_size, channels, num_frames, image_size, image_size)) + # print('cond_',cond.shape) + _sample = self.p_sample_loop( + (batch_size, channels, num_frames, image_size, image_size), cond=cond, cond_scale=cond_scale) + + if isinstance(self.vqgan, VQGAN): + # denormalize TODO: Remove eventually + _sample = (((_sample + 1.0) / 2.0) * (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) + self.vqgan.codebook.embeddings.min() + + _sample = self.vqgan.decode(_sample, quantize=True) + else: + unnormalize_img(_sample) + + return _sample + + @torch.inference_mode() + def interpolate(self, x1, x2, t=None, lam=0.5): + b, *_, device = *x1.shape, x1.device + t = default(t, self.num_timesteps - 1) + + assert x1.shape == x2.shape + + t_batched = torch.stack([torch.tensor(t, device=device)] * b) + xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) + + img = (1 - lam) * xt1 + lam * xt2 + for i in reversed(range(0, t)): + img = self.p_sample(img, torch.full( + (b,), i, device=device, dtype=torch.long)) + + return img + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, + t, x_start.shape) * noise + ) + + def p_losses(self, x_start, t, cond=None, noise=None, **kwargs): + b, c, f, h, w, device = *x_start.shape, x_start.device + noise = default(noise, lambda: torch.randn_like(x_start)) + # breakpoint() + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # [2, 8, 32, 32, 32] + + if is_list_str(cond): + cond = bert_embed( + tokenize(cond), return_cls_repr=self.text_use_bert_cls) + cond = cond.to(device) + + x_recon = self.denoise_fn(x_noisy, t, cond=cond, **kwargs) + + if self.loss_type == 'l1': + loss = F.l1_loss(noise, x_recon) + elif self.loss_type == 'l2': + loss = F.mse_loss(noise, x_recon) + else: + raise NotImplementedError() + + return loss + + def forward(self, x, *args, **kwargs): + bs = int(x.shape[0]/2) + img=x[:bs,...] + mask=x[bs:,...] + mask_=(1-mask).detach() + masked_img = (img*mask_).detach() + masked_img=masked_img.permute(0,1,-1,-3,-2) + img=img.permute(0,1,-1,-3,-2) + mask=mask.permute(0,1,-1,-3,-2) + # breakpoint() + if isinstance(self.vqgan, VQGAN): + with torch.no_grad(): + img = self.vqgan.encode( + img, quantize=False, include_embeddings=True) + # normalize to -1 and 1 + img = ((img - self.vqgan.codebook.embeddings.min()) / + (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + masked_img = self.vqgan.encode( + masked_img, quantize=False, include_embeddings=True) + # normalize to -1 and 1 + masked_img = ((masked_img - self.vqgan.codebook.embeddings.min()) / + (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + else: + print("Hi") + img = normalize_img(img) + masked_img = normalize_img(masked_img) + mask = mask*2.0 - 1.0 + cc = torch.nn.functional.interpolate(mask, size=masked_img.shape[-3:]) + cond = torch.cat((masked_img, cc), dim=1) + + b, device, img_size, = img.shape[0], img.device, self.image_size + t = torch.randint(0, self.num_timesteps, (b,), device=device).long() + # breakpoint() + return self.p_losses(img, t, cond=cond, *args, **kwargs) + +# trainer class + + +CHANNELS_TO_MODE = { + 1: 'L', + 3: 'RGB', + 4: 'RGBA' +} + + +def seek_all_images(img, channels=3): + assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid' + mode = CHANNELS_TO_MODE[channels] + + i = 0 + while True: + try: + img.seek(i) + yield img.convert(mode) + except EOFError: + break + i += 1 + +# tensor of shape (channels, frames, height, width) -> gif + + +def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True): + tensor = ((tensor - tensor.min()) / (tensor.max() - tensor.min())) * 1.0 + images = map(T.ToPILImage(), tensor.unbind(dim=1)) + first_img, *rest_imgs = images + first_img.save(path, save_all=True, append_images=rest_imgs, + duration=duration, loop=loop, optimize=optimize) + return images + +# gif -> (channels, frame, height, width) tensor + + +def gif_to_tensor(path, channels=3, transform=T.ToTensor()): + img = Image.open(path) + tensors = tuple(map(transform, seek_all_images(img, channels=channels))) + return torch.stack(tensors, dim=1) + + +def identity(t, *args, **kwargs): + return t + + +def normalize_img(t): + return t * 2 - 1 + + +def unnormalize_img(t): + return (t + 1) * 0.5 + + +def cast_num_frames(t, *, frames): + f = t.shape[1] + + if f == frames: + return t + + if f > frames: + return t[:, :frames] + + return F.pad(t, (0, 0, 0, 0, 0, frames - f)) + + +class Dataset(data.Dataset): + def __init__( + self, + folder, + image_size, + channels=3, + num_frames=16, + horizontal_flip=False, + force_num_frames=True, + exts=['gif'] + ): + super().__init__() + self.folder = folder + self.image_size = image_size + self.channels = channels + self.paths = [p for ext in exts for p in Path( + f'{folder}').glob(f'**/*.{ext}')] + + self.cast_num_frames_fn = partial( + cast_num_frames, frames=num_frames) if force_num_frames else identity + + self.transform = T.Compose([ + T.Resize(image_size), + T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity), + T.CenterCrop(image_size), + T.ToTensor() + ]) + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = self.paths[index] + tensor = gif_to_tensor(path, self.channels, transform=self.transform) + return self.cast_num_frames_fn(tensor) + +# trainer class + + +class Tester(object): + def __init__( + self, + diffusion_model, + ): + super().__init__() + self.model = diffusion_model + self.ema_model = copy.deepcopy(self.model) + self.step=0 + self.image_size = diffusion_model.image_size + + self.reset_parameters() + + def reset_parameters(self): + self.ema_model.load_state_dict(self.model.state_dict()) + + + def load(self, milestone, map_location=None, **kwargs): + if milestone == -1: + all_milestones = [int(p.stem.split('-')[-1]) + for p in Path(self.results_folder).glob('**/*.pt')] + assert len( + all_milestones) > 0, 'need to have at least one milestone to load from latest checkpoint (milestone == -1)' + milestone = max(all_milestones) + + if map_location: + data = torch.load(milestone, map_location=map_location) + else: + data = torch.load(milestone) + + self.step = data['step'] + self.model.load_state_dict(data['model'], **kwargs) + self.ema_model.load_state_dict(data['ema'], **kwargs) + + diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/text.py b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/text.py new file mode 100644 index 0000000000000000000000000000000000000000..42dca78bec5075fb4f59c522aa3d00cc395d7536 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/text.py @@ -0,0 +1,94 @@ +"Taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch" + +import torch +from einops import rearrange + + +def exists(val): + return val is not None + +# singleton globals + + +MODEL = None +TOKENIZER = None +BERT_MODEL_DIM = 768 + + +def get_tokenizer(): + global TOKENIZER + if not exists(TOKENIZER): + TOKENIZER = torch.hub.load( + 'huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased') + return TOKENIZER + + +def get_bert(): + global MODEL + if not exists(MODEL): + MODEL = torch.hub.load( + 'huggingface/pytorch-transformers', 'model', 'bert-base-cased') + if torch.cuda.is_available(): + MODEL = MODEL.cuda() + + return MODEL + +# tokenize + + +def tokenize(texts, add_special_tokens=True): + if not isinstance(texts, (list, tuple)): + texts = [texts] + + tokenizer = get_tokenizer() + + encoding = tokenizer.batch_encode_plus( + texts, + add_special_tokens=add_special_tokens, + padding=True, + return_tensors='pt' + ) + + token_ids = encoding.input_ids + return token_ids + +# embedding function + + +@torch.no_grad() +def bert_embed( + token_ids, + return_cls_repr=False, + eps=1e-8, + pad_id=0. +): + model = get_bert() + mask = token_ids != pad_id + + if torch.cuda.is_available(): + token_ids = token_ids.cuda() + mask = mask.cuda() + + outputs = model( + input_ids=token_ids, + attention_mask=mask, + output_hidden_states=True + ) + + hidden_state = outputs.hidden_states[-1] + + if return_cls_repr: + # return [cls] as representation + return hidden_state[:, 0] + + if not exists(mask): + return hidden_state.mean(dim=1) + + # mean all tokens excluding [cls], accounting for length + mask = mask[:, 1:] + mask = rearrange(mask, 'b n -> b n 1') + + numer = (hidden_state[:, 1:] * mask).sum(dim=1) + denom = mask.sum(dim=1) + masked_mean = numer / (denom + eps) + return diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/time_embedding.py b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/time_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..7c911fa4f485295005cda9c9c2099db3ddbda15c --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/time_embedding.py @@ -0,0 +1,75 @@ +import math +import torch +import torch.nn as nn + +from monai.networks.layers.utils import get_act_layer + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, emb_dim=16, downscale_freq_shift=1, max_period=10000, flip_sin_to_cos=False): + super().__init__() + self.emb_dim = emb_dim + self.downscale_freq_shift = downscale_freq_shift + self.max_period = max_period + self.flip_sin_to_cos = flip_sin_to_cos + + def forward(self, x): + device = x.device + half_dim = self.emb_dim // 2 + emb = math.log(self.max_period) / \ + (half_dim - self.downscale_freq_shift) + emb = torch.exp(-emb*torch.arange(half_dim, device=device)) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + + if self.flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + if self.emb_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class LearnedSinusoidalPosEmb(nn.Module): + """ following @crowsonkb 's lead with learned sinusoidal pos emb """ + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, emb_dim): + super().__init__() + self.emb_dim = emb_dim + half_dim = emb_dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x): + x = x[:, None] + freqs = x * self.weights[None, :] * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + if self.emb_dim % 2 == 1: + fouriered = torch.nn.functional.pad(fouriered, (0, 1, 0, 0)) + return fouriered + + +class TimeEmbbeding(nn.Module): + def __init__( + self, + emb_dim=64, + pos_embedder=SinusoidalPosEmb, + pos_embedder_kwargs={}, + act_name=("SWISH", {}) # Swish = SiLU + ): + super().__init__() + self.emb_dim = emb_dim + self.pos_emb_dim = pos_embedder_kwargs.get('emb_dim', emb_dim//4) + pos_embedder_kwargs['emb_dim'] = self.pos_emb_dim + self.pos_embedder = pos_embedder(**pos_embedder_kwargs) + + self.time_emb = nn.Sequential( + self.pos_embedder, + nn.Linear(self.pos_emb_dim, self.emb_dim), + get_act_layer(act_name), + nn.Linear(self.emb_dim, self.emb_dim) + ) + + def forward(self, time): + return self.time_emb(time) diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/unet.py b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..e9cae8021e874a92b6baaaf4decd802516c2dc87 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/unet.py @@ -0,0 +1,226 @@ +from ddpm.time_embedding import TimeEmbbeding + +import monai.networks.nets as nets +import torch +import torch.nn as nn +from einops import rearrange + +from monai.networks.blocks import UnetBasicBlock, UnetResBlock, UnetUpBlock, Convolution, UnetOutBlock +from monai.networks.layers.utils import get_act_layer + + +class DownBlock(nn.Module): + def __init__( + self, + spatial_dims, + in_ch, + out_ch, + time_emb_dim, + cond_emb_dim, + act_name=("swish", {}), + **kwargs): + super(DownBlock, self).__init__() + self.loca_time_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(time_emb_dim, in_ch) # in_ch * 2 + ) + if cond_emb_dim is not None: + self.loca_cond_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(cond_emb_dim, in_ch), + ) + self.down_op = UnetBasicBlock( + spatial_dims, in_ch, out_ch, act_name=act_name, **kwargs) + + def forward(self, x, time_emb, cond_emb): + b, c, *_ = x.shape + sp_dim = x.ndim-2 + + # ------------ Time ---------- + time_emb = self.loca_time_embedder(time_emb) + time_emb = time_emb.reshape(b, c, *((1,)*sp_dim)) + # scale, shift = time_emb.chunk(2, dim = 1) + + # ------------ Combine ------------ + # x = x * (scale + 1) + shift + x = x + time_emb + + # ----------- Condition ------------ + if cond_emb is not None: + cond_emb = self.loca_cond_embedder(cond_emb) + cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim)) + x = x + cond_emb + + # ----------- Image --------- + y = self.down_op(x) + return y + + +class UpBlock(nn.Module): + def __init__( + self, + spatial_dims, + skip_ch, + enc_ch, + time_emb_dim, + cond_emb_dim, + act_name=("swish", {}), + **kwargs): + super(UpBlock, self).__init__() + self.up_op = UnetUpBlock(spatial_dims, enc_ch, + skip_ch, act_name=act_name, **kwargs) + self.loca_time_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(time_emb_dim, skip_ch * 2), + ) + if cond_emb_dim is not None: + self.loca_cond_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(cond_emb_dim, skip_ch * 2), + ) + + def forward(self, x_skip, x_enc, time_emb, cond_emb): + b, c, *_ = x_enc.shape + sp_dim = x_enc.ndim-2 + + # ----------- Time -------------- + time_emb = self.loca_time_embedder(time_emb) + time_emb = time_emb.reshape(b, c, *((1,)*sp_dim)) + # scale, shift = time_emb.chunk(2, dim = 1) + + # -------- Combine ------------- + # y = x * (scale + 1) + shift + x_enc = x_enc + time_emb + + # ----------- Condition ------------ + if cond_emb is not None: + cond_emb = self.loca_cond_embedder(cond_emb) + cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim)) + x_enc = x_enc + cond_emb + + # ----------- Image ------------- + y = self.up_op(x_enc, x_skip) + + # -------- Combine ------------- + # y = y * (scale + 1) + shift + + return y + + +class UNet(nn.Module): + + def __init__(self, + in_ch=1, + out_ch=1, + spatial_dims=3, + hid_chs=[32, 64, 128, 256, 512], + kernel_sizes=[(1, 3, 3), (1, 3, 3), (1, 3, 3), 3, 3], + strides=[1, (1, 2, 2), (1, 2, 2), 2, 2], + upsample_kernel_sizes=None, + act_name=("SWISH", {}), + norm_name=("INSTANCE", {"affine": True}), + time_embedder=TimeEmbbeding, + time_embedder_kwargs={}, + cond_embedder=None, + cond_embedder_kwargs={}, + # True = all but last layer, 0/False=disable, 1=only first layer, ... + deep_ver_supervision=True, + estimate_variance=False, + use_self_conditioning=False, + **kwargs + ): + super().__init__() + if upsample_kernel_sizes is None: + upsample_kernel_sizes = strides[1:] + + # ------------- Time-Embedder----------- + self.time_embedder = time_embedder(**time_embedder_kwargs) + + # ------------- Condition-Embedder----------- + if cond_embedder is not None: + self.cond_embedder = cond_embedder(**cond_embedder_kwargs) + cond_emb_dim = self.cond_embedder.emb_dim + else: + self.cond_embedder = None + cond_emb_dim = None + + # ----------- In-Convolution ------------ + in_ch = in_ch*2 if use_self_conditioning else in_ch + self.inc = UnetBasicBlock(spatial_dims, in_ch, hid_chs[0], kernel_size=kernel_sizes[0], stride=strides[0], + act_name=act_name, norm_name=norm_name, **kwargs) + + # ----------- Encoder ---------------- + self.encoders = nn.ModuleList([ + DownBlock(spatial_dims, hid_chs[i-1], hid_chs[i], time_emb_dim=self.time_embedder.emb_dim, + cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[ + i], stride=strides[i], act_name=act_name, + norm_name=norm_name, **kwargs) + for i in range(1, len(strides)) + ]) + + # ------------ Decoder ---------- + self.decoders = nn.ModuleList([ + UpBlock(spatial_dims, hid_chs[i], hid_chs[i+1], time_emb_dim=self.time_embedder.emb_dim, + cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[i + + 1], stride=strides[i+1], act_name=act_name, + norm_name=norm_name, upsample_kernel_size=upsample_kernel_sizes[i], **kwargs) + for i in range(len(strides)-1) + ]) + + # --------------- Out-Convolution ---------------- + out_ch_hor = out_ch*2 if estimate_variance else out_ch + self.outc = UnetOutBlock( + spatial_dims, hid_chs[0], out_ch_hor, dropout=None) + if isinstance(deep_ver_supervision, bool): + deep_ver_supervision = len( + strides)-2 if deep_ver_supervision else 0 + self.outc_ver = nn.ModuleList([ + UnetOutBlock(spatial_dims, hid_chs[i], out_ch, dropout=None) + for i in range(1, deep_ver_supervision+1) + ]) + + def forward(self, x_t, t, cond=None, self_cond=None, **kwargs): + condition = cond + # x_t [B, C, (D), H, W] + # t [B,] + + # -------- In-Convolution -------------- + x = [None for _ in range(len(self.encoders)+1)] + x_t = torch.cat([x_t, self_cond], + dim=1) if self_cond is not None else x_t + x[0] = self.inc(x_t) + + # -------- Time Embedding (Gloabl) ----------- + time_emb = self.time_embedder(t) # [B, C] + + # -------- Condition Embedding (Gloabl) ----------- + if (condition is None) or (self.cond_embedder is None): + cond_emb = None + else: + cond_emb = self.cond_embedder(condition) # [B, C] + + # --------- Encoder -------------- + for i in range(len(self.encoders)): + x[i+1] = self.encoders[i](x[i], time_emb, cond_emb) + + # -------- Decoder ----------- + for i in range(len(self.decoders), 0, -1): + x[i-1] = self.decoders[i-1](x[i-1], x[i], time_emb, cond_emb) + + # ---------Out-Convolution ------------ + y_hor = self.outc(x[0]) + y_ver = [outc_ver_i(x[i+1]) + for i, outc_ver_i in enumerate(self.outc_ver)] + + return y_hor # , y_ver + + def forward_with_cond_scale(self, *args, cond_scale=0., **kwargs): + return self.forward(*args, **kwargs) + + +if __name__ == '__main__': + model = UNet(in_ch=3) + input = torch.randn((1, 3, 16, 128, 128)) + time = torch.randn((1,)) + out_hor, out_ver = model(input, time) + print(out_hor[0].shape) diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/util.py b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..2047cf56802046bbf37cfd33f819ed75a239a7df --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/ddpm/util.py @@ -0,0 +1,271 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +# from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + if c != 1: + steps_out = ddim_timesteps + 1 + else: + steps_out = ddim_timesteps + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6db109a51de3e5b659aef7922b6a7a4e21b218ae Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__init__.py b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38c08845c1ee2f7b4b52bf95fa282b0808931a3a --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__init__.py @@ -0,0 +1,3 @@ +from .vqgan import VQGAN +from .codebook import Codebook +from .lpips import LPIPS diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cf059a16f68c0ed927cb076bdde87e3d4b24bef Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..044d85f49f99efcf654fa12107f49a7064be4449 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22441541d9e4e1adf1caa18b4d6b8ca923e1fe0c Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be6282b2dd9dafb677790c29d12e4699bd903c97 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth new file mode 100644 index 0000000000000000000000000000000000000000..f57dcf5cc764d61c8a460365847fb2137ff0a62d --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868 +size 7289 diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/codebook.py b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/codebook.py new file mode 100644 index 0000000000000000000000000000000000000000..c17ed701a2824cc0dcd75670d47a6ab3842e9d35 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/codebook.py @@ -0,0 +1,109 @@ +""" Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from ..utils import shift_dim + + +class Codebook(nn.Module): + def __init__(self, n_codes, embedding_dim, no_random_restart=False, restart_thres=1.0): + super().__init__() + self.register_buffer('embeddings', torch.randn(n_codes, embedding_dim)) + self.register_buffer('N', torch.zeros(n_codes)) + self.register_buffer('z_avg', self.embeddings.data.clone()) + + self.n_codes = n_codes + self.embedding_dim = embedding_dim + self._need_init = True + self.no_random_restart = no_random_restart + self.restart_thres = restart_thres + + def _tile(self, x): + d, ew = x.shape + if d < self.n_codes: + n_repeats = (self.n_codes + d - 1) // d + std = 0.01 / np.sqrt(ew) + x = x.repeat(n_repeats, 1) + x = x + torch.randn_like(x) * std + return x + + def _init_embeddings(self, z): + # z: [b, c, t, h, w] + self._need_init = False + breakpoint() + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [65536, 8] [2, 8, 32, 32, 32] + y = self._tile(flat_inputs) # [65536, 8] + + d = y.shape[0] + _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + self.embeddings.data.copy_(_k_rand) + self.z_avg.data.copy_(_k_rand) + self.N.data.copy_(torch.ones(self.n_codes)) + + def forward(self, z): + # z: [b, c, t, h, w] + if self._need_init and self.training: + self._init_embeddings(z) + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [bthw, c] [65536, 8] + distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \ + - 2 * flat_inputs @ self.embeddings.t() \ + + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) # [bthw, c] [65536, 8] + + encoding_indices = torch.argmin(distances, dim=1) # [65536] + encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as( + flat_inputs) # [bthw, ncode] [65536, 16384] + encoding_indices = encoding_indices.view( + z.shape[0], *z.shape[2:]) # [b, t, h, w, ncode] [2, 32, 32, 32] + + embeddings = F.embedding( + encoding_indices, self.embeddings) # [b, t, h, w, c] self.embeddings [16384, 8] + embeddings = shift_dim(embeddings, -1, 1) # [b, c, t, h, w] [2, 8, 32, 32, 32] + + commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) + + # EMA codebook update + if self.training: + n_total = encode_onehot.sum(dim=0) # [16384] + encode_sum = flat_inputs.t() @ encode_onehot # [8, 16384] + if dist.is_initialized(): + dist.all_reduce(n_total) + dist.all_reduce(encode_sum) + + self.N.data.mul_(0.99).add_(n_total, alpha=0.01) + self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) + + n = self.N.sum() + weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n + encode_normalized = self.z_avg / weights.unsqueeze(1) + self.embeddings.data.copy_(encode_normalized) + + y = self._tile(flat_inputs) + _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + + if not self.no_random_restart: + usage = (self.N.view(self.n_codes, 1) + >= self.restart_thres).float() + self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) + + embeddings_st = (embeddings - z).detach() + z + + avg_probs = torch.mean(encode_onehot, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * + torch.log(avg_probs + 1e-10))) + + return dict(embeddings=embeddings_st, encodings=encoding_indices, + commitment_loss=commitment_loss, perplexity=perplexity) + + def dictionary_lookup(self, encodings): + embeddings = F.embedding(encodings, self.embeddings) + return embeddings diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/lpips.py b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..019434b9b29945d67e6e0a95dec324df7ff908f3 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/lpips.py @@ -0,0 +1,181 @@ +"""Adapted from https://github.com/SongweiGe/TATS""" + +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + + +from collections import namedtuple +from torchvision import models +import torch.nn as nn +import torch +from tqdm import tqdm +import requests +import os +import hashlib +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format( + name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + # self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + self.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name is not "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + model.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer( + input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor( + outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) + for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor( + [-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor( + [.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, + padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple( + "VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, + h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x/(norm_factor+eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py new file mode 100644 index 0000000000000000000000000000000000000000..a64ec8878c0ecbf418775e2958b2cbf578ce2918 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py @@ -0,0 +1,561 @@ +"""Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import math +import argparse +import numpy as np +import pickle as pkl + +import pytorch_lightning as pl +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from ..utils import shift_dim, adopt_weight, comp_getattr +from .lpips import LPIPS +from .codebook import Codebook + + +def silu(x): + return x*torch.sigmoid(x) + + +class SiLU(nn.Module): + def __init__(self): + super(SiLU, self).__init__() + + def forward(self, x): + return silu(x) + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1. - logits_real)) + loss_fake = torch.mean(F.relu(1. + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake))) + return d_loss + + +class VQGAN(pl.LightningModule): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.embedding_dim = cfg.model.embedding_dim # 8 + self.n_codes = cfg.model.n_codes # 16384 + + self.encoder = Encoder(cfg.model.n_hiddens, # 16 + cfg.model.downsample, # [2, 2, 2] + cfg.dataset.image_channels, # 1 + cfg.model.norm_type, # group + cfg.model.padding_type, # replicate + cfg.model.num_groups, # 32 + ) + self.decoder = Decoder( + cfg.model.n_hiddens, cfg.model.downsample, cfg.dataset.image_channels, cfg.model.norm_type, cfg.model.num_groups) + self.enc_out_ch = self.encoder.out_channels + self.pre_vq_conv = SamePadConv3d( + self.enc_out_ch, cfg.model.embedding_dim, 1, padding_type=cfg.model.padding_type) + self.post_vq_conv = SamePadConv3d( + cfg.model.embedding_dim, self.enc_out_ch, 1) + + self.codebook = Codebook(cfg.model.n_codes, cfg.model.embedding_dim, + no_random_restart=cfg.model.no_random_restart, restart_thres=cfg.model.restart_thres) + + self.gan_feat_weight = cfg.model.gan_feat_weight + # TODO: Changed batchnorm from sync to normal + self.image_discriminator = NLayerDiscriminator( + cfg.dataset.image_channels, cfg.model.disc_channels, cfg.model.disc_layers, norm_layer=nn.BatchNorm2d) + self.video_discriminator = NLayerDiscriminator3D( + cfg.dataset.image_channels, cfg.model.disc_channels, cfg.model.disc_layers, norm_layer=nn.BatchNorm3d) + + if cfg.model.disc_loss_type == 'vanilla': + self.disc_loss = vanilla_d_loss + elif cfg.model.disc_loss_type == 'hinge': + self.disc_loss = hinge_d_loss + + self.perceptual_model = LPIPS().eval() + + self.image_gan_weight = cfg.model.image_gan_weight + self.video_gan_weight = cfg.model.video_gan_weight + + self.perceptual_weight = cfg.model.perceptual_weight + + self.l1_weight = cfg.model.l1_weight + self.save_hyperparameters() + + def encode(self, x, include_embeddings=False, quantize=True): + h = self.pre_vq_conv(self.encoder(x)) + if quantize: + vq_output = self.codebook(h) + if include_embeddings: + return vq_output['embeddings'], vq_output['encodings'] + else: + return vq_output['encodings'] + return h + + def decode(self, latent, quantize=False): + if quantize: + vq_output = self.codebook(latent) + latent = vq_output['encodings'] + h = F.embedding(latent, self.codebook.embeddings) + h = self.post_vq_conv(shift_dim(h, -1, 1)) + return self.decoder(h) + + def forward(self, x, optimizer_idx=None, log_image=False): + B, C, T, H, W = x.shape + z = self.pre_vq_conv(self.encoder(x)) # [2, 32, 32, 32, 32] [2, 8, 32, 32, 32] + vq_output = self.codebook(z) # ['embeddings', 'encodings', 'commitment_loss', 'perplexity'] + x_recon = self.decoder(self.post_vq_conv(vq_output['embeddings'])) # [2, 8, 32, 32, 32] [2, 32, 32, 32, 32] + + recon_loss = F.l1_loss(x_recon, x) * self.l1_weight + + # Selects one random 2D image from each 3D Image + frame_idx = torch.randint(0, T, [B]).cuda() + frame_idx_selected = frame_idx.reshape(-1, + 1, 1, 1, 1).repeat(1, C, 1, H, W) # [2, 1, 1, 64, 64] + frames = torch.gather(x, 2, frame_idx_selected).squeeze(2) # [2, 1, 64, 64] + frames_recon = torch.gather(x_recon, 2, frame_idx_selected).squeeze(2) # [2, 1, 64, 64] + + if log_image: + return frames, frames_recon, x, x_recon + + if optimizer_idx == 0: + # Autoencoder - train the "generator" + + # Perceptual loss + perceptual_loss = 0 + if self.perceptual_weight > 0: + perceptual_loss = self.perceptual_model( + frames, frames_recon).mean() * self.perceptual_weight + + # Discriminator loss (turned on after a certain epoch) + logits_image_fake, pred_image_fake = self.image_discriminator( + frames_recon) + logits_video_fake, pred_video_fake = self.video_discriminator( + x_recon) + g_image_loss = -torch.mean(logits_image_fake) + g_video_loss = -torch.mean(logits_video_fake) + g_loss = self.image_gan_weight*g_image_loss + self.video_gan_weight*g_video_loss + disc_factor = adopt_weight( + self.global_step, threshold=self.cfg.model.discriminator_iter_start) + aeloss = disc_factor * g_loss + + # GAN feature matching loss - tune features such that we get the same prediction result on the discriminator + image_gan_feat_loss = 0 + video_gan_feat_loss = 0 + feat_weights = 4.0 / (3 + 1) + if self.image_gan_weight > 0: + logits_image_real, pred_image_real = self.image_discriminator( + frames) + for i in range(len(pred_image_fake)-1): + image_gan_feat_loss += feat_weights * \ + F.l1_loss(pred_image_fake[i], pred_image_real[i].detach( + )) * (self.image_gan_weight > 0) + if self.video_gan_weight > 0: + logits_video_real, pred_video_real = self.video_discriminator( + x) + for i in range(len(pred_video_fake)-1): + video_gan_feat_loss += feat_weights * \ + F.l1_loss(pred_video_fake[i], pred_video_real[i].detach( + )) * (self.video_gan_weight > 0) + gan_feat_loss = disc_factor * self.gan_feat_weight * \ + (image_gan_feat_loss + video_gan_feat_loss) + + self.log("train/g_image_loss", g_image_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/g_video_loss", g_video_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/image_gan_feat_loss", image_gan_feat_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/video_gan_feat_loss", video_gan_feat_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/perceptual_loss", perceptual_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log("train/recon_loss", recon_loss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + self.log("train/aeloss", aeloss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + self.log("train/commitment_loss", vq_output['commitment_loss'], + prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log('train/perplexity', vq_output['perplexity'], + prog_bar=True, logger=True, on_step=True, on_epoch=True) + return recon_loss, x_recon, vq_output, aeloss, perceptual_loss, gan_feat_loss + + if optimizer_idx == 1: + # Train discriminator + logits_image_real, _ = self.image_discriminator(frames.detach()) + logits_video_real, _ = self.video_discriminator(x.detach()) + + logits_image_fake, _ = self.image_discriminator( + frames_recon.detach()) + logits_video_fake, _ = self.video_discriminator(x_recon.detach()) + + d_image_loss = self.disc_loss(logits_image_real, logits_image_fake) + d_video_loss = self.disc_loss(logits_video_real, logits_video_fake) + disc_factor = adopt_weight( + self.global_step, threshold=self.cfg.model.discriminator_iter_start) + discloss = disc_factor * \ + (self.image_gan_weight*d_image_loss + + self.video_gan_weight*d_video_loss) + + self.log("train/logits_image_real", logits_image_real.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_image_fake", logits_image_fake.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_video_real", logits_video_real.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_video_fake", logits_video_fake.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/d_image_loss", d_image_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/d_video_loss", d_video_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/discloss", discloss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + return discloss + + perceptual_loss = self.perceptual_model( + frames, frames_recon) * self.perceptual_weight + return recon_loss, x_recon, vq_output, perceptual_loss + + def training_step(self, batch, batch_idx, optimizer_idx): + x = batch['image'] + if optimizer_idx == 0: + recon_loss, _, vq_output, aeloss, perceptual_loss, gan_feat_loss = self.forward( + x, optimizer_idx) + commitment_loss = vq_output['commitment_loss'] + loss = recon_loss + commitment_loss + aeloss + perceptual_loss + gan_feat_loss + if optimizer_idx == 1: + discloss = self.forward(x, optimizer_idx) + loss = discloss + return loss + + def validation_step(self, batch, batch_idx): + x = batch['image'] # TODO: batch['stft'] + recon_loss, _, vq_output, perceptual_loss = self.forward(x) + self.log('val/recon_loss', recon_loss, prog_bar=True) + self.log('val/perceptual_loss', perceptual_loss, prog_bar=True) + self.log('val/perplexity', vq_output['perplexity'], prog_bar=True) + self.log('val/commitment_loss', + vq_output['commitment_loss'], prog_bar=True) + + def configure_optimizers(self): + lr = self.cfg.model.lr + opt_ae = torch.optim.Adam(list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.pre_vq_conv.parameters()) + + list(self.post_vq_conv.parameters()) + + list(self.codebook.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(list(self.image_discriminator.parameters()) + + list(self.video_discriminator.parameters()), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def log_images(self, batch, **kwargs): + log = dict() + x = batch['image'] + x = x.to(self.device) + frames, frames_rec, _, _ = self(x, log_image=True) + log["inputs"] = frames + log["reconstructions"] = frames_rec + #log['mean_org'] = batch['mean_org'] + #log['std_org'] = batch['std_org'] + return log + + def log_videos(self, batch, **kwargs): + log = dict() + x = batch['image'] + _, _, x, x_rec = self(x, log_image=True) + log["inputs"] = x + log["reconstructions"] = x_rec + #log['mean_org'] = batch['mean_org'] + #log['std_org'] = batch['std_org'] + return log + + +def Normalize(in_channels, norm_type='group', num_groups=32): + assert norm_type in ['group', 'batch'] + if norm_type == 'group': + # TODO Changed num_groups from 32 to 8 + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == 'batch': + return torch.nn.SyncBatchNorm(in_channels) + + +class Encoder(nn.Module): + def __init__(self, n_hiddens, downsample, image_channel=3, norm_type='group', padding_type='replicate', num_groups=32): + super().__init__() + n_times_downsample = np.array([int(math.log2(d)) for d in downsample]) + self.conv_blocks = nn.ModuleList() + max_ds = n_times_downsample.max() + + self.conv_first = SamePadConv3d( + image_channel, n_hiddens, kernel_size=3, padding_type=padding_type) + + for i in range(max_ds): + block = nn.Module() + in_channels = n_hiddens * 2**i + out_channels = n_hiddens * 2**(i+1) + stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) + block.down = SamePadConv3d( + in_channels, out_channels, 4, stride=stride, padding_type=padding_type) + block.res = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + self.conv_blocks.append(block) + n_times_downsample -= 1 + + self.final_block = nn.Sequential( + Normalize(out_channels, norm_type, num_groups=num_groups), + SiLU() + ) + + self.out_channels = out_channels + + def forward(self, x): + h = self.conv_first(x) + for block in self.conv_blocks: + h = block.down(h) + h = block.res(h) + h = self.final_block(h) + return h + + +class Decoder(nn.Module): + def __init__(self, n_hiddens, upsample, image_channel, norm_type='group', num_groups=32): + super().__init__() + + n_times_upsample = np.array([int(math.log2(d)) for d in upsample]) + max_us = n_times_upsample.max() + + in_channels = n_hiddens*2**max_us + self.final_block = nn.Sequential( + Normalize(in_channels, norm_type, num_groups=num_groups), + SiLU() + ) + + self.conv_blocks = nn.ModuleList() + for i in range(max_us): + block = nn.Module() + in_channels = in_channels if i == 0 else n_hiddens*2**(max_us-i+1) + out_channels = n_hiddens*2**(max_us-i) + us = tuple([2 if d > 0 else 1 for d in n_times_upsample]) + block.up = SamePadConvTranspose3d( + in_channels, out_channels, 4, stride=us) + block.res1 = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + block.res2 = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + self.conv_blocks.append(block) + n_times_upsample -= 1 + + self.conv_last = SamePadConv3d( + out_channels, image_channel, kernel_size=3) + + def forward(self, x): + h = self.final_block(x) + for i, block in enumerate(self.conv_blocks): + h = block.up(h) + h = block.res1(h) + h = block.res2(h) + h = self.conv_last(h) + return h + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group', padding_type='replicate', num_groups=32): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, norm_type, num_groups=num_groups) + self.conv1 = SamePadConv3d( + in_channels, out_channels, kernel_size=3, padding_type=padding_type) + self.dropout = torch.nn.Dropout(dropout) + self.norm2 = Normalize(in_channels, norm_type, num_groups=num_groups) + self.conv2 = SamePadConv3d( + out_channels, out_channels, kernel_size=3, padding_type=padding_type) + if self.in_channels != self.out_channels: + self.conv_shortcut = SamePadConv3d( + in_channels, out_channels, kernel_size=3, padding_type=padding_type) + + def forward(self, x): + h = x + h = self.norm1(h) + h = silu(h) + h = self.conv1(h) + h = self.norm2(h) + h = silu(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.conv_shortcut(x) + + return x+h + + +# Does not support dilation +class SamePadConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + if isinstance(stride, int): + stride = (stride,) * 3 + + # assumes that the input shape is divisible by stride + total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) + pad_input = [] + for p in total_pad[::-1]: # reverse since F.pad starts from last dim + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + self.pad_input = pad_input + self.padding_type = padding_type + + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, + stride=stride, padding=0, bias=bias) + + def forward(self, x): + return self.conv(F.pad(x, self.pad_input, mode=self.padding_type)) + + +class SamePadConvTranspose3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + if isinstance(stride, int): + stride = (stride,) * 3 + + total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) + pad_input = [] + for p in total_pad[::-1]: # reverse since F.pad starts from last dim + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + self.pad_input = pad_input + self.padding_type = padding_type + + self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, + stride=stride, bias=bias, + padding=tuple([k - 1 for k in kernel_size])) + + def forward(self, x): + return self.convt(F.pad(x, self.pad_input, mode=self.padding_type)) + + +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True): + # def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=True): + super(NLayerDiscriminator, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, + stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, + stride=1, padding=padw)]] + + if use_sigmoid: + sequence += [[nn.Sigmoid()]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[-1], res[1:] + else: + return self.model(input), _ + + +class NLayerDiscriminator3D(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True): + super(NLayerDiscriminator3D, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv3d(input_nc, ndf, kernel_size=kw, + stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv3d(nf, 1, kernel_size=kw, + stride=1, padding=padw)]] + + if use_sigmoid: + sequence += [[nn.Sigmoid()]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[-1], res[1:] + else: + return self.model(input), _ diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/utils.py b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..acf1a21270d98d0a1ea9935d00f9f16d89d54551 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/ldm/vq_gan_3d/utils.py @@ -0,0 +1,177 @@ +""" Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import warnings +import torch +import imageio + +import math +import numpy as np + +import sys +import pdb as pdb_original +import logging + +import imageio.core.util +logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR) + + +class ForkedPdb(pdb_original.Pdb): + """A Pdb subclass that may be used + from a forked multiprocessing child + + """ + + def interaction(self, *args, **kwargs): + _stdin = sys.stdin + try: + sys.stdin = open('/dev/stdin') + pdb_original.Pdb.interaction(self, *args, **kwargs) + finally: + sys.stdin = _stdin + + +# Shifts src_tf dim to dest dim +# i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c) +def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): + n_dims = len(x.shape) + if src_dim < 0: + src_dim = n_dims + src_dim + if dest_dim < 0: + dest_dim = n_dims + dest_dim + + assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims + + dims = list(range(n_dims)) + del dims[src_dim] + + permutation = [] + ctr = 0 + for i in range(n_dims): + if i == dest_dim: + permutation.append(src_dim) + else: + permutation.append(dims[ctr]) + ctr += 1 + x = x.permute(permutation) + if make_contiguous: + x = x.contiguous() + return x + + +# reshapes tensor start from dim i (inclusive) +# to dim j (exclusive) to the desired shape +# e.g. if x.shape = (b, thw, c) then +# view_range(x, 1, 2, (t, h, w)) returns +# x of shape (b, t, h, w, c) +def view_range(x, i, j, shape): + shape = tuple(shape) + + n_dims = len(x.shape) + if i < 0: + i = n_dims + i + + if j is None: + j = n_dims + elif j < 0: + j = n_dims + j + + assert 0 <= i < j <= n_dims + + x_shape = x.shape + target_shape = x_shape[:i] + shape + x_shape[j:] + return x.view(target_shape) + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def tensor_slice(x, begin, size): + assert all([b >= 0 for b in begin]) + size = [l - b if s == -1 else s + for s, b, l in zip(size, begin, x.shape)] + assert all([s >= 0 for s in size]) + + slices = [slice(b, b + s) for b, s in zip(begin, size)] + return x[slices] + + +def adopt_weight(global_step, threshold=0, value=0.): + weight = 1 + if global_step < threshold: + weight = value + return weight + + +def save_video_grid(video, fname, nrow=None, fps=6): + b, c, t, h, w = video.shape + video = video.permute(0, 2, 3, 4, 1) + video = (video.cpu().numpy() * 255).astype('uint8') + if nrow is None: + nrow = math.ceil(math.sqrt(b)) + ncol = math.ceil(b / nrow) + padding = 1 + video_grid = np.zeros((t, (padding + h) * nrow + padding, + (padding + w) * ncol + padding, c), dtype='uint8') + for i in range(b): + r = i // ncol + c = i % ncol + start_r = (padding + h) * r + start_c = (padding + w) * c + video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] + video = [] + for i in range(t): + video.append(video_grid[i]) + imageio.mimsave(fname, video, fps=fps) + ## skvideo.io.vwrite(fname, video_grid, inputdict={'-r': '5'}) + #print('saved videos to', fname) + + +def comp_getattr(args, attr_name, default=None): + if hasattr(args, attr_name): + return getattr(args, attr_name) + else: + return default + + +def visualize_tensors(t, name=None, nest=0): + if name is not None: + print(name, "current nest: ", nest) + print("type: ", type(t)) + if 'dict' in str(type(t)): + print(t.keys()) + for k in t.keys(): + if t[k] is None: + print(k, "None") + else: + if 'Tensor' in str(type(t[k])): + print(k, t[k].shape) + elif 'dict' in str(type(t[k])): + print(k, 'dict') + visualize_tensors(t[k], name, nest + 1) + elif 'list' in str(type(t[k])): + print(k, len(t[k])) + visualize_tensors(t[k], name, nest + 1) + elif 'list' in str(type(t)): + print("list length: ", len(t)) + for t2 in t: + visualize_tensors(t2, name, nest + 1) + elif 'Tensor' in str(type(t)): + print(t.shape) + else: + print(t) + return "" diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/utils.py b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..88e303a085dc90228553170866ec732e2cd86bcd --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/utils.py @@ -0,0 +1,471 @@ +### Tumor Generateion +import random +import cv2 +import elasticdeform +import numpy as np +from scipy.ndimage import gaussian_filter +from TumorGeneration.ldm.ddpm.ddim import DDIMSampler + +# Step 1: Random select (numbers) location for tumor. +def random_select(mask_scan, organ_type): + # we first find z index and then sample point with z slice + # print('mask_scan',np.unique(mask_scan)) + # print('pixel num', (mask_scan == 1).sum()) + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0].min(), np.where(np.any(mask_scan, axis=(0, 1)))[0].max() + # print('z_start, z_end',z_start, z_end) + # we need to strict number z's position (0.3 - 0.7 in the middle of liver) + flag=0 + while 1: + z = round(random.uniform(0.3, 0.7) * (z_end - z_start)) + z_start + liver_mask = mask_scan[..., z] + # erode the mask (we don't want the edge points) + if organ_type == 'liver': + flag+=1 + if flag <= 10: + kernel = np.ones((5,5), dtype=np.uint8) + liver_mask = cv2.erode(liver_mask, kernel, iterations=1) + elif flag >10 and flag <= 20: + kernel = np.ones((3,3), dtype=np.uint8) + liver_mask = cv2.erode(liver_mask, kernel, iterations=1) + else: + pass + print(flag) + if (liver_mask == 1).sum() > 0: + break + + # print('liver_mask', (liver_mask == 1).sum()) + coordinates = np.argwhere(liver_mask == 1) + random_index = np.random.randint(0, len(coordinates)) + xyz = coordinates[random_index].tolist() # get x,y + xyz.append(z) + potential_points = xyz + + return potential_points + +def center_select(mask_scan): + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0].min(), np.where(np.any(mask_scan, axis=(0, 1)))[0].max() + x_start, x_end = np.where(np.any(mask_scan, axis=(1, 2)))[0].min(), np.where(np.any(mask_scan, axis=(1, 2)))[0].max() + y_start, y_end = np.where(np.any(mask_scan, axis=(0, 2)))[0].min(), np.where(np.any(mask_scan, axis=(0, 2)))[0].max() + + z = round(0.5 * (z_end - z_start)) + z_start + x = round(0.5 * (x_end - x_start)) + x_start + y = round(0.5 * (y_end - y_start)) + y_start + + xyz = [x, y, z] + potential_points = xyz + + return potential_points + +# Step 2 : generate the ellipsoid +def get_ellipsoid(x, y, z): + """" + x, y, z is the radius of this ellipsoid in x, y, z direction respectly. + """ + sh = (4*x, 4*y, 4*z) + out = np.zeros(sh, int) + aux = np.zeros(sh) + radii = np.array([x, y, z]) + com = np.array([2*x, 2*y, 2*z]) # center point + + # calculate the ellipsoid + bboxl = np.floor(com-radii).clip(0,None).astype(int) + bboxh = (np.ceil(com+radii)+1).clip(None, sh).astype(int) + roi = out[tuple(map(slice,bboxl,bboxh))] + roiaux = aux[tuple(map(slice,bboxl,bboxh))] + logrid = *map(np.square,np.ogrid[tuple( + map(slice,(bboxl-com)/radii,(bboxh-com-1)/radii,1j*(bboxh-bboxl)))]), + dst = (1-sum(logrid)).clip(0,None) + mask = dst>roiaux + roi[mask] = 1 + np.copyto(roiaux,dst,where=mask) + + return out + +def get_fixed_geo(mask_scan, tumor_type, organ_type): + if tumor_type == 'large': + enlarge_x, enlarge_y, enlarge_z = 280, 280, 280 + else: + enlarge_x, enlarge_y, enlarge_z = 160, 160, 160 + geo_mask = np.zeros((mask_scan.shape[0] + enlarge_x, mask_scan.shape[1] + enlarge_y, mask_scan.shape[2] + enlarge_z), dtype=np.int8) + tiny_radius, small_radius, medium_radius, large_radius = 4, 8, 16, 32 + + if tumor_type == 'tiny': + num_tumor = random.randint(1,3) + # num_tumor = 1 + for _ in range(num_tumor): + # Tiny tumor + x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + sigma = random.uniform(0.5,1) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + + if tumor_type == 'small': + num_tumor = random.randint(1,3) + # num_tumor = 1 + for _ in range(num_tumor): + # Small tumor + x = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + y = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + z = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + sigma = random.randint(1, 2) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'medium': + # num_tumor = random.randint(1, 3) + num_tumor = 1 + for _ in range(num_tumor): + # medium tumor + x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + sigma = random.randint(3, 6) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste medium tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'large': + num_tumor = 1 # random.randint(1,3) + for _ in range(num_tumor): + # Large tumor + + x = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + y = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + z = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + sigma = random.randint(5, 10) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + if organ_type == 'liver' or organ_type == 'kidney' : + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + else: + x_start, x_end = np.where(np.any(geo, axis=(1, 2)))[0].min(), np.where(np.any(geo, axis=(1, 2)))[0].max() + y_start, y_end = np.where(np.any(geo, axis=(0, 2)))[0].min(), np.where(np.any(geo, axis=(0, 2)))[0].max() + z_start, z_end = np.where(np.any(geo, axis=(0, 1)))[0].min(), np.where(np.any(geo, axis=(0, 1)))[0].max() + geo = geo[x_start:x_end, y_start:y_end, z_start:z_end] + + point = center_select(mask_scan) + + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low = new_point[0] - geo.shape[0]//2 + y_low = new_point[1] - geo.shape[1]//2 + z_low = new_point[2] - geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_low+geo.shape[0], y_low:y_low+geo.shape[1], z_low:z_low+geo.shape[2]] += geo + + geo_mask = geo_mask[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + # texture_map = texture_map[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + if ((tumor_type == 'medium') or (tumor_type == 'large')) and (organ_type == 'kidney'): + if random.random() > 0.5: + geo_mask = (geo_mask>=1) + else: + geo_mask = (geo_mask * mask_scan) >=1 + else: + geo_mask = (geo_mask * mask_scan) >=1 + + return geo_mask + + +from .ldm.vq_gan_3d.model.vqgan import VQGAN +import matplotlib.pyplot as plt +import SimpleITK as sitk +from .ldm.ddpm import Unet3D, GaussianDiffusion, Tester +from hydra import initialize, compose +import torch +import yaml +def synt_model_prepare(device, vqgan_ckpt='TumorGeneration/model_weight/recon_96d4_all.ckpt', diffusion_ckpt='TumorGeneration/model_weight/', fold=0, organ='liver'): + with initialize(config_path="diffusion_config/"): + cfg=compose(config_name="ddpm.yaml") + print('diffusion_ckpt',diffusion_ckpt) + vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt) + vqgan = vqgan.to(device) + vqgan.eval() + + early_Unet3D = Unet3D( + dim=cfg.diffusion_img_size, + dim_mults=cfg.dim_mults, + channels=cfg.diffusion_num_channels, + out_dim=cfg.out_dim + ).to(device) + + early_diffusion = GaussianDiffusion( + early_Unet3D, + vqgan_ckpt= vqgan_ckpt, # cfg.vqgan_ckpt, + image_size=cfg.diffusion_img_size, + num_frames=cfg.diffusion_depth_size, + channels=cfg.diffusion_num_channels, + timesteps=4, # cfg.timesteps, + loss_type=cfg.loss_type, + device=device + ).to(device) + + noearly_Unet3D = Unet3D( + dim=cfg.diffusion_img_size, + dim_mults=cfg.dim_mults, + channels=cfg.diffusion_num_channels, + out_dim=cfg.out_dim + ).to(device) + + noearly_diffusion = GaussianDiffusion( + noearly_Unet3D, + vqgan_ckpt= vqgan_ckpt, # cfg.vqgan_ckpt, + image_size=cfg.diffusion_img_size, + num_frames=cfg.diffusion_depth_size, + channels=cfg.diffusion_num_channels, + timesteps=200, # cfg.timesteps, + loss_type=cfg.loss_type, + device=device + ).to(device) + + early_tester = Tester(early_diffusion) + # noearly_tester = Tester(noearly_diffusion) + early_tester.load(diffusion_ckpt+'diff_{}_fold{}_early.pt'.format(organ, fold), map_location=device) + # noearly_tester.load(diffusion_ckpt+'diff_liver_fold{}_noearly_t200.pt'.format(fold), map_location=device) + + # early_checkpoint = torch.load(diffusion_ckpt+'diff_liver_fold{}_early.pt'.format(fold), map_location=device) + noearly_checkpoint = torch.load(diffusion_ckpt+'diff_{}_fold{}_noearly_t200.pt'.format(organ, fold), map_location=device) + # early_diffusion.load_state_dict(early_checkpoint['ema']) + noearly_diffusion.load_state_dict(noearly_checkpoint['ema']) + # early_sampler = DDIMSampler(early_diffusion, schedule="cosine") + noearly_sampler = DDIMSampler(noearly_diffusion, schedule="cosine") + # breakpoint() + return vqgan, early_tester, noearly_sampler + +def synthesize_early_tumor(ct_volume, organ_mask, organ_type, vqgan, tester): + device=ct_volume.device + + # generate tumor mask + tumor_types = ['tiny', 'small'] + tumor_probs = np.array([0.5, 0.5]) + total_tumor_mask = [] + organ_mask_np = organ_mask.cpu().numpy() + with torch.no_grad(): + # get model input + for bs in range(organ_mask_np.shape[0]): + synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type) + total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:]) + total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device) + + volume = ct_volume*2.0 - 1.0 + mask = total_tumor_mask*2.0 - 1.0 + mask_ = 1-total_tumor_mask + masked_volume = (volume*mask_).detach() + + volume = volume.permute(0,1,-1,-3,-2) + masked_volume = masked_volume.permute(0,1,-1,-3,-2) + mask = mask.permute(0,1,-1,-3,-2) + + # vqgan encoder inference + masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True) + masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) / + (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:]) + cond = torch.cat((masked_volume_feat, cc), dim=1) + + # diffusion inference and decoder + tester.ema_model.eval() + sample = tester.ema_model.sample(batch_size=volume.shape[0], cond=cond) + + # if organ_type == 'liver' or organ_type == 'kidney' : + + mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0) + sigma = np.random.uniform(0, 4) # (1, 2) + mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma]) + # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy() + + volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0) + sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0) + + mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device) + final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_ + final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0) + # elif organ_type == 'pancreas': + # final_volume_ = (sample+1.0)/2.0 + final_volume_ = final_volume_.permute(0,1,-2,-1,-3) + organ_tumor_mask = torch.ones_like(organ_mask) + organ_tumor_mask[total_tumor_mask==1] = 2 + + return final_volume_, organ_tumor_mask + +def synthesize_medium_tumor(ct_volume, organ_mask, organ_type, vqgan, sampler, ddim_ts=50): + device=ct_volume.device + + # generate tumor mask + # tumor_types = ['large'] + # tumor_probs = np.array([1.0]) + total_tumor_mask = [] + organ_mask_np = organ_mask.cpu().numpy() + with torch.no_grad(): + # get model input + for bs in range(organ_mask_np.shape[0]): + # synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + synthetic_tumor_type = 'medium' + tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type) + total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:]) + total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device) + + volume = ct_volume*2.0 - 1.0 + mask = total_tumor_mask*2.0 - 1.0 + mask_ = 1-total_tumor_mask + masked_volume = (volume*mask_).detach() + + volume = volume.permute(0,1,-1,-3,-2) + masked_volume = masked_volume.permute(0,1,-1,-3,-2) + mask = mask.permute(0,1,-1,-3,-2) + + # vqgan encoder inference + masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True) + masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) / + (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:]) + cond = torch.cat((masked_volume_feat, cc), dim=1) + + # diffusion inference and decoder + shape = masked_volume_feat.shape[-4:] + samples_ddim, _ = sampler.sample(S=ddim_ts, + conditioning=cond, + batch_size=1, + shape=shape, + verbose=False) + samples_ddim = (((samples_ddim + 1.0) / 2.0) * (vqgan.codebook.embeddings.max() - + vqgan.codebook.embeddings.min())) + vqgan.codebook.embeddings.min() + + sample = vqgan.decode(samples_ddim, quantize=True) + + # if organ_type == 'liver' or organ_type == 'kidney': + # post-process + mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0) + sigma = np.random.uniform(0, 4) # (1, 2) + mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma]) + # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy() + + volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0) + sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0) + + mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device) + final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_ + final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0) + # elif organ_type == 'pancreas': + # final_volume_ = (sample+1.0)/2.0 + + final_volume_ = final_volume_.permute(0,1,-2,-1,-3) + organ_tumor_mask = torch.ones_like(organ_mask) + organ_tumor_mask[total_tumor_mask==1] = 2 + + return final_volume_, organ_tumor_mask + +def synthesize_large_tumor(ct_volume, organ_mask, organ_type, vqgan, sampler, ddim_ts=50): + device=ct_volume.device + + # generate tumor mask + # tumor_types = ['large'] + # tumor_probs = np.array([1.0]) + total_tumor_mask = [] + organ_mask_np = organ_mask.cpu().numpy() + with torch.no_grad(): + # get model input + for bs in range(organ_mask_np.shape[0]): + # synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + synthetic_tumor_type = 'large' + tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type) + total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:]) + total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device) + + volume = ct_volume*2.0 - 1.0 + mask = total_tumor_mask*2.0 - 1.0 + mask_ = 1-total_tumor_mask + masked_volume = (volume*mask_).detach() + + volume = volume.permute(0,1,-1,-3,-2) + masked_volume = masked_volume.permute(0,1,-1,-3,-2) + mask = mask.permute(0,1,-1,-3,-2) + + # vqgan encoder inference + masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True) + masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) / + (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:]) + cond = torch.cat((masked_volume_feat, cc), dim=1) + + # diffusion inference and decoder + shape = masked_volume_feat.shape[-4:] + samples_ddim, _ = sampler.sample(S=ddim_ts, + conditioning=cond, + batch_size=1, + shape=shape, + verbose=False) + samples_ddim = (((samples_ddim + 1.0) / 2.0) * (vqgan.codebook.embeddings.max() - + vqgan.codebook.embeddings.min())) + vqgan.codebook.embeddings.min() + + sample = vqgan.decode(samples_ddim, quantize=True) + + # if organ_type == 'liver' or organ_type == 'kidney': + # post-process + mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0) + sigma = np.random.uniform(0, 4) # (1, 2) + mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma]) + # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy() + + volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0) + sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0) + + mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device) + final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_ + final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0) + # elif organ_type == 'pancreas': + # final_volume_ = (sample+1.0)/2.0 + + final_volume_ = final_volume_.permute(0,1,-2,-1,-3) + organ_tumor_mask = torch.ones_like(organ_mask) + organ_tumor_mask[total_tumor_mask==1] = 2 + + return final_volume_, organ_tumor_mask \ No newline at end of file diff --git a/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/utils_.py b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/utils_.py new file mode 100644 index 0000000000000000000000000000000000000000..312e72d1571b3cd996fd567bcedf939443a4e182 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_liver/TumorGeneration/utils_.py @@ -0,0 +1,298 @@ +### Tumor Generateion +import random +import cv2 +import elasticdeform +import numpy as np +from scipy.ndimage import gaussian_filter + +# Step 1: Random select (numbers) location for tumor. +def random_select(mask_scan): + # we first find z index and then sample point with z slice + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0][[0, -1]] + + # we need to strict number z's position (0.3 - 0.7 in the middle of liver) + z = round(random.uniform(0.3, 0.7) * (z_end - z_start)) + z_start + + liver_mask = mask_scan[..., z] + + # erode the mask (we don't want the edge points) + kernel = np.ones((5,5), dtype=np.uint8) + liver_mask = cv2.erode(liver_mask, kernel, iterations=1) + + coordinates = np.argwhere(liver_mask == 1) + random_index = np.random.randint(0, len(coordinates)) + xyz = coordinates[random_index].tolist() # get x,y + xyz.append(z) + potential_points = xyz + + return potential_points + +# Step 2 : generate the ellipsoid +def get_ellipsoid(x, y, z): + """" + x, y, z is the radius of this ellipsoid in x, y, z direction respectly. + """ + sh = (4*x, 4*y, 4*z) + out = np.zeros(sh, int) + aux = np.zeros(sh) + radii = np.array([x, y, z]) + com = np.array([2*x, 2*y, 2*z]) # center point + + # calculate the ellipsoid + bboxl = np.floor(com-radii).clip(0,None).astype(int) + bboxh = (np.ceil(com+radii)+1).clip(None, sh).astype(int) + roi = out[tuple(map(slice,bboxl,bboxh))] + roiaux = aux[tuple(map(slice,bboxl,bboxh))] + logrid = *map(np.square,np.ogrid[tuple( + map(slice,(bboxl-com)/radii,(bboxh-com-1)/radii,1j*(bboxh-bboxl)))]), + dst = (1-sum(logrid)).clip(0,None) + mask = dst>roiaux + roi[mask] = 1 + np.copyto(roiaux,dst,where=mask) + + return out + +def get_fixed_geo(mask_scan, tumor_type): + + enlarge_x, enlarge_y, enlarge_z = 160, 160, 160 + geo_mask = np.zeros((mask_scan.shape[0] + enlarge_x, mask_scan.shape[1] + enlarge_y, mask_scan.shape[2] + enlarge_z), dtype=np.int8) + tiny_radius, small_radius, medium_radius, large_radius = 4, 8, 16, 32 + + if tumor_type == 'tiny': + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Tiny tumor + x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + sigma = random.uniform(0.5,1) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + + if tumor_type == 'small': + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Small tumor + x = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + y = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + z = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + sigma = random.randint(1, 2) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'medium': + num_tumor = random.randint(2, 5) + for _ in range(num_tumor): + # medium tumor + x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + sigma = random.randint(3, 6) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste medium tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'large': + num_tumor = random.randint(1,3) + for _ in range(num_tumor): + # Large tumor + x = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + y = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + z = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + sigma = random.randint(5, 10) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == "mix": + # tiny + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Tiny tumor + x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + sigma = random.uniform(0.5,1) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + + # small + num_tumor = random.randint(5,10) + for _ in range(num_tumor): + # Small tumor + x = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + y = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + z = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + sigma = random.randint(1, 2) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + # medium + num_tumor = random.randint(2, 5) + for _ in range(num_tumor): + # medium tumor + x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + sigma = random.randint(3, 6) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste medium tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + # large + num_tumor = random.randint(1,3) + for _ in range(num_tumor): + # Large tumor + x = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + y = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + z = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + sigma = random.randint(5, 10) + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + geo_mask = geo_mask[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + # texture_map = texture_map[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + geo_mask = (geo_mask * mask_scan) >=1 + + return geo_mask + + +def get_tumor(volume_scan, mask_scan, tumor_type): + tumor_mask = get_fixed_geo(mask_scan, tumor_type) + + sigma = np.random.uniform(1, 2) + # difference = np.random.uniform(65, 145) + difference = 1 + + # blur the boundary + tumor_mask_blur = gaussian_filter(tumor_mask*1.0, sigma) + + + abnormally_full = volume_scan * (1 - mask_scan) + abnormally + abnormally_mask = mask_scan + geo_mask + + return abnormally_full, abnormally_mask + +def SynthesisTumor(volume_scan, mask_scan, tumor_type): + # for speed_generate_tumor, we only send the liver part into the generate program + x_start, x_end = np.where(np.any(mask_scan, axis=(1, 2)))[0][[0, -1]] + y_start, y_end = np.where(np.any(mask_scan, axis=(0, 2)))[0][[0, -1]] + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0][[0, -1]] + + # shrink the boundary + x_start, x_end = max(0, x_start+1), min(mask_scan.shape[0], x_end-1) + y_start, y_end = max(0, y_start+1), min(mask_scan.shape[1], y_end-1) + z_start, z_end = max(0, z_start+1), min(mask_scan.shape[2], z_end-1) + + ct_volume = volume_scan[x_start:x_end, y_start:y_end, z_start:z_end] + organ_mask = mask_scan[x_start:x_end, y_start:y_end, z_start:z_end] + + # input texture shape: 420, 300, 320 + # we need to cut it into the shape of liver_mask + # for examples, the liver_mask.shape = 286, 173, 46; we should change the texture shape + x_length, y_length, z_length = 64, 64, 64 + crop_x = random.randint(x_start, x_end - x_length - 1) # random select the start point, -1 is to avoid boundary check + crop_y = random.randint(y_start, y_end - y_length - 1) + crop_z = random.randint(z_start, z_end - z_length - 1) + + ct_volume, organ_tumor_mask = get_tumor(ct_volume, organ_mask, tumor_type) + volume_scan[x_start:x_end, y_start:y_end, z_start:z_end] = ct_volume + mask_scan[x_start:x_end, y_start:y_end, z_start:z_end] = organ_tumor_mask + + return volume_scan, mask_scan diff --git a/Generation_Pipeline_filter_all/syn_liver/healthy_liver_1k.txt b/Generation_Pipeline_filter_all/syn_liver/healthy_liver_1k.txt new file mode 100644 index 0000000000000000000000000000000000000000..74ed74167da166c49bfa98088ad2683251771bb4 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_liver/healthy_liver_1k.txt @@ -0,0 +1,895 @@ +BDMAP_00004578 +BDMAP_00004183 +BDMAP_00002690 +BDMAP_00004295 +BDMAP_00001736 +BDMAP_00000411 +BDMAP_00003277 +BDMAP_00000696 +BDMAP_00004196 +BDMAP_00001598 +BDMAP_00001183 +BDMAP_00002626 +BDMAP_00004793 +BDMAP_00003385 +BDMAP_00005037 +BDMAP_00004652 +BDMAP_00001383 +BDMAP_00001092 +BDMAP_00004927 +BDMAP_00001618 +BDMAP_00004087 +BDMAP_00002273 +BDMAP_00001288 +BDMAP_00000043 +BDMAP_00003356 +BDMAP_00002776 +BDMAP_00003961 +BDMAP_00002422 +BDMAP_00000345 +BDMAP_00000438 +BDMAP_00001517 +BDMAP_00003564 +BDMAP_00001275 +BDMAP_00003315 +BDMAP_00002986 +BDMAP_00003514 +BDMAP_00000190 +BDMAP_00001434 +BDMAP_00003608 +BDMAP_00001995 +BDMAP_00000414 +BDMAP_00003451 +BDMAP_00002612 +BDMAP_00003744 +BDMAP_00005170 +BDMAP_00002328 +BDMAP_00002940 +BDMAP_00005020 +BDMAP_00000562 +BDMAP_00000810 +BDMAP_00003833 +BDMAP_00000320 +BDMAP_00001791 +BDMAP_00004895 +BDMAP_00003576 +BDMAP_00001924 +BDMAP_00005140 +BDMAP_00003946 +BDMAP_00005067 +BDMAP_00001102 +BDMAP_00001826 +BDMAP_00004131 +BDMAP_00003141 +BDMAP_00002758 +BDMAP_00004969 +BDMAP_00003633 +BDMAP_00004195 +BDMAP_00000030 +BDMAP_00000939 +BDMAP_00001835 +BDMAP_00003762 +BDMAP_00003215 +BDMAP_00003396 +BDMAP_00001078 +BDMAP_00003484 +BDMAP_00001096 +BDMAP_00001688 +BDMAP_00005155 +BDMAP_00005064 +BDMAP_00001862 +BDMAP_00004867 +BDMAP_00001982 +BDMAP_00002295 +BDMAP_00000062 +BDMAP_00000715 +BDMAP_00004608 +BDMAP_00000162 +BDMAP_00003558 +BDMAP_00005070 +BDMAP_00003812 +BDMAP_00000725 +BDMAP_00004624 +BDMAP_00003752 +BDMAP_00001557 +BDMAP_00002185 +BDMAP_00000093 +BDMAP_00003774 +BDMAP_00001701 +BDMAP_00004184 +BDMAP_00000873 +BDMAP_00000236 +BDMAP_00001676 +BDMAP_00001635 +BDMAP_00002475 +BDMAP_00002653 +BDMAP_00003400 +BDMAP_00001863 +BDMAP_00003017 +BDMAP_00001283 +BDMAP_00001359 +BDMAP_00001281 +BDMAP_00004293 +BDMAP_00000582 +BDMAP_00001752 +BDMAP_00004910 +BDMAP_00003373 +BDMAP_00004297 +BDMAP_00003947 +BDMAP_00003612 +BDMAP_00003598 +BDMAP_00002746 +BDMAP_00004552 +BDMAP_00002333 +BDMAP_00002580 +BDMAP_00002871 +BDMAP_00001565 +BDMAP_00003549 +BDMAP_00003976 +BDMAP_00001712 +BDMAP_00001602 +BDMAP_00000812 +BDMAP_00000353 +BDMAP_00001251 +BDMAP_00004841 +BDMAP_00000429 +BDMAP_00000432 +BDMAP_00000159 +BDMAP_00002347 +BDMAP_00002496 +BDMAP_00004735 +BDMAP_00001514 +BDMAP_00003560 +BDMAP_00001209 +BDMAP_00002313 +BDMAP_00005092 +BDMAP_00005009 +BDMAP_00004673 +BDMAP_00000547 +BDMAP_00003255 +BDMAP_00000229 +BDMAP_00001522 +BDMAP_00002426 +BDMAP_00004015 +BDMAP_00004541 +BDMAP_00003952 +BDMAP_00003853 +BDMAP_00001119 +BDMAP_00004198 +BDMAP_00004427 +BDMAP_00004417 +BDMAP_00000833 +BDMAP_00002487 +BDMAP_00002981 +BDMAP_00000653 +BDMAP_00003815 +BDMAP_00003972 +BDMAP_00000373 +BDMAP_00002864 +BDMAP_00002902 +BDMAP_00001836 +BDMAP_00004897 +BDMAP_00002889 +BDMAP_00003493 +BDMAP_00000667 +BDMAP_00004163 +BDMAP_00004586 +BDMAP_00001704 +BDMAP_00002152 +BDMAP_00001258 +BDMAP_00003827 +BDMAP_00001265 +BDMAP_00001040 +BDMAP_00004106 +BDMAP_00000059 +BDMAP_00002363 +BDMAP_00000161 +BDMAP_00001475 +BDMAP_00001747 +BDMAP_00001027 +BDMAP_00000279 +BDMAP_00002242 +BDMAP_00004175 +BDMAP_00003358 +BDMAP_00004815 +BDMAP_00003580 +BDMAP_00001068 +BDMAP_00003327 +BDMAP_00004616 +BDMAP_00000197 +BDMAP_00003740 +BDMAP_00005074 +BDMAP_00001261 +BDMAP_00002775 +BDMAP_00002545 +BDMAP_00000104 +BDMAP_00004738 +BDMAP_00005099 +BDMAP_00004672 +BDMAP_00004074 +BDMAP_00004288 +BDMAP_00003590 +BDMAP_00001545 +BDMAP_00004922 +BDMAP_00002619 +BDMAP_00000874 +BDMAP_00001438 +BDMAP_00003138 +BDMAP_00002251 +BDMAP_00003769 +BDMAP_00003267 +BDMAP_00002216 +BDMAP_00003994 +BDMAP_00002742 +BDMAP_00001089 +BDMAP_00003957 +BDMAP_00001533 +BDMAP_00004636 +BDMAP_00004499 +BDMAP_00000698 +BDMAP_00002232 +BDMAP_00004250 +BDMAP_00004491 +BDMAP_00001636 +BDMAP_00005078 +BDMAP_00004121 +BDMAP_00001845 +BDMAP_00004264 +BDMAP_00000137 +BDMAP_00003516 +BDMAP_00005017 +BDMAP_00000087 +BDMAP_00000319 +BDMAP_00001828 +BDMAP_00000948 +BDMAP_00001977 +BDMAP_00003457 +BDMAP_00005157 +BDMAP_00003150 +BDMAP_00002166 +BDMAP_00003301 +BDMAP_00003680 +BDMAP_00003133 +BDMAP_00000574 +BDMAP_00002305 +BDMAP_00004843 +BDMAP_00002230 +BDMAP_00000332 +BDMAP_00003063 +BDMAP_00002076 +BDMAP_00003319 +BDMAP_00004373 +BDMAP_00004880 +BDMAP_00000623 +BDMAP_00003631 +BDMAP_00001737 +BDMAP_00001057 +BDMAP_00002173 +BDMAP_00000139 +BDMAP_00001891 +BDMAP_00000552 +BDMAP_00004717 +BDMAP_00003172 +BDMAP_00003955 +BDMAP_00001664 +BDMAP_00003070 +BDMAP_00004550 +BDMAP_00002057 +BDMAP_00000616 +BDMAP_00000913 +BDMAP_00000388 +BDMAP_00000355 +BDMAP_00003333 +BDMAP_00004148 +BDMAP_00001985 +BDMAP_00001921 +BDMAP_00001624 +BDMAP_00004129 +BDMAP_00002598 +BDMAP_00000859 +BDMAP_00000558 +BDMAP_00002226 +BDMAP_00000452 +BDMAP_00004829 +BDMAP_00003455 +BDMAP_00002402 +BDMAP_00000117 +BDMAP_00000826 +BDMAP_00000243 +BDMAP_00002319 +BDMAP_00002737 +BDMAP_00002318 +BDMAP_00003357 +BDMAP_00000692 +BDMAP_00003427 +BDMAP_00001441 +BDMAP_00004796 +BDMAP_00002171 +BDMAP_00001296 +BDMAP_00004296 +BDMAP_00003808 +BDMAP_00003058 +BDMAP_00003502 +BDMAP_00001045 +BDMAP_00003438 +BDMAP_00002884 +BDMAP_00004561 +BDMAP_00000462 +BDMAP_00001785 +BDMAP_00000794 +BDMAP_00000942 +BDMAP_00002947 +BDMAP_00004744 +BDMAP_00004328 +BDMAP_00004671 +BDMAP_00005108 +BDMAP_00002278 +BDMAP_00000679 +BDMAP_00004903 +BDMAP_00001732 +BDMAP_00001095 +BDMAP_00003343 +BDMAP_00001289 +BDMAP_00001109 +BDMAP_00003650 +BDMAP_00001710 +BDMAP_00003031 +BDMAP_00001617 +BDMAP_00001246 +BDMAP_00004894 +BDMAP_00003520 +BDMAP_00004097 +BDMAP_00001020 +BDMAP_00003600 +BDMAP_00001518 +BDMAP_00000416 +BDMAP_00004990 +BDMAP_00005151 +BDMAP_00000132 +BDMAP_00000138 +BDMAP_00004885 +BDMAP_00000771 +BDMAP_00003928 +BDMAP_00001419 +BDMAP_00003130 +BDMAP_00001892 +BDMAP_00003886 +BDMAP_00004479 +BDMAP_00003918 +BDMAP_00003324 +BDMAP_00002410 +BDMAP_00002509 +BDMAP_00000701 +BDMAP_00003847 +BDMAP_00004450 +BDMAP_00003363 +BDMAP_00002875 +BDMAP_00002793 +BDMAP_00005113 +BDMAP_00000465 +BDMAP_00004847 +BDMAP_00004294 +BDMAP_00000936 +BDMAP_00002476 +BDMAP_00003840 +BDMAP_00004130 +BDMAP_00003614 +BDMAP_00000883 +BDMAP_00000542 +BDMAP_00002562 +BDMAP_00000285 +BDMAP_00001256 +BDMAP_00004597 +BDMAP_00002260 +BDMAP_00001067 +BDMAP_00000968 +BDMAP_00005085 +BDMAP_00003412 +BDMAP_00003884 +BDMAP_00001420 +BDMAP_00003268 +BDMAP_00001735 +BDMAP_00003392 +BDMAP_00000241 +BDMAP_00003326 +BDMAP_00001853 +BDMAP_00001126 +BDMAP_00002237 +BDMAP_00003809 +BDMAP_00001584 +BDMAP_00003359 +BDMAP_00002730 +BDMAP_00000923 +BDMAP_00000687 +BDMAP_00003281 +BDMAP_00004431 +BDMAP_00001440 +BDMAP_00001410 +BDMAP_00004650 +BDMAP_00004065 +BDMAP_00001806 +BDMAP_00002227 +BDMAP_00001906 +BDMAP_00000331 +BDMAP_00001130 +BDMAP_00003178 +BDMAP_00002707 +BDMAP_00001646 +BDMAP_00001707 +BDMAP_00003592 +BDMAP_00003943 +BDMAP_00002361 +BDMAP_00004901 +BDMAP_00003329 +BDMAP_00005075 +BDMAP_00002326 +BDMAP_00003713 +BDMAP_00003832 +BDMAP_00004165 +BDMAP_00004415 +BDMAP_00004331 +BDMAP_00001035 +BDMAP_00004457 +BDMAP_00003347 +BDMAP_00001422 +BDMAP_00002437 +BDMAP_00003996 +BDMAP_00003461 +BDMAP_00002751 +BDMAP_00002523 +BDMAP_00000439 +BDMAP_00004746 +BDMAP_00002188 +BDMAP_00004253 +BDMAP_00000935 +BDMAP_00002451 +BDMAP_00003971 +BDMAP_00000926 +BDMAP_00003109 +BDMAP_00000660 +BDMAP_00001169 +BDMAP_00001331 +BDMAP_00001175 +BDMAP_00000881 +BDMAP_00000263 +BDMAP_00002401 +BDMAP_00005167 +BDMAP_00002041 +BDMAP_00000656 +BDMAP_00000366 +BDMAP_00002582 +BDMAP_00001238 +BDMAP_00001590 +BDMAP_00001784 +BDMAP_00001564 +BDMAP_00004719 +BDMAP_00001917 +BDMAP_00003956 +BDMAP_00003225 +BDMAP_00000982 +BDMAP_00004992 +BDMAP_00003479 +BDMAP_00001215 +BDMAP_00004147 +BDMAP_00001711 +BDMAP_00000626 +BDMAP_00000516 +BDMAP_00004876 +BDMAP_00003376 +BDMAP_00001628 +BDMAP_00001148 +BDMAP_00003672 +BDMAP_00001205 +BDMAP_00004651 +BDMAP_00000987 +BDMAP_00004104 +BDMAP_00001647 +BDMAP_00000998 +BDMAP_00002244 +BDMAP_00004676 +BDMAP_00001908 +BDMAP_00000714 +BDMAP_00001104 +BDMAP_00001911 +BDMAP_00000882 +BDMAP_00003930 +BDMAP_00000368 +BDMAP_00003923 +BDMAP_00002099 +BDMAP_00000240 +BDMAP_00003658 +BDMAP_00005077 +BDMAP_00002696 +BDMAP_00002184 +BDMAP_00003890 +BDMAP_00002704 +BDMAP_00000066 +BDMAP_00005006 +BDMAP_00001242 +BDMAP_00002396 +BDMAP_00004389 +BDMAP_00002656 +BDMAP_00000469 +BDMAP_00001138 +BDMAP_00004773 +BDMAP_00004033 +BDMAP_00004128 +BDMAP_00002631 +BDMAP_00004925 +BDMAP_00004475 +BDMAP_00001521 +BDMAP_00000364 +BDMAP_00002953 +BDMAP_00003776 +BDMAP_00004154 +BDMAP_00002654 +BDMAP_00002959 +BDMAP_00002199 +BDMAP_00003551 +BDMAP_00002465 +BDMAP_00005154 +BDMAP_00002648 +BDMAP_00000128 +BDMAP_00001001 +BDMAP_00002017 +BDMAP_00004712 +BDMAP_00004286 +BDMAP_00000568 +BDMAP_00004858 +BDMAP_00001782 +BDMAP_00001496 +BDMAP_00004407 +BDMAP_00002250 +BDMAP_00001212 +BDMAP_00000972 +BDMAP_00004374 +BDMAP_00002846 +BDMAP_00002472 +BDMAP_00000569 +BDMAP_00004981 +BDMAP_00000176 +BDMAP_00003510 +BDMAP_00003771 +BDMAP_00002804 +BDMAP_00004558 +BDMAP_00003411 +BDMAP_00001563 +BDMAP_00000604 +BDMAP_00002075 +BDMAP_00005160 +BDMAP_00001511 +BDMAP_00001273 +BDMAP_00002603 +BDMAP_00001656 +BDMAP_00003822 +BDMAP_00004510 +BDMAP_00001809 +BDMAP_00002944 +BDMAP_00002739 +BDMAP_00002609 +BDMAP_00003849 +BDMAP_00001128 +BDMAP_00003717 +BDMAP_00000036 +BDMAP_00002863 +BDMAP_00004956 +BDMAP_00004229 +BDMAP_00003425 +BDMAP_00001865 +BDMAP_00000608 +BDMAP_00004620 +BDMAP_00000589 +BDMAP_00001597 +BDMAP_00003543 +BDMAP_00004645 +BDMAP_00004395 +BDMAP_00005105 +BDMAP_00001426 +BDMAP_00000264 +BDMAP_00001504 +BDMAP_00001649 +BDMAP_00000662 +BDMAP_00002854 +BDMAP_00004060 +BDMAP_00003440 +BDMAP_00003367 +BDMAP_00004011 +BDMAP_00003634 +BDMAP_00003443 +BDMAP_00000828 +BDMAP_00000889 +BDMAP_00000321 +BDMAP_00004615 +BDMAP_00000244 +BDMAP_00003685 +BDMAP_00001461 +BDMAP_00001396 +BDMAP_00004262 +BDMAP_00004579 +BDMAP_00005022 +BDMAP_00004804 +BDMAP_00001632 +BDMAP_00002661 +BDMAP_00000980 +BDMAP_00001445 +BDMAP_00000809 +BDMAP_00004384 +BDMAP_00003114 +BDMAP_00000435 +BDMAP_00003406 +BDMAP_00002899 +BDMAP_00002164 +BDMAP_00002498 +BDMAP_00000039 +BDMAP_00002524 +BDMAP_00000805 +BDMAP_00004604 +BDMAP_00000338 +BDMAP_00002990 +BDMAP_00001516 +BDMAP_00002896 +BDMAP_00004549 +BDMAP_00000259 +BDMAP_00001945 +BDMAP_00002695 +BDMAP_00005141 +BDMAP_00002828 +BDMAP_00003781 +BDMAP_00003900 +BDMAP_00004278 +BDMAP_00004551 +BDMAP_00000532 +BDMAP_00002844 +BDMAP_00001476 +BDMAP_00004887 +BDMAP_00005174 +BDMAP_00000836 +BDMAP_00001456 +BDMAP_00001607 +BDMAP_00003164 +BDMAP_00002404 +BDMAP_00003036 +BDMAP_00001225 +BDMAP_00002022 +BDMAP_00004030 +BDMAP_00000329 +BDMAP_00002253 +BDMAP_00000154 +BDMAP_00003111 +BDMAP_00003384 +BDMAP_00000023 +BDMAP_00001125 +BDMAP_00001414 +BDMAP_00002383 +BDMAP_00003483 +BDMAP_00000034 +BDMAP_00001413 +BDMAP_00003767 +BDMAP_00001368 +BDMAP_00003448 +BDMAP_00000940 +BDMAP_00000430 +BDMAP_00003153 +BDMAP_00003603 +BDMAP_00003202 +BDMAP_00002421 +BDMAP_00005001 +BDMAP_00004447 +BDMAP_00001325 +BDMAP_00003168 +BDMAP_00000887 +BDMAP_00004481 +BDMAP_00001324 +BDMAP_00004066 +BDMAP_00001474 +BDMAP_00004850 +BDMAP_00002233 +BDMAP_00000511 +BDMAP_00001223 +BDMAP_00003581 +BDMAP_00002930 +BDMAP_00001305 +BDMAP_00002689 +BDMAP_00002332 +BDMAP_00000683 +BDMAP_00003300 +BDMAP_00003701 +BDMAP_00001015 +BDMAP_00001562 +BDMAP_00001898 +BDMAP_00001247 +BDMAP_00001941 +BDMAP_00002840 +BDMAP_00002440 +BDMAP_00000245 +BDMAP_00002855 +BDMAP_00004493 +BDMAP_00000989 +BDMAP_00003736 +BDMAP_00002265 +BDMAP_00004039 +BDMAP_00002826 +BDMAP_00002924 +BDMAP_00003299 +BDMAP_00001361 +BDMAP_00004014 +BDMAP_00001444 +BDMAP_00001370 +BDMAP_00002304 +BDMAP_00000774 +BDMAP_00000614 +BDMAP_00000434 +BDMAP_00001230 +BDMAP_00000044 +BDMAP_00001768 +BDMAP_00004783 +BDMAP_00004494 +BDMAP_00001905 +BDMAP_00003824 +BDMAP_00002309 +BDMAP_00004511 +BDMAP_00000233 +BDMAP_00002845 +BDMAP_00005016 +BDMAP_00002829 +BDMAP_00001059 +BDMAP_00001549 +BDMAP_00002403 +BDMAP_00001794 +BDMAP_00001286 +BDMAP_00003294 +BDMAP_00003722 +BDMAP_00000902 +BDMAP_00002298 +BDMAP_00005191 +BDMAP_00001487 +BDMAP_00003364 +BDMAP_00001605 +BDMAP_00001483 +BDMAP_00000676 +BDMAP_00002945 +BDMAP_00005073 +BDMAP_00002085 +BDMAP_00000716 +BDMAP_00003435 +BDMAP_00002803 +BDMAP_00002663 +BDMAP_00003727 +BDMAP_00000839 +BDMAP_00002068 +BDMAP_00004764 +BDMAP_00002114 +BDMAP_00004741 +BDMAP_00004077 +BDMAP_00004870 +BDMAP_00000571 +BDMAP_00004115 +BDMAP_00001868 +BDMAP_00004113 +BDMAP_00002039 +BDMAP_00004257 +BDMAP_00001620 +BDMAP_00000470 +BDMAP_00000149 +BDMAP_00002815 +BDMAP_00000304 +BDMAP_00005185 +BDMAP_00003113 +BDMAP_00005063 +BDMAP_00000122 +BDMAP_00004482 +BDMAP_00002471 +BDMAP_00004023 +BDMAP_00000225 +BDMAP_00003657 +BDMAP_00001255 +BDMAP_00002616 +BDMAP_00002407 +BDMAP_00002060 +BDMAP_00004546 +BDMAP_00004917 +BDMAP_00003615 +BDMAP_00003525 +BDMAP_00002120 +BDMAP_00000481 +BDMAP_00004770 +BDMAP_00003683 +BDMAP_00000618 +BDMAP_00001875 +BDMAP_00003409 +BDMAP_00003381 +BDMAP_00004398 +BDMAP_00000867 +BDMAP_00000487 +BDMAP_00003073 +BDMAP_00002592 +BDMAP_00005120 +BDMAP_00003128 +BDMAP_00001754 +BDMAP_00004232 +BDMAP_00000855 +BDMAP_00000069 +BDMAP_00002744 +BDMAP_00004808 +BDMAP_00004031 +BDMAP_00001842 +BDMAP_00000324 +BDMAP_00002933 +BDMAP_00004954 +BDMAP_00000541 +BDMAP_00002458 +BDMAP_00002288 +BDMAP_00002807 +BDMAP_00000837 +BDMAP_00002065 +BDMAP_00000152 +BDMAP_00003491 +BDMAP_00001464 +BDMAP_00003486 +BDMAP_00003244 +BDMAP_00000871 +BDMAP_00002362 +BDMAP_00000993 +BDMAP_00000219 +BDMAP_00000192 +BDMAP_00001218 +BDMAP_00001024 +BDMAP_00004980 +BDMAP_00000713 +BDMAP_00001523 +BDMAP_00002688 +BDMAP_00003143 +BDMAP_00005114 +BDMAP_00003749 +BDMAP_00002354 +BDMAP_00000052 +BDMAP_00002710 +BDMAP_00004817 +BDMAP_00004964 +BDMAP_00004775 +BDMAP_00005005 +BDMAP_00004216 +BDMAP_00002936 +BDMAP_00000956 +BDMAP_00002942 +BDMAP_00001705 +BDMAP_00001823 +BDMAP_00002387 +BDMAP_00000690 +BDMAP_00002021 +BDMAP_00000851 +BDMAP_00000427 +BDMAP_00002133 +BDMAP_00004231 +BDMAP_00005169 +BDMAP_00003640 +BDMAP_00000977 +BDMAP_00002103 +BDMAP_00000449 +BDMAP_00001214 +BDMAP_00003506 +BDMAP_00002411 +BDMAP_00003973 +BDMAP_00001912 +BDMAP_00000710 +BDMAP_00004514 +BDMAP_00001807 +BDMAP_00001769 +BDMAP_00001746 +BDMAP_00001804 +BDMAP_00002484 +BDMAP_00003444 +BDMAP_00002029 +BDMAP_00001237 +BDMAP_00004420 +BDMAP_00000431 +BDMAP_00003252 +BDMAP_00005081 +BDMAP_00003694 +BDMAP_00002655 +BDMAP_00004641 +BDMAP_00000297 +BDMAP_00001077 +BDMAP_00003254 +BDMAP_00000447 +BDMAP_00004834 diff --git a/Generation_Pipeline_filter_all/syn_liver/requirements.txt b/Generation_Pipeline_filter_all/syn_liver/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9c1067c8eb4f44699bbf517601396113cea4b370 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_liver/requirements.txt @@ -0,0 +1,94 @@ +absl-py==1.1.0 +accelerate==0.11.0 +aiohttp==3.8.1 +aiosignal==1.2.0 +antlr4-python3-runtime==4.9.3 +async-timeout==4.0.2 +attrs==21.4.0 +autopep8==1.6.0 +cachetools==5.2.0 +certifi==2022.6.15 +charset-normalizer==2.0.12 +click==8.1.3 +cycler==0.11.0 +Deprecated==1.2.13 +docker-pycreds==0.4.0 +einops==0.4.1 +einops-exts==0.0.3 +ema-pytorch==0.0.8 +fonttools==4.34.4 +frozenlist==1.3.0 +fsspec==2022.5.0 +ftfy==6.1.1 +future==0.18.2 +gitdb==4.0.9 +GitPython==3.1.27 +google-auth==2.9.0 +google-auth-oauthlib==0.4.6 +grpcio==1.47.0 +h5py==3.7.0 +humanize==4.2.2 +hydra-core==1.2.0 +idna==3.3 +imageio==2.19.3 +imageio-ffmpeg==0.4.7 +importlib-metadata==4.12.0 +importlib-resources==5.9.0 +joblib==1.1.0 +kiwisolver==1.4.3 +lxml==4.9.1 +Markdown==3.3.7 +matplotlib==3.5.2 +multidict==6.0.2 +networkx==2.8.5 +nibabel==4.0.1 +nilearn==0.9.1 +numpy==1.23.0 +oauthlib==3.2.0 +omegaconf==2.2.3 +pandas==1.4.3 +Pillow==9.1.1 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pycodestyle==2.8.0 +pyDeprecate==0.3.1 +pydicom==2.3.0 +pytorch-lightning==1.6.4 +pytz==2022.1 +PyWavelets==1.3.0 +PyYAML==6.0 +pyzmq==19.0.2 +regex==2022.6.2 +requests==2.28.0 +requests-oauthlib==1.3.1 +rotary-embedding-torch==0.1.5 +rsa==4.8 +scikit-image==0.19.3 +scikit-learn==1.1.2 +scikit-video==1.1.11 +scipy==1.8.1 +seaborn==0.11.2 +sentry-sdk==1.7.2 +setproctitle==1.2.3 +shortuuid==1.0.9 +SimpleITK==2.1.1.2 +smmap==5.0.0 +tensorboard==2.9.1 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.1 +threadpoolctl==3.1.0 +tifffile==2022.8.3 +toml==0.10.2 +torch-tb-profiler==0.4.0 +torchio==0.18.80 +torchmetrics==0.9.1 +tqdm==4.64.0 +typing_extensions==4.2.0 +urllib3==1.26.9 +wandb==0.12.21 +Werkzeug==2.1.2 +wrapt==1.14.1 +yarl==1.7.2 +zipp==3.8.0 +wandb +tensorboardX==2.4.1 diff --git a/Generation_Pipeline_filter_all/syn_pancreas/CT_syn_pancreas_data_new.py b/Generation_Pipeline_filter_all/syn_pancreas/CT_syn_pancreas_data_new.py new file mode 100644 index 0000000000000000000000000000000000000000..94fe7c96cc0d65667edc9fd8fc2cb85cf7f718c7 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_pancreas/CT_syn_pancreas_data_new.py @@ -0,0 +1,242 @@ +import os, time, csv +import numpy as np +import torch +from sklearn.metrics import confusion_matrix +from scipy import ndimage +from scipy.ndimage import label +from functools import partial +import monai +from monai.transforms import AsDiscrete,AsDiscreted,Compose,Invertd,SaveImaged +from monai import transforms, data +from TumorGeneration.utils import synthesize_early_tumor, synthesize_medium_tumor, synthesize_large_tumor, synt_model_prepare +import nibabel as nib + +import warnings +warnings.filterwarnings("ignore") + +import argparse +parser = argparse.ArgumentParser(description='pancreas tumor validation') + +# file dir +parser.add_argument('--data_root', default=None, type=str) +parser.add_argument('--organ_type', default='pancreas', type=str) +parser.add_argument('--save_dir', default='out', type=str) +parser.add_argument('--data_file', default='out', type=str) +parser.add_argument('--ddim_ts', default=50, type=int) +parser.add_argument('--fg_thresh', default=30, type=int) +parser.add_argument('--start', default=0, type=int) +parser.add_argument('--end', default=1000, type=int) + +def voxel2R(A): + return (np.array(A)/4*3/np.pi)**(1/3) + +class RandCropByPosNegLabeld_select(transforms.RandCropByPosNegLabeld): + def __init__(self, keys, label_key, spatial_size, + pos=1.0, neg=1.0, num_samples=1, + image_key=None, image_threshold=0.0, allow_missing_keys=True, + fg_thresh=0): + super().__init__(keys=keys, label_key=label_key, spatial_size=spatial_size, + pos=pos, neg=neg, num_samples=num_samples, + image_key=image_key, image_threshold=image_threshold, allow_missing_keys=allow_missing_keys) + self.fg_thresh = fg_thresh + + def R2voxel(self,R): + return (4/3*np.pi)*(R)**(3) + + def __call__(self, data): + d = dict(data) + data_name = d['name'] + d.pop('name') + + if '10_Decathlon' in data_name or '05_KiTS' in data_name: + d_crop = super().__call__(d) + + else: + flag=0 + while 1: + flag+=1 + + d_crop = super().__call__(d) + pixel_num = (d_crop[0]['label']>0).sum() + + if pixel_num > self.R2voxel(self.fg_thresh): + break + if flag>5 and pixel_num > self.R2voxel(max(self.fg_thresh-5, 5)): + break + if flag>10 and pixel_num > self.R2voxel(max(self.fg_thresh-10, 5)): + break + if flag>15 and pixel_num > self.R2voxel(max(self.fg_thresh-15, 5)): + break + if flag>20 and pixel_num > self.R2voxel(max(self.fg_thresh-20, 5)): + break + if flag>25 and pixel_num > self.R2voxel(max(self.fg_thresh-25, 5)): + break + if flag>50: + break + + d_crop[0]['name'] = data_name + + return d_crop + +def _get_loader(args): + # val_data_dir = args.val_dir + # datalist_json = args.json_dir + val_org_transform = transforms.Compose( + [ + transforms.LoadImaged(keys=["image", "label", "raw_image"]), + transforms.AddChanneld(keys=["image", "label", "raw_image"]), + transforms.Orientationd(keys=["image", "label"], axcodes="RAS"), + transforms.Spacingd(keys=["image", "label"], pixdim=(1.0, 1.0, 1.0), mode=("bilinear", "bilinear")), + transforms.ScaleIntensityRanged(keys=["image"], a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True), + transforms.SpatialPadd(keys=["image", "label"], mode=["minimum", "constant"], spatial_size=[96, 96, 96]), + RandCropByPosNegLabeld_select( + keys=["image", "label", "name"], + label_key="label", + spatial_size=(96, 96, 96), + pos=1, + neg=0, + num_samples=1, + image_key="image", + image_threshold=0, + fg_thresh = args.fg_thresh, + ), + transforms.ToTensord(keys=["image", "label", "raw_image"]), + ] + ) + + val_img=[] + val_lbl=[] + val_name=[] + + for line in open(args.data_file): + # name = line.strip().split()[1].split('.')[0] + # val_img.append(args.data_root + line.strip().split()[0]) + # val_lbl.append(args.data_root + line.strip().split()[1]) + # breakpoint() + name = line.strip() + val_img.append(os.path.join(args.data_root, name, 'ct.nii.gz')) + val_lbl.append(os.path.join(args.data_root, name, 'segmentations/pancreas.nii.gz')) + val_name.append(name) + data_dicts_val = [{'image': image, 'raw_image':image, 'label': label, 'name': name} + for image, label, name in zip(val_img, val_lbl, val_name)] + + if args.end < len(data_dicts_val): + data_dicts_val = data_dicts_val[args.start:args.end] + else: + data_dicts_val = data_dicts_val[args.start:] + print('val len {}'.format(len(data_dicts_val))) + val_org_ds = data.Dataset(data_dicts_val, transform=val_org_transform) + val_org_loader = data.DataLoader(val_org_ds, batch_size=1, shuffle=False, num_workers=4, sampler=None, pin_memory=True) + + post_transforms = Compose([ + Invertd( + keys=['image'], + transform=val_org_transform, + orig_keys="image", + nearest_interp=False, + # nearest_interp=True, + to_tensor=True, + ), + Invertd( + keys=['label'], + transform=val_org_transform, + orig_keys="label", + nearest_interp=False, + # nearest_interp=True, + to_tensor=True, + ) + ]) + return val_org_loader, post_transforms + +def main(): + args = parser.parse_args() + output_dir = args.save_dir + if not os.path.exists(output_dir): + os.makedirs(output_dir) + print("MAIN Argument values:") + for k, v in vars(args).items(): + print(k, '=>', v) + print('-----------------') + + ## loader and post_transform + val_loader, post_transforms = _get_loader(args) + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = monai.networks.nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=2).to(device) + model.load_state_dict(torch.load("../best_metric_model_classification3d_dict.pth")) + model.eval() + + start_time=0 + with torch.no_grad(): + for idx, val_data in enumerate(val_loader): + print('idx',idx) + if idx == 0: + start_time = time.time() + # val_inputs = val_data["image"] + # name = val_data['label_meta_dict']['filename_or_obj'][0].split('/')[-1].split('.')[0] + + vqgan, early_sampler, noearly_sampler= synt_model_prepare(device = torch.device("cuda"), fold=0, organ=args.organ_type) + + healthy_data, healthy_target, data_names, raw_data = val_data['image'], val_data['label'], val_data['name'], val_data['raw_image'] + case_name = data_names[0].split('/')[-1] + print('case_name', case_name) + original_affine = val_data["label_meta_dict"]["original_affine"][0].numpy() + + if healthy_target.sum() == 0: + val_data = [post_transforms(i) for i in data.decollate_batch(val_data)] + tumor_mask = val_data[0]['label'][0].cpu().numpy().astype(np.uint8) + tumor_mask_ = np.zeros_like(tumor_mask) + nib.save(nib.Nifti1Image(tumor_mask_, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/pancreas_tumor.nii.gz')) + continue + + healthy_data, healthy_target = healthy_data.cuda(), healthy_target.cuda() + healthy_target = (healthy_target>0).to(healthy_target) + + tumor_types = ['early', 'medium', 'large'] + # tumor_probs = np.array([0.45, 0.45, 0.1]) + # tumor_probs = np.array([1.0, 0.0, 0.0]) + tumor_probs = np.array([0.5, 0.4, 0.1]) + synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + print('synthetic_tumor_type',synthetic_tumor_type) + flag=0 + while 1: + flag+=1 + if synthetic_tumor_type == 'early': + synt_data, synt_target = synthesize_early_tumor(healthy_data, healthy_target, args.organ_type, vqgan, early_sampler) + elif synthetic_tumor_type == 'medium': + synt_data, synt_target = synthesize_medium_tumor(healthy_data, healthy_target, args.organ_type, vqgan, noearly_sampler, ddim_ts=args.ddim_ts) + elif synthetic_tumor_type == 'large': + synt_data, synt_target = synthesize_large_tumor(healthy_data, healthy_target, args.organ_type, vqgan, noearly_sampler, ddim_ts=args.ddim_ts) + + syn_confidence = model(synt_data).sigmoid()[:,1] + if syn_confidence>0.01: + break + elif flag > 20 and syn_confidence>0.005: + break + elif flag > 40 and syn_confidence>0.001: + break + + val_data['image'] = synt_data.detach() + val_data['label'] = synt_target.detach() + + val_data = [post_transforms(i) for i in data.decollate_batch(val_data)] + synt_data = val_data[0]['image'][0] + synt_target = val_data[0]['label'][0] + final_data = raw_data[0,0] + + synt_data = (synt_data*(250+175)-175) + final_data[synt_target>1] = synt_data[synt_target>1] + final_data = final_data.cpu().numpy() + final_label = (synt_target>=1.5).cpu().numpy().astype(np.uint8) + + os.makedirs(os.path.join(output_dir, f'{case_name}'), exist_ok=True) + os.makedirs(os.path.join(output_dir, f'{case_name}/segmentations'), exist_ok=True) + nib.save(nib.Nifti1Image(final_data, original_affine), os.path.join(output_dir, f'{case_name}/ct.nii.gz')) + nib.save(nib.Nifti1Image(final_label, original_affine), os.path.join(output_dir, f'{case_name}/segmentations/pancreas_tumor.nii.gz')) + + print('time = ', time.time()-start_time) + start_time = time.time() + + +if __name__ == "__main__": + main() diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/.DS_Store b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..fb0d207e124cfecc777faae5b4a56c5ca1b9cd2e Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/.DS_Store differ diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/README.md b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/README.md new file mode 100644 index 0000000000000000000000000000000000000000..201911031ca988c410781a60dc3c192a65ee56b3 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/README.md @@ -0,0 +1,5 @@ +```bash +wget https://www.dropbox.com/scl/fi/k856fhk60kck8uqxtxazw/model_weight.tar.gz?rlkey=hrcn4cbt690dzern1bkbfejll +mv model_weight.tar.gz?rlkey=hrcn4cbt690dzern1bkbfejll model_weight.tar.gz +tar -xzvf model_weight.tar.gz +``` diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/TumorGenerated.py b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/TumorGenerated.py new file mode 100644 index 0000000000000000000000000000000000000000..9983cf047b1532f5edd20c0d4c78102d6d611eb9 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/TumorGenerated.py @@ -0,0 +1,39 @@ +import random +from typing import Hashable, Mapping, Dict + +from monai.config import KeysCollection +from monai.config.type_definitions import NdarrayOrTensor +from monai.transforms.transform import MapTransform, RandomizableTransform + +from .utils_ import SynthesisTumor +import numpy as np + +class TumorGenerated(RandomizableTransform, MapTransform): + def __init__(self, + keys: KeysCollection, + prob: float = 0.1, + tumor_prob = [0.2, 0.2, 0.2, 0.2, 0.2], + allow_missing_keys: bool = False + ) -> None: + MapTransform.__init__(self, keys, allow_missing_keys) + RandomizableTransform.__init__(self, prob) + random.seed(0) + np.random.seed(0) + + self.tumor_types = ['tiny', 'small', 'medium', 'large', 'mix'] + + assert len(tumor_prob) == 5 + self.tumor_prob = np.array(tumor_prob) + + + + def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]: + d = dict(data) + self.randomize(None) + + if self._do_transform and (np.max(d['label']) <= 1): + tumor_type = np.random.choice(self.tumor_types, p=self.tumor_prob.ravel()) + + d['image'][0], d['label'][0] = SynthesisTumor(d['image'][0], d['label'][0], tumor_type) + # print(tumor_type, d['image'].shape, np.max(d['label'])) + return d diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/__init__.py b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1afc8a195ba5fd106ca18d4e219c123a75e6e831 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/__init__.py @@ -0,0 +1,5 @@ +### Online Version TumorGeneration ### + +from .TumorGenerated import TumorGenerated + +from .utils import synthesize_early_tumor, synthesize_medium_tumor, synthesize_large_tumor \ No newline at end of file diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02cb9a3bda15ccd819e927754d80c3538090b4d2 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/__pycache__/TumorGenerated.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc1142ed95f5cd5ccc2ab25fb5f501ffd782f335 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/__pycache__/utils.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad72b834e61a83099d491b2a359824c51a42beb4 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/__pycache__/utils.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/__pycache__/utils_.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/__pycache__/utils_.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c67d98211992fb7cb17035aac7241ee0afb0ae44 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/__pycache__/utils_.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/diffusion_config/ddpm.yaml b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/diffusion_config/ddpm.yaml new file mode 100644 index 0000000000000000000000000000000000000000..474a973b1f76c2a026e2df916c4a153b5bbea05f --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/diffusion_config/ddpm.yaml @@ -0,0 +1,29 @@ +vqgan_ckpt: None + +# Have to be derived from VQ-GAN Latent space dimensions +diffusion_img_size: 24 +diffusion_depth_size: 24 +diffusion_num_channels: 17 # 17 +out_dim: 8 +dim_mults: [1,2,4,8] +results_folder: checkpoints/ddpm/ +results_folder_postfix: 'own_dataset_t2' +load_milestone: False # False + +batch_size: 2 # 40 +num_workers: 20 +logger: wandb +objective: pred_x0 +save_and_sample_every: 1000 +denoising_fn: Unet3D +train_lr: 1e-4 +timesteps: 2 # number of steps +sampling_timesteps: 250 # number of sampling timesteps (using ddim for faster inference [see citation for ddim paper]) +loss_type: l1 # L1 or L2 +train_num_steps: 700000 # total training steps +gradient_accumulate_every: 2 # gradient accumulation steps +ema_decay: 0.995 # exponential moving average decay +amp: False # turn on mixed precision +num_sample_rows: 1 +gpus: 0 + diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/diffusion_config/vq_gan_3d.yaml b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/diffusion_config/vq_gan_3d.yaml new file mode 100644 index 0000000000000000000000000000000000000000..29940377c5cadb0c06322de8ac60d0713f799024 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/diffusion_config/vq_gan_3d.yaml @@ -0,0 +1,37 @@ +seed: 1234 +batch_size: 2 # 30 +num_workers: 32 # 30 + +gpus: 1 +accumulate_grad_batches: 1 +default_root_dir: checkpoints/vq_gan/ +default_root_dir_postfix: 'flair' +resume_from_checkpoint: +max_steps: -1 +max_epochs: -1 +precision: 16 +gradient_clip_val: 1.0 + + +embedding_dim: 8 # 256 +n_codes: 16384 # 2048 +n_hiddens: 16 +lr: 3e-4 +downsample: [2, 2, 2] # [4, 4, 4] +disc_channels: 64 +disc_layers: 3 +discriminator_iter_start: 10000 # 50000 +disc_loss_type: hinge +image_gan_weight: 1.0 +video_gan_weight: 1.0 +l1_weight: 4.0 +gan_feat_weight: 4.0 # 0.0 +perceptual_weight: 4.0 # 0.0 +i3d_feat: False +restart_thres: 1.0 +no_random_restart: False +norm_type: group +padding_type: replicate +num_groups: 32 + + diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__init__.py b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a01336e46e37471097d8e1420d0ffd5d803a1edd --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__init__.py @@ -0,0 +1 @@ +from .diffusion import Unet3D, GaussianDiffusion, Tester diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fe5bab6ebc89b70a4b1a01afcaff758e59ed73b Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c25d2c3abacd1a912c6753385896dd607193ab33 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/ddim.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8479abbdbbc5bcbd3be9dda38a303ff83cef8fc5 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/diffusion.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d933fb26b3a9bad8718b0fc74180a3e595d19854 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/text.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7ac1a6aea44a1d1741652f2f0a598fb18b2730d Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/time_embedding.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..20c6786f9995cbd455c900944b0ab9501a97c202 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/unet.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12e044abbf40919d2be472579cb32f762b2d8f4e Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/__pycache__/util.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/ddim.py b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..1bc51f94355732aedb0ffd254b8166144c468370 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/ddim.py @@ -0,0 +1,206 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm +from functools import partial + +from .util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): # "uniform" 'quad' + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) + + alphas_cumprod = self.model.alphas_cumprod + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer('betas', to_torch(self.model.betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) + # breakpoint() + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta,verbose=verbose) + self.register_buffer('ddim_sigmas', ddim_sigmas) + self.register_buffer('ddim_alphas', ddim_alphas) + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) + + @torch.no_grad() + def sample(self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0., + mask=None, + x0=None, + temperature=1., + noise_dropout=0., + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1., + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") + else: + if conditioning.shape[0] != batch_size: + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, T, H, W = shape + # breakpoint() + size = (batch_size, C, T, H, W) + # print(f'Data shape for DDIM sampling is {size}, eta {eta}') + + samples, intermediates = self.ddim_sampling(conditioning, size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling(self, cond, shape, + x_T=None, ddim_use_original_steps=False, + callback=None, timesteps=None, quantize_denoised=False, + mask=None, x0=None, img_callback=None, log_every_t=100, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None,): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps + elif timesteps is not None and not ddim_use_original_steps: + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {'x_inter': [img], 'pred_x0': [img]} + time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + # iterator = tqdm(time_range, desc='DDIM Sampler', total=total_steps) + + for i, step in enumerate(time_range): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? + img = img_orig * mask + (1. - mask) * img + + outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, temperature=temperature, + noise_dropout=noise_dropout, score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning) + img, pred_x0 = outs + if callback: callback(i) + if img_callback: img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates['x_inter'].append(img) + intermediates['pred_x0'].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, + unconditional_guidance_scale=1., unconditional_conditioning=None): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: + # breakpoint() + e_t = self.model.denoise_fn(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.denoise_fn(x_in, t_in, c_in).chunk(2) + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas + sigmas = self.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise + return x_prev, pred_x0 diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/diffusion.py b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..87111dc3b2ab671e798efc3a320ae81d024f20b9 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/diffusion.py @@ -0,0 +1,1016 @@ +"Largely taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch" + +import math +import copy +import torch +from torch import nn, einsum +import torch.nn.functional as F +from functools import partial + +from torch.utils import data +from pathlib import Path +from torch.optim import Adam +from torchvision import transforms as T, utils +from torch.cuda.amp import autocast, GradScaler +from PIL import Image + +from tqdm import tqdm +from einops import rearrange +from einops_exts import check_shape, rearrange_many + +from rotary_embedding_torch import RotaryEmbedding + +from .text import tokenize, bert_embed, BERT_MODEL_DIM +from torch.utils.data import Dataset, DataLoader +from ..vq_gan_3d.model.vqgan import VQGAN + +import matplotlib.pyplot as plt + +# helpers functions + + +def exists(x): + return x is not None + + +def noop(*args, **kwargs): + pass + + +def is_odd(n): + return (n % 2) == 1 + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +def cycle(dl): + while True: + for data in dl: + yield data + + +def num_to_groups(num, divisor): + groups = num // divisor + remainder = num % divisor + arr = [divisor] * groups + if remainder > 0: + arr.append(remainder) + return arr + + +def prob_mask_like(shape, prob, device): + if prob == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif prob == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + return torch.zeros(shape, device=device).float().uniform_(0, 1) < prob + + +def is_list_str(x): + if not isinstance(x, (list, tuple)): + return False + return all([type(el) == str for el in x]) + +# relative positional bias + + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads=8, + num_buckets=32, + max_distance=128 + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / + max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min( + val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype=torch.long, device=device) + k_pos = torch.arange(n, dtype=torch.long, device=device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket( + rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') + +# small helper modules + + +class EMA(): + def __init__(self, beta): + super().__init__() + self.beta = beta + + def update_model_average(self, ma_model, current_model): + for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()): + old_weight, up_weight = ma_params.data, current_params.data + ma_params.data = self.update_average(old_weight, up_weight) + + def update_average(self, old, new): + if old is None: + return new + return old * self.beta + (1 - self.beta) * new + + +class Residual(nn.Module): + def __init__(self, fn): + super().__init__() + self.fn = fn + + def forward(self, x, *args, **kwargs): + return self.fn(x, *args, **kwargs) + x + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + device = x.device + half_dim = self.dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, device=device) * -emb) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + return emb + + +def Upsample(dim): + return nn.ConvTranspose3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) + + +def Downsample(dim): + return nn.Conv3d(dim, dim, (1, 4, 4), (1, 2, 2), (0, 1, 1)) + + +class LayerNorm(nn.Module): + def __init__(self, dim, eps=1e-5): + super().__init__() + self.eps = eps + self.gamma = nn.Parameter(torch.ones(1, dim, 1, 1, 1)) + + def forward(self, x): + var = torch.var(x, dim=1, unbiased=False, keepdim=True) + mean = torch.mean(x, dim=1, keepdim=True) + return (x - mean) / (var + self.eps).sqrt() * self.gamma + + +class PreNorm(nn.Module): + def __init__(self, dim, fn): + super().__init__() + self.fn = fn + self.norm = LayerNorm(dim) + + def forward(self, x, **kwargs): + x = self.norm(x) + return self.fn(x, **kwargs) + +# building block modules + + +class Block(nn.Module): + def __init__(self, dim, dim_out, groups=8): + super().__init__() + self.proj = nn.Conv3d(dim, dim_out, (1, 3, 3), padding=(0, 1, 1)) + self.norm = nn.GroupNorm(groups, dim_out) + self.act = nn.SiLU() + + def forward(self, x, scale_shift=None): + x = self.proj(x) + x = self.norm(x) + + if exists(scale_shift): + scale, shift = scale_shift + x = x * (scale + 1) + shift + + return self.act(x) + + +class ResnetBlock(nn.Module): + def __init__(self, dim, dim_out, *, time_emb_dim=None, groups=8): + super().__init__() + self.mlp = nn.Sequential( + nn.SiLU(), + nn.Linear(time_emb_dim, dim_out * 2) + ) if exists(time_emb_dim) else None + + self.block1 = Block(dim, dim_out, groups=groups) + self.block2 = Block(dim_out, dim_out, groups=groups) + self.res_conv = nn.Conv3d( + dim, dim_out, 1) if dim != dim_out else nn.Identity() + + def forward(self, x, time_emb=None): + + scale_shift = None + if exists(self.mlp): + assert exists(time_emb), 'time emb must be passed in' + time_emb = self.mlp(time_emb) + time_emb = rearrange(time_emb, 'b c -> b c 1 1 1') + scale_shift = time_emb.chunk(2, dim=1) + + h = self.block1(x, scale_shift=scale_shift) + + h = self.block2(h) + return h + self.res_conv(x) + + +class SpatialLinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, f, h, w = x.shape + x = rearrange(x, 'b c f h w -> (b f) c h w') + + qkv = self.to_qkv(x).chunk(3, dim=1) + q, k, v = rearrange_many( + qkv, 'b (h c) x y -> b h c (x y)', h=self.heads) + + q = q.softmax(dim=-2) + k = k.softmax(dim=-1) + + q = q * self.scale + context = torch.einsum('b h d n, b h e n -> b h d e', k, v) + + out = torch.einsum('b h d e, b h d n -> b h e n', context, q) + out = rearrange(out, 'b h c (x y) -> b (h c) x y', + h=self.heads, x=h, y=w) + out = self.to_out(out) + return rearrange(out, '(b f) c h w -> b c f h w', b=b) + +# attention along space and time + + +class EinopsToAndFrom(nn.Module): + def __init__(self, from_einops, to_einops, fn): + super().__init__() + self.from_einops = from_einops + self.to_einops = to_einops + self.fn = fn + + def forward(self, x, **kwargs): + shape = x.shape + reconstitute_kwargs = dict( + tuple(zip(self.from_einops.split(' '), shape))) + x = rearrange(x, f'{self.from_einops} -> {self.to_einops}') + x = self.fn(x, **kwargs) + x = rearrange( + x, f'{self.to_einops} -> {self.from_einops}', **reconstitute_kwargs) + return x + + +class Attention(nn.Module): + def __init__( + self, + dim, + heads=4, + dim_head=32, + rotary_emb=None + ): + super().__init__() + self.scale = dim_head ** -0.5 + self.heads = heads + hidden_dim = dim_head * heads + + self.rotary_emb = rotary_emb + self.to_qkv = nn.Linear(dim, hidden_dim * 3, bias=False) + self.to_out = nn.Linear(hidden_dim, dim, bias=False) + + def forward( + self, + x, + pos_bias=None, + focus_present_mask=None + ): + n, device = x.shape[-2], x.device + + qkv = self.to_qkv(x).chunk(3, dim=-1) + + if exists(focus_present_mask) and focus_present_mask.all(): + # if all batch samples are focusing on present + # it would be equivalent to passing that token's values through to the output + values = qkv[-1] + return self.to_out(values) + + # split out heads + + q, k, v = rearrange_many(qkv, '... n (h d) -> ... h n d', h=self.heads) + + # scale + + q = q * self.scale + + # rotate positions into queries and keys for time attention + + if exists(self.rotary_emb): + q = self.rotary_emb.rotate_queries_or_keys(q) + k = self.rotary_emb.rotate_queries_or_keys(k) + + # similarity + + sim = einsum('... h i d, ... h j d -> ... h i j', q, k) + + # relative positional bias + + if exists(pos_bias): + sim = sim + pos_bias + + if exists(focus_present_mask) and not (~focus_present_mask).all(): + attend_all_mask = torch.ones( + (n, n), device=device, dtype=torch.bool) + attend_self_mask = torch.eye(n, device=device, dtype=torch.bool) + + mask = torch.where( + rearrange(focus_present_mask, 'b -> b 1 1 1 1'), + rearrange(attend_self_mask, 'i j -> 1 1 1 i j'), + rearrange(attend_all_mask, 'i j -> 1 1 1 i j'), + ) + + sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) + + # numerical stability + + sim = sim - sim.amax(dim=-1, keepdim=True).detach() + attn = sim.softmax(dim=-1) + + # aggregate values + + out = einsum('... h i j, ... h j d -> ... h i d', attn, v) + out = rearrange(out, '... h n d -> ... n (h d)') + return self.to_out(out) + +# model + + +class Unet3D(nn.Module): + def __init__( + self, + dim, + cond_dim=None, + out_dim=None, + dim_mults=(1, 2, 4, 8), + channels=3, + attn_heads=8, + attn_dim_head=32, + use_bert_text_cond=False, + init_dim=None, + init_kernel_size=7, + use_sparse_linear_attn=True, + block_type='resnet', + resnet_groups=8 + ): + super().__init__() + self.channels = channels + + # temporal attention and its relative positional encoding + + rotary_emb = RotaryEmbedding(min(32, attn_dim_head)) + + def temporal_attn(dim): return EinopsToAndFrom('b c f h w', 'b (h w) f c', Attention( + dim, heads=attn_heads, dim_head=attn_dim_head, rotary_emb=rotary_emb)) + + # realistically will not be able to generate that many frames of video... yet + self.time_rel_pos_bias = RelativePositionBias( + heads=attn_heads, max_distance=32) + + # initial conv + + init_dim = default(init_dim, dim) + assert is_odd(init_kernel_size) + + init_padding = init_kernel_size // 2 + self.init_conv = nn.Conv3d(channels, init_dim, (1, init_kernel_size, + init_kernel_size), padding=(0, init_padding, init_padding)) + + self.init_temporal_attn = Residual( + PreNorm(init_dim, temporal_attn(init_dim))) + + # dimensions + + dims = [init_dim, *map(lambda m: dim * m, dim_mults)] + in_out = list(zip(dims[:-1], dims[1:])) + + # time conditioning + + time_dim = dim * 4 + self.time_mlp = nn.Sequential( + SinusoidalPosEmb(dim), + nn.Linear(dim, time_dim), + nn.GELU(), + nn.Linear(time_dim, time_dim) + ) + + # text conditioning + + self.has_cond = exists(cond_dim) or use_bert_text_cond + cond_dim = BERT_MODEL_DIM if use_bert_text_cond else cond_dim + + self.null_cond_emb = nn.Parameter( + torch.randn(1, cond_dim)) if self.has_cond else None + + cond_dim = time_dim + int(cond_dim or 0) + + # layers + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + + num_resolutions = len(in_out) + # block type + + block_klass = partial(ResnetBlock, groups=resnet_groups) + block_klass_cond = partial(block_klass, time_emb_dim=cond_dim) + + # modules for all layers + for ind, (dim_in, dim_out) in enumerate(in_out): + is_last = ind >= (num_resolutions - 1) + + self.downs.append(nn.ModuleList([ + block_klass_cond(dim_in, dim_out), + block_klass_cond(dim_out, dim_out), + Residual(PreNorm(dim_out, SpatialLinearAttention( + dim_out, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(), + Residual(PreNorm(dim_out, temporal_attn(dim_out))), + Downsample(dim_out) if not is_last else nn.Identity() + ])) + + mid_dim = dims[-1] + self.mid_block1 = block_klass_cond(mid_dim, mid_dim) + + spatial_attn = EinopsToAndFrom( + 'b c f h w', 'b f (h w) c', Attention(mid_dim, heads=attn_heads)) + + self.mid_spatial_attn = Residual(PreNorm(mid_dim, spatial_attn)) + self.mid_temporal_attn = Residual( + PreNorm(mid_dim, temporal_attn(mid_dim))) + + self.mid_block2 = block_klass_cond(mid_dim, mid_dim) + + for ind, (dim_in, dim_out) in enumerate(reversed(in_out)): + is_last = ind >= (num_resolutions - 1) + + self.ups.append(nn.ModuleList([ + block_klass_cond(dim_out * 2, dim_in), + block_klass_cond(dim_in, dim_in), + Residual(PreNorm(dim_in, SpatialLinearAttention( + dim_in, heads=attn_heads))) if use_sparse_linear_attn else nn.Identity(), + Residual(PreNorm(dim_in, temporal_attn(dim_in))), + Upsample(dim_in) if not is_last else nn.Identity() + ])) + + out_dim = default(out_dim, channels) + self.final_conv = nn.Sequential( + block_klass(dim * 2, dim), + nn.Conv3d(dim, out_dim, 1) + ) + + def forward_with_cond_scale( + self, + *args, + cond_scale=2., + **kwargs + ): + logits = self.forward(*args, null_cond_prob=0., **kwargs) + if cond_scale == 1 or not self.has_cond: + return logits + + null_logits = self.forward(*args, null_cond_prob=1., **kwargs) + return null_logits + (logits - null_logits) * cond_scale + + def forward( + self, + x, + time, + cond=None, + null_cond_prob=0., + focus_present_mask=None, + # probability at which a given batch sample will focus on the present (0. is all off, 1. is completely arrested attention across time) + prob_focus_present=0. + ): + assert not (self.has_cond and not exists(cond) + ), 'cond must be passed in if cond_dim specified' + x = torch.cat([x, cond], dim=1) + + batch, device = x.shape[0], x.device + + focus_present_mask = default(focus_present_mask, lambda: prob_mask_like( + (batch,), prob_focus_present, device=device)) + + time_rel_pos_bias = self.time_rel_pos_bias(x.shape[2], device=x.device) + + x = self.init_conv(x) + r = x.clone() + + x = self.init_temporal_attn(x, pos_bias=time_rel_pos_bias) + + t = self.time_mlp(time) if exists(self.time_mlp) else None # [2, 128] + + # classifier free guidance + + if self.has_cond: + batch, device = x.shape[0], x.device + mask = prob_mask_like((batch,), null_cond_prob, device=device) + cond = torch.where(rearrange(mask, 'b -> b 1'), + self.null_cond_emb, cond) + t = torch.cat((t, cond), dim=-1) + + h = [] + + for block1, block2, spatial_attn, temporal_attn, downsample in self.downs: + x = block1(x, t) + x = block2(x, t) + x = spatial_attn(x) + x = temporal_attn(x, pos_bias=time_rel_pos_bias, + focus_present_mask=focus_present_mask) + h.append(x) + x = downsample(x) + + # [2, 256, 32, 4, 4] + x = self.mid_block1(x, t) + x = self.mid_spatial_attn(x) + x = self.mid_temporal_attn( + x, pos_bias=time_rel_pos_bias, focus_present_mask=focus_present_mask) + x = self.mid_block2(x, t) + + for block1, block2, spatial_attn, temporal_attn, upsample in self.ups: + x = torch.cat((x, h.pop()), dim=1) + x = block1(x, t) + x = block2(x, t) + x = spatial_attn(x) + x = temporal_attn(x, pos_bias=time_rel_pos_bias, + focus_present_mask=focus_present_mask) + x = upsample(x) + + x = torch.cat((x, r), dim=1) + return self.final_conv(x) + +# gaussian diffusion trainer class + + +def extract(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def cosine_beta_schedule(timesteps, s=0.008): + """ + cosine schedule + as proposed in https://openreview.net/forum?id=-NEXDKk8gZ + """ + steps = timesteps + 1 + x = torch.linspace(0, timesteps, steps, dtype=torch.float64) + alphas_cumprod = torch.cos( + ((x / timesteps) + s) / (1 + s) * torch.pi * 0.5) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) + return torch.clip(betas, 0, 0.9999) + + +class GaussianDiffusion(nn.Module): + def __init__( + self, + denoise_fn, + *, + image_size, + num_frames, + text_use_bert_cls=False, + channels=3, + timesteps=1000, + loss_type='l1', + use_dynamic_thres=False, # from the Imagen paper + dynamic_thres_percentile=0.9, + vqgan_ckpt=None, + device=None + ): + super().__init__() + self.channels = channels + self.image_size = image_size + self.num_frames = num_frames + self.denoise_fn = denoise_fn + self.device = device + + if vqgan_ckpt: + self.vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt).cuda() + self.vqgan.eval() + else: + self.vqgan = None + + betas = cosine_beta_schedule(timesteps) + + alphas = 1. - betas + alphas_cumprod = torch.cumprod(alphas, axis=0) + alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value=1.) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.loss_type = loss_type + + # register buffer helper function that casts float64 to float32 + + def register_buffer(name, val): return self.register_buffer( + name, val.to(torch.float32)) + + register_buffer('betas', betas) + register_buffer('alphas_cumprod', alphas_cumprod) + register_buffer('alphas_cumprod_prev', alphas_cumprod_prev) + + # calculations for diffusion q(x_t | x_{t-1}) and others + + register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod)) + register_buffer('sqrt_one_minus_alphas_cumprod', + torch.sqrt(1. - alphas_cumprod)) + register_buffer('log_one_minus_alphas_cumprod', + torch.log(1. - alphas_cumprod)) + register_buffer('sqrt_recip_alphas_cumprod', + torch.sqrt(1. / alphas_cumprod)) + register_buffer('sqrt_recipm1_alphas_cumprod', + torch.sqrt(1. / alphas_cumprod - 1)) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + + posterior_variance = betas * \ + (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) + + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + + register_buffer('posterior_variance', posterior_variance) + + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + + register_buffer('posterior_log_variance_clipped', + torch.log(posterior_variance.clamp(min=1e-20))) + register_buffer('posterior_mean_coef1', betas * + torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) + register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) + * torch.sqrt(alphas) / (1. - alphas_cumprod)) + + # text conditioning parameters + + self.text_use_bert_cls = text_use_bert_cls + + # dynamic thresholding when sampling + + self.use_dynamic_thres = use_dynamic_thres + self.dynamic_thres_percentile = dynamic_thres_percentile + + def q_mean_variance(self, x_start, t): + mean = extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract(1. - self.alphas_cumprod, t, x_start.shape) + log_variance = extract( + self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - + extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract( + self.posterior_log_variance_clipped, t, x_t.shape) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool, cond=None, cond_scale=1.): + x_recon = self.predict_start_from_noise( + x, t=t, noise=self.denoise_fn.forward_with_cond_scale(x, t, cond=cond, cond_scale=cond_scale)) + + if clip_denoised: + s = 1. + if self.use_dynamic_thres: + s = torch.quantile( + rearrange(x_recon, 'b ... -> b (...)').abs(), + self.dynamic_thres_percentile, + dim=-1 + ) + + s.clamp_(min=1.) + s = s.view(-1, *((1,) * (x_recon.ndim - 1))) + + # clip by threshold, depending on whether static or dynamic + x_recon = x_recon.clamp(-s, s) / s + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t) + return model_mean, posterior_variance, posterior_log_variance + + @torch.inference_mode() + def p_sample(self, x, t, cond=None, cond_scale=1., clip_denoised=True): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance( + x=x, t=t, clip_denoised=clip_denoised, cond=cond, cond_scale=cond_scale) + noise = torch.randn_like(x) + # no noise when t == 0 + nonzero_mask = (1 - (t == 0).float()).reshape(b, + *((1,) * (len(x.shape) - 1))) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.inference_mode() + def p_sample_loop(self, shape, cond=None, cond_scale=1.): + device = self.betas.device + + b = shape[0] + img = torch.randn(shape, device=device) + # print('cond', cond.shape) + for i in reversed(range(0, self.num_timesteps)): + img = self.p_sample(img, torch.full( + (b,), i, device=device, dtype=torch.long), cond=cond, cond_scale=cond_scale) + + return img + + @torch.inference_mode() + def sample(self, cond=None, cond_scale=1., batch_size=16): + device = next(self.denoise_fn.parameters()).device + + if is_list_str(cond): + cond = bert_embed(tokenize(cond)).to(device) + + # batch_size = cond.shape[0] if exists(cond) else batch_size + batch_size = batch_size + image_size = self.image_size + channels = 8 # self.channels + num_frames = self.num_frames + # print((batch_size, channels, num_frames, image_size, image_size)) + # print('cond_',cond.shape) + _sample = self.p_sample_loop( + (batch_size, channels, num_frames, image_size, image_size), cond=cond, cond_scale=cond_scale) + + if isinstance(self.vqgan, VQGAN): + # denormalize TODO: Remove eventually + _sample = (((_sample + 1.0) / 2.0) * (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) + self.vqgan.codebook.embeddings.min() + + _sample = self.vqgan.decode(_sample, quantize=True) + else: + unnormalize_img(_sample) + + return _sample + + @torch.inference_mode() + def interpolate(self, x1, x2, t=None, lam=0.5): + b, *_, device = *x1.shape, x1.device + t = default(t, self.num_timesteps - 1) + + assert x1.shape == x2.shape + + t_batched = torch.stack([torch.tensor(t, device=device)] * b) + xt1, xt2 = map(lambda x: self.q_sample(x, t=t_batched), (x1, x2)) + + img = (1 - lam) * xt1 + lam * xt2 + for i in reversed(range(0, t)): + img = self.p_sample(img, torch.full( + (b,), i, device=device, dtype=torch.long)) + + return img + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + + return ( + extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract(self.sqrt_one_minus_alphas_cumprod, + t, x_start.shape) * noise + ) + + def p_losses(self, x_start, t, cond=None, noise=None, **kwargs): + b, c, f, h, w, device = *x_start.shape, x_start.device + noise = default(noise, lambda: torch.randn_like(x_start)) + # breakpoint() + x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise) # [2, 8, 32, 32, 32] + + if is_list_str(cond): + cond = bert_embed( + tokenize(cond), return_cls_repr=self.text_use_bert_cls) + cond = cond.to(device) + + x_recon = self.denoise_fn(x_noisy, t, cond=cond, **kwargs) + + if self.loss_type == 'l1': + loss = F.l1_loss(noise, x_recon) + elif self.loss_type == 'l2': + loss = F.mse_loss(noise, x_recon) + else: + raise NotImplementedError() + + return loss + + def forward(self, x, *args, **kwargs): + bs = int(x.shape[0]/2) + img=x[:bs,...] + mask=x[bs:,...] + mask_=(1-mask).detach() + masked_img = (img*mask_).detach() + masked_img=masked_img.permute(0,1,-1,-3,-2) + img=img.permute(0,1,-1,-3,-2) + mask=mask.permute(0,1,-1,-3,-2) + # breakpoint() + if isinstance(self.vqgan, VQGAN): + with torch.no_grad(): + img = self.vqgan.encode( + img, quantize=False, include_embeddings=True) + # normalize to -1 and 1 + img = ((img - self.vqgan.codebook.embeddings.min()) / + (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + masked_img = self.vqgan.encode( + masked_img, quantize=False, include_embeddings=True) + # normalize to -1 and 1 + masked_img = ((masked_img - self.vqgan.codebook.embeddings.min()) / + (self.vqgan.codebook.embeddings.max() - + self.vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + else: + print("Hi") + img = normalize_img(img) + masked_img = normalize_img(masked_img) + mask = mask*2.0 - 1.0 + cc = torch.nn.functional.interpolate(mask, size=masked_img.shape[-3:]) + cond = torch.cat((masked_img, cc), dim=1) + + b, device, img_size, = img.shape[0], img.device, self.image_size + t = torch.randint(0, self.num_timesteps, (b,), device=device).long() + # breakpoint() + return self.p_losses(img, t, cond=cond, *args, **kwargs) + +# trainer class + + +CHANNELS_TO_MODE = { + 1: 'L', + 3: 'RGB', + 4: 'RGBA' +} + + +def seek_all_images(img, channels=3): + assert channels in CHANNELS_TO_MODE, f'channels {channels} invalid' + mode = CHANNELS_TO_MODE[channels] + + i = 0 + while True: + try: + img.seek(i) + yield img.convert(mode) + except EOFError: + break + i += 1 + +# tensor of shape (channels, frames, height, width) -> gif + + +def video_tensor_to_gif(tensor, path, duration=120, loop=0, optimize=True): + tensor = ((tensor - tensor.min()) / (tensor.max() - tensor.min())) * 1.0 + images = map(T.ToPILImage(), tensor.unbind(dim=1)) + first_img, *rest_imgs = images + first_img.save(path, save_all=True, append_images=rest_imgs, + duration=duration, loop=loop, optimize=optimize) + return images + +# gif -> (channels, frame, height, width) tensor + + +def gif_to_tensor(path, channels=3, transform=T.ToTensor()): + img = Image.open(path) + tensors = tuple(map(transform, seek_all_images(img, channels=channels))) + return torch.stack(tensors, dim=1) + + +def identity(t, *args, **kwargs): + return t + + +def normalize_img(t): + return t * 2 - 1 + + +def unnormalize_img(t): + return (t + 1) * 0.5 + + +def cast_num_frames(t, *, frames): + f = t.shape[1] + + if f == frames: + return t + + if f > frames: + return t[:, :frames] + + return F.pad(t, (0, 0, 0, 0, 0, frames - f)) + + +class Dataset(data.Dataset): + def __init__( + self, + folder, + image_size, + channels=3, + num_frames=16, + horizontal_flip=False, + force_num_frames=True, + exts=['gif'] + ): + super().__init__() + self.folder = folder + self.image_size = image_size + self.channels = channels + self.paths = [p for ext in exts for p in Path( + f'{folder}').glob(f'**/*.{ext}')] + + self.cast_num_frames_fn = partial( + cast_num_frames, frames=num_frames) if force_num_frames else identity + + self.transform = T.Compose([ + T.Resize(image_size), + T.RandomHorizontalFlip() if horizontal_flip else T.Lambda(identity), + T.CenterCrop(image_size), + T.ToTensor() + ]) + + def __len__(self): + return len(self.paths) + + def __getitem__(self, index): + path = self.paths[index] + tensor = gif_to_tensor(path, self.channels, transform=self.transform) + return self.cast_num_frames_fn(tensor) + +# trainer class + + +class Tester(object): + def __init__( + self, + diffusion_model, + ): + super().__init__() + self.model = diffusion_model + self.ema_model = copy.deepcopy(self.model) + self.step=0 + self.image_size = diffusion_model.image_size + + self.reset_parameters() + + def reset_parameters(self): + self.ema_model.load_state_dict(self.model.state_dict()) + + + def load(self, milestone, map_location=None, **kwargs): + if milestone == -1: + all_milestones = [int(p.stem.split('-')[-1]) + for p in Path(self.results_folder).glob('**/*.pt')] + assert len( + all_milestones) > 0, 'need to have at least one milestone to load from latest checkpoint (milestone == -1)' + milestone = max(all_milestones) + + if map_location: + data = torch.load(milestone, map_location=map_location) + else: + data = torch.load(milestone) + + self.step = data['step'] + self.model.load_state_dict(data['model'], **kwargs) + self.ema_model.load_state_dict(data['ema'], **kwargs) + + diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/text.py b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/text.py new file mode 100644 index 0000000000000000000000000000000000000000..42dca78bec5075fb4f59c522aa3d00cc395d7536 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/text.py @@ -0,0 +1,94 @@ +"Taken and adapted from https://github.com/lucidrains/video-diffusion-pytorch" + +import torch +from einops import rearrange + + +def exists(val): + return val is not None + +# singleton globals + + +MODEL = None +TOKENIZER = None +BERT_MODEL_DIM = 768 + + +def get_tokenizer(): + global TOKENIZER + if not exists(TOKENIZER): + TOKENIZER = torch.hub.load( + 'huggingface/pytorch-transformers', 'tokenizer', 'bert-base-cased') + return TOKENIZER + + +def get_bert(): + global MODEL + if not exists(MODEL): + MODEL = torch.hub.load( + 'huggingface/pytorch-transformers', 'model', 'bert-base-cased') + if torch.cuda.is_available(): + MODEL = MODEL.cuda() + + return MODEL + +# tokenize + + +def tokenize(texts, add_special_tokens=True): + if not isinstance(texts, (list, tuple)): + texts = [texts] + + tokenizer = get_tokenizer() + + encoding = tokenizer.batch_encode_plus( + texts, + add_special_tokens=add_special_tokens, + padding=True, + return_tensors='pt' + ) + + token_ids = encoding.input_ids + return token_ids + +# embedding function + + +@torch.no_grad() +def bert_embed( + token_ids, + return_cls_repr=False, + eps=1e-8, + pad_id=0. +): + model = get_bert() + mask = token_ids != pad_id + + if torch.cuda.is_available(): + token_ids = token_ids.cuda() + mask = mask.cuda() + + outputs = model( + input_ids=token_ids, + attention_mask=mask, + output_hidden_states=True + ) + + hidden_state = outputs.hidden_states[-1] + + if return_cls_repr: + # return [cls] as representation + return hidden_state[:, 0] + + if not exists(mask): + return hidden_state.mean(dim=1) + + # mean all tokens excluding [cls], accounting for length + mask = mask[:, 1:] + mask = rearrange(mask, 'b n -> b n 1') + + numer = (hidden_state[:, 1:] * mask).sum(dim=1) + denom = mask.sum(dim=1) + masked_mean = numer / (denom + eps) + return diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/time_embedding.py b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/time_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..7c911fa4f485295005cda9c9c2099db3ddbda15c --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/time_embedding.py @@ -0,0 +1,75 @@ +import math +import torch +import torch.nn as nn + +from monai.networks.layers.utils import get_act_layer + + +class SinusoidalPosEmb(nn.Module): + def __init__(self, emb_dim=16, downscale_freq_shift=1, max_period=10000, flip_sin_to_cos=False): + super().__init__() + self.emb_dim = emb_dim + self.downscale_freq_shift = downscale_freq_shift + self.max_period = max_period + self.flip_sin_to_cos = flip_sin_to_cos + + def forward(self, x): + device = x.device + half_dim = self.emb_dim // 2 + emb = math.log(self.max_period) / \ + (half_dim - self.downscale_freq_shift) + emb = torch.exp(-emb*torch.arange(half_dim, device=device)) + emb = x[:, None] * emb[None, :] + emb = torch.cat((emb.sin(), emb.cos()), dim=-1) + + if self.flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + if self.emb_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +class LearnedSinusoidalPosEmb(nn.Module): + """ following @crowsonkb 's lead with learned sinusoidal pos emb """ + """ https://github.com/crowsonkb/v-diffusion-jax/blob/master/diffusion/models/danbooru_128.py#L8 """ + + def __init__(self, emb_dim): + super().__init__() + self.emb_dim = emb_dim + half_dim = emb_dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x): + x = x[:, None] + freqs = x * self.weights[None, :] * 2 * math.pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + if self.emb_dim % 2 == 1: + fouriered = torch.nn.functional.pad(fouriered, (0, 1, 0, 0)) + return fouriered + + +class TimeEmbbeding(nn.Module): + def __init__( + self, + emb_dim=64, + pos_embedder=SinusoidalPosEmb, + pos_embedder_kwargs={}, + act_name=("SWISH", {}) # Swish = SiLU + ): + super().__init__() + self.emb_dim = emb_dim + self.pos_emb_dim = pos_embedder_kwargs.get('emb_dim', emb_dim//4) + pos_embedder_kwargs['emb_dim'] = self.pos_emb_dim + self.pos_embedder = pos_embedder(**pos_embedder_kwargs) + + self.time_emb = nn.Sequential( + self.pos_embedder, + nn.Linear(self.pos_emb_dim, self.emb_dim), + get_act_layer(act_name), + nn.Linear(self.emb_dim, self.emb_dim) + ) + + def forward(self, time): + return self.time_emb(time) diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/unet.py b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..e9cae8021e874a92b6baaaf4decd802516c2dc87 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/unet.py @@ -0,0 +1,226 @@ +from ddpm.time_embedding import TimeEmbbeding + +import monai.networks.nets as nets +import torch +import torch.nn as nn +from einops import rearrange + +from monai.networks.blocks import UnetBasicBlock, UnetResBlock, UnetUpBlock, Convolution, UnetOutBlock +from monai.networks.layers.utils import get_act_layer + + +class DownBlock(nn.Module): + def __init__( + self, + spatial_dims, + in_ch, + out_ch, + time_emb_dim, + cond_emb_dim, + act_name=("swish", {}), + **kwargs): + super(DownBlock, self).__init__() + self.loca_time_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(time_emb_dim, in_ch) # in_ch * 2 + ) + if cond_emb_dim is not None: + self.loca_cond_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(cond_emb_dim, in_ch), + ) + self.down_op = UnetBasicBlock( + spatial_dims, in_ch, out_ch, act_name=act_name, **kwargs) + + def forward(self, x, time_emb, cond_emb): + b, c, *_ = x.shape + sp_dim = x.ndim-2 + + # ------------ Time ---------- + time_emb = self.loca_time_embedder(time_emb) + time_emb = time_emb.reshape(b, c, *((1,)*sp_dim)) + # scale, shift = time_emb.chunk(2, dim = 1) + + # ------------ Combine ------------ + # x = x * (scale + 1) + shift + x = x + time_emb + + # ----------- Condition ------------ + if cond_emb is not None: + cond_emb = self.loca_cond_embedder(cond_emb) + cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim)) + x = x + cond_emb + + # ----------- Image --------- + y = self.down_op(x) + return y + + +class UpBlock(nn.Module): + def __init__( + self, + spatial_dims, + skip_ch, + enc_ch, + time_emb_dim, + cond_emb_dim, + act_name=("swish", {}), + **kwargs): + super(UpBlock, self).__init__() + self.up_op = UnetUpBlock(spatial_dims, enc_ch, + skip_ch, act_name=act_name, **kwargs) + self.loca_time_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(time_emb_dim, skip_ch * 2), + ) + if cond_emb_dim is not None: + self.loca_cond_embedder = nn.Sequential( + get_act_layer(name=act_name), + nn.Linear(cond_emb_dim, skip_ch * 2), + ) + + def forward(self, x_skip, x_enc, time_emb, cond_emb): + b, c, *_ = x_enc.shape + sp_dim = x_enc.ndim-2 + + # ----------- Time -------------- + time_emb = self.loca_time_embedder(time_emb) + time_emb = time_emb.reshape(b, c, *((1,)*sp_dim)) + # scale, shift = time_emb.chunk(2, dim = 1) + + # -------- Combine ------------- + # y = x * (scale + 1) + shift + x_enc = x_enc + time_emb + + # ----------- Condition ------------ + if cond_emb is not None: + cond_emb = self.loca_cond_embedder(cond_emb) + cond_emb = cond_emb.reshape(b, c, *((1,)*sp_dim)) + x_enc = x_enc + cond_emb + + # ----------- Image ------------- + y = self.up_op(x_enc, x_skip) + + # -------- Combine ------------- + # y = y * (scale + 1) + shift + + return y + + +class UNet(nn.Module): + + def __init__(self, + in_ch=1, + out_ch=1, + spatial_dims=3, + hid_chs=[32, 64, 128, 256, 512], + kernel_sizes=[(1, 3, 3), (1, 3, 3), (1, 3, 3), 3, 3], + strides=[1, (1, 2, 2), (1, 2, 2), 2, 2], + upsample_kernel_sizes=None, + act_name=("SWISH", {}), + norm_name=("INSTANCE", {"affine": True}), + time_embedder=TimeEmbbeding, + time_embedder_kwargs={}, + cond_embedder=None, + cond_embedder_kwargs={}, + # True = all but last layer, 0/False=disable, 1=only first layer, ... + deep_ver_supervision=True, + estimate_variance=False, + use_self_conditioning=False, + **kwargs + ): + super().__init__() + if upsample_kernel_sizes is None: + upsample_kernel_sizes = strides[1:] + + # ------------- Time-Embedder----------- + self.time_embedder = time_embedder(**time_embedder_kwargs) + + # ------------- Condition-Embedder----------- + if cond_embedder is not None: + self.cond_embedder = cond_embedder(**cond_embedder_kwargs) + cond_emb_dim = self.cond_embedder.emb_dim + else: + self.cond_embedder = None + cond_emb_dim = None + + # ----------- In-Convolution ------------ + in_ch = in_ch*2 if use_self_conditioning else in_ch + self.inc = UnetBasicBlock(spatial_dims, in_ch, hid_chs[0], kernel_size=kernel_sizes[0], stride=strides[0], + act_name=act_name, norm_name=norm_name, **kwargs) + + # ----------- Encoder ---------------- + self.encoders = nn.ModuleList([ + DownBlock(spatial_dims, hid_chs[i-1], hid_chs[i], time_emb_dim=self.time_embedder.emb_dim, + cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[ + i], stride=strides[i], act_name=act_name, + norm_name=norm_name, **kwargs) + for i in range(1, len(strides)) + ]) + + # ------------ Decoder ---------- + self.decoders = nn.ModuleList([ + UpBlock(spatial_dims, hid_chs[i], hid_chs[i+1], time_emb_dim=self.time_embedder.emb_dim, + cond_emb_dim=cond_emb_dim, kernel_size=kernel_sizes[i + + 1], stride=strides[i+1], act_name=act_name, + norm_name=norm_name, upsample_kernel_size=upsample_kernel_sizes[i], **kwargs) + for i in range(len(strides)-1) + ]) + + # --------------- Out-Convolution ---------------- + out_ch_hor = out_ch*2 if estimate_variance else out_ch + self.outc = UnetOutBlock( + spatial_dims, hid_chs[0], out_ch_hor, dropout=None) + if isinstance(deep_ver_supervision, bool): + deep_ver_supervision = len( + strides)-2 if deep_ver_supervision else 0 + self.outc_ver = nn.ModuleList([ + UnetOutBlock(spatial_dims, hid_chs[i], out_ch, dropout=None) + for i in range(1, deep_ver_supervision+1) + ]) + + def forward(self, x_t, t, cond=None, self_cond=None, **kwargs): + condition = cond + # x_t [B, C, (D), H, W] + # t [B,] + + # -------- In-Convolution -------------- + x = [None for _ in range(len(self.encoders)+1)] + x_t = torch.cat([x_t, self_cond], + dim=1) if self_cond is not None else x_t + x[0] = self.inc(x_t) + + # -------- Time Embedding (Gloabl) ----------- + time_emb = self.time_embedder(t) # [B, C] + + # -------- Condition Embedding (Gloabl) ----------- + if (condition is None) or (self.cond_embedder is None): + cond_emb = None + else: + cond_emb = self.cond_embedder(condition) # [B, C] + + # --------- Encoder -------------- + for i in range(len(self.encoders)): + x[i+1] = self.encoders[i](x[i], time_emb, cond_emb) + + # -------- Decoder ----------- + for i in range(len(self.decoders), 0, -1): + x[i-1] = self.decoders[i-1](x[i-1], x[i], time_emb, cond_emb) + + # ---------Out-Convolution ------------ + y_hor = self.outc(x[0]) + y_ver = [outc_ver_i(x[i+1]) + for i, outc_ver_i in enumerate(self.outc_ver)] + + return y_hor # , y_ver + + def forward_with_cond_scale(self, *args, cond_scale=0., **kwargs): + return self.forward(*args, **kwargs) + + +if __name__ == '__main__': + model = UNet(in_ch=3) + input = torch.randn((1, 3, 16, 128, 128)) + time = torch.randn((1,)) + out_hor, out_ver = model(input, time) + print(out_hor[0].shape) diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/util.py b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/util.py new file mode 100644 index 0000000000000000000000000000000000000000..2047cf56802046bbf37cfd33f819ed75a239a7df --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/ddpm/util.py @@ -0,0 +1,271 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +# from ldm.util import instantiate_from_config + + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + if c != 1: + steps_out = ddim_timesteps + 1 + else: + steps_out = ddim_timesteps + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() \ No newline at end of file diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c9267820ac3147285b3dbc3eca053f259ca015c Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/__pycache__/utils.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__init__.py b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38c08845c1ee2f7b4b52bf95fa282b0808931a3a --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__init__.py @@ -0,0 +1,3 @@ +from .vqgan import VQGAN +from .codebook import Codebook +from .lpips import LPIPS diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5af25919e8562ffd0096c9a9795af5e6311d5dc Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/__init__.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c14899b2127c37d8cdba304c2380cbf6fc3ccd3 Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/codebook.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6fd1b98c409d33c2783dcf7b43ee55e86cd775d Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/lpips.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..708f349ec739b22404f5242cc4945d38baf8872f Binary files /dev/null and b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/__pycache__/vqgan.cpython-38.pyc differ diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth new file mode 100644 index 0000000000000000000000000000000000000000..f57dcf5cc764d61c8a460365847fb2137ff0a62d --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/cache/vgg.pth @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a78928a0af1e5f0fcb1f3b9e8f8c3a2a5a3de244d830ad5c1feddc79b8432868 +size 7289 diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/codebook.py b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/codebook.py new file mode 100644 index 0000000000000000000000000000000000000000..c17ed701a2824cc0dcd75670d47a6ab3842e9d35 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/codebook.py @@ -0,0 +1,109 @@ +""" Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import numpy as np + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from ..utils import shift_dim + + +class Codebook(nn.Module): + def __init__(self, n_codes, embedding_dim, no_random_restart=False, restart_thres=1.0): + super().__init__() + self.register_buffer('embeddings', torch.randn(n_codes, embedding_dim)) + self.register_buffer('N', torch.zeros(n_codes)) + self.register_buffer('z_avg', self.embeddings.data.clone()) + + self.n_codes = n_codes + self.embedding_dim = embedding_dim + self._need_init = True + self.no_random_restart = no_random_restart + self.restart_thres = restart_thres + + def _tile(self, x): + d, ew = x.shape + if d < self.n_codes: + n_repeats = (self.n_codes + d - 1) // d + std = 0.01 / np.sqrt(ew) + x = x.repeat(n_repeats, 1) + x = x + torch.randn_like(x) * std + return x + + def _init_embeddings(self, z): + # z: [b, c, t, h, w] + self._need_init = False + breakpoint() + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [65536, 8] [2, 8, 32, 32, 32] + y = self._tile(flat_inputs) # [65536, 8] + + d = y.shape[0] + _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + self.embeddings.data.copy_(_k_rand) + self.z_avg.data.copy_(_k_rand) + self.N.data.copy_(torch.ones(self.n_codes)) + + def forward(self, z): + # z: [b, c, t, h, w] + if self._need_init and self.training: + self._init_embeddings(z) + flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2) # [bthw, c] [65536, 8] + distances = (flat_inputs ** 2).sum(dim=1, keepdim=True) \ + - 2 * flat_inputs @ self.embeddings.t() \ + + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True) # [bthw, c] [65536, 8] + + encoding_indices = torch.argmin(distances, dim=1) # [65536] + encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as( + flat_inputs) # [bthw, ncode] [65536, 16384] + encoding_indices = encoding_indices.view( + z.shape[0], *z.shape[2:]) # [b, t, h, w, ncode] [2, 32, 32, 32] + + embeddings = F.embedding( + encoding_indices, self.embeddings) # [b, t, h, w, c] self.embeddings [16384, 8] + embeddings = shift_dim(embeddings, -1, 1) # [b, c, t, h, w] [2, 8, 32, 32, 32] + + commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach()) + + # EMA codebook update + if self.training: + n_total = encode_onehot.sum(dim=0) # [16384] + encode_sum = flat_inputs.t() @ encode_onehot # [8, 16384] + if dist.is_initialized(): + dist.all_reduce(n_total) + dist.all_reduce(encode_sum) + + self.N.data.mul_(0.99).add_(n_total, alpha=0.01) + self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01) + + n = self.N.sum() + weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n + encode_normalized = self.z_avg / weights.unsqueeze(1) + self.embeddings.data.copy_(encode_normalized) + + y = self._tile(flat_inputs) + _k_rand = y[torch.randperm(y.shape[0])][:self.n_codes] + if dist.is_initialized(): + dist.broadcast(_k_rand, 0) + + if not self.no_random_restart: + usage = (self.N.view(self.n_codes, 1) + >= self.restart_thres).float() + self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage)) + + embeddings_st = (embeddings - z).detach() + z + + avg_probs = torch.mean(encode_onehot, dim=0) + perplexity = torch.exp(-torch.sum(avg_probs * + torch.log(avg_probs + 1e-10))) + + return dict(embeddings=embeddings_st, encodings=encoding_indices, + commitment_loss=commitment_loss, perplexity=perplexity) + + def dictionary_lookup(self, encodings): + embeddings = F.embedding(encodings, self.embeddings) + return embeddings diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/lpips.py b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/lpips.py new file mode 100644 index 0000000000000000000000000000000000000000..019434b9b29945d67e6e0a95dec324df7ff908f3 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/lpips.py @@ -0,0 +1,181 @@ +"""Adapted from https://github.com/SongweiGe/TATS""" + +"""Stripped version of https://github.com/richzhang/PerceptualSimilarity/tree/master/models""" + + +from collections import namedtuple +from torchvision import models +import torch.nn as nn +import torch +from tqdm import tqdm +import requests +import os +import hashlib +URL_MAP = { + "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1" +} + +CKPT_MAP = { + "vgg_lpips": "vgg.pth" +} + +MD5_MAP = { + "vgg_lpips": "d507d7349b931f0638a25a48a722f98a" +} + + +def download(url, local_path, chunk_size=1024): + os.makedirs(os.path.split(local_path)[0], exist_ok=True) + with requests.get(url, stream=True) as r: + total_size = int(r.headers.get("content-length", 0)) + with tqdm(total=total_size, unit="B", unit_scale=True) as pbar: + with open(local_path, "wb") as f: + for data in r.iter_content(chunk_size=chunk_size): + if data: + f.write(data) + pbar.update(chunk_size) + + +def md5_hash(path): + with open(path, "rb") as f: + content = f.read() + return hashlib.md5(content).hexdigest() + + +def get_ckpt_path(name, root, check=False): + assert name in URL_MAP + path = os.path.join(root, CKPT_MAP[name]) + if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]): + print("Downloading {} model from {} to {}".format( + name, URL_MAP[name], path)) + download(URL_MAP[name], path) + md5 = md5_hash(path) + assert md5 == MD5_MAP[name], md5 + return path + + +class LPIPS(nn.Module): + # Learned perceptual metric + def __init__(self, use_dropout=True): + super().__init__() + self.scaling_layer = ScalingLayer() + self.chns = [64, 128, 256, 512, 512] # vg16 features + self.net = vgg16(pretrained=True, requires_grad=False) + self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout) + self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout) + self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout) + self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout) + self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout) + # self.load_from_pretrained() + for param in self.parameters(): + param.requires_grad = False + + def load_from_pretrained(self, name="vgg_lpips"): + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + self.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + print("loaded pretrained LPIPS loss from {}".format(ckpt)) + + @classmethod + def from_pretrained(cls, name="vgg_lpips"): + if name is not "vgg_lpips": + raise NotImplementedError + model = cls() + ckpt = get_ckpt_path(name, os.path.join( + os.path.dirname(os.path.abspath(__file__)), "cache")) + model.load_state_dict(torch.load( + ckpt, map_location=torch.device("cpu")), strict=False) + return model + + def forward(self, input, target): + in0_input, in1_input = (self.scaling_layer( + input), self.scaling_layer(target)) + outs0, outs1 = self.net(in0_input), self.net(in1_input) + feats0, feats1, diffs = {}, {}, {} + lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4] + for kk in range(len(self.chns)): + feats0[kk], feats1[kk] = normalize_tensor( + outs0[kk]), normalize_tensor(outs1[kk]) + diffs[kk] = (feats0[kk] - feats1[kk]) ** 2 + + res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) + for kk in range(len(self.chns))] + val = res[0] + for l in range(1, len(self.chns)): + val += res[l] + return val + + +class ScalingLayer(nn.Module): + def __init__(self): + super(ScalingLayer, self).__init__() + self.register_buffer('shift', torch.Tensor( + [-.030, -.088, -.188])[None, :, None, None]) + self.register_buffer('scale', torch.Tensor( + [.458, .448, .450])[None, :, None, None]) + + def forward(self, inp): + return (inp - self.shift) / self.scale + + +class NetLinLayer(nn.Module): + """ A single linear layer which does a 1x1 conv """ + + def __init__(self, chn_in, chn_out=1, use_dropout=False): + super(NetLinLayer, self).__init__() + layers = [nn.Dropout(), ] if (use_dropout) else [] + layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, + padding=0, bias=False), ] + self.model = nn.Sequential(*layers) + + +class vgg16(torch.nn.Module): + def __init__(self, requires_grad=False, pretrained=True): + super(vgg16, self).__init__() + vgg_pretrained_features = models.vgg16(pretrained=pretrained).features + self.slice1 = torch.nn.Sequential() + self.slice2 = torch.nn.Sequential() + self.slice3 = torch.nn.Sequential() + self.slice4 = torch.nn.Sequential() + self.slice5 = torch.nn.Sequential() + self.N_slices = 5 + for x in range(4): + self.slice1.add_module(str(x), vgg_pretrained_features[x]) + for x in range(4, 9): + self.slice2.add_module(str(x), vgg_pretrained_features[x]) + for x in range(9, 16): + self.slice3.add_module(str(x), vgg_pretrained_features[x]) + for x in range(16, 23): + self.slice4.add_module(str(x), vgg_pretrained_features[x]) + for x in range(23, 30): + self.slice5.add_module(str(x), vgg_pretrained_features[x]) + if not requires_grad: + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X): + h = self.slice1(X) + h_relu1_2 = h + h = self.slice2(h) + h_relu2_2 = h + h = self.slice3(h) + h_relu3_3 = h + h = self.slice4(h) + h_relu4_3 = h + h = self.slice5(h) + h_relu5_3 = h + vgg_outputs = namedtuple( + "VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3']) + out = vgg_outputs(h_relu1_2, h_relu2_2, + h_relu3_3, h_relu4_3, h_relu5_3) + return out + + +def normalize_tensor(x, eps=1e-10): + norm_factor = torch.sqrt(torch.sum(x**2, dim=1, keepdim=True)) + return x/(norm_factor+eps) + + +def spatial_average(x, keepdim=True): + return x.mean([2, 3], keepdim=keepdim) diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py new file mode 100644 index 0000000000000000000000000000000000000000..a64ec8878c0ecbf418775e2958b2cbf578ce2918 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/model/vqgan.py @@ -0,0 +1,561 @@ +"""Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import math +import argparse +import numpy as np +import pickle as pkl + +import pytorch_lightning as pl +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.distributed as dist + +from ..utils import shift_dim, adopt_weight, comp_getattr +from .lpips import LPIPS +from .codebook import Codebook + + +def silu(x): + return x*torch.sigmoid(x) + + +class SiLU(nn.Module): + def __init__(self): + super(SiLU, self).__init__() + + def forward(self, x): + return silu(x) + + +def hinge_d_loss(logits_real, logits_fake): + loss_real = torch.mean(F.relu(1. - logits_real)) + loss_fake = torch.mean(F.relu(1. + logits_fake)) + d_loss = 0.5 * (loss_real + loss_fake) + return d_loss + + +def vanilla_d_loss(logits_real, logits_fake): + d_loss = 0.5 * ( + torch.mean(torch.nn.functional.softplus(-logits_real)) + + torch.mean(torch.nn.functional.softplus(logits_fake))) + return d_loss + + +class VQGAN(pl.LightningModule): + def __init__(self, cfg): + super().__init__() + self.cfg = cfg + self.embedding_dim = cfg.model.embedding_dim # 8 + self.n_codes = cfg.model.n_codes # 16384 + + self.encoder = Encoder(cfg.model.n_hiddens, # 16 + cfg.model.downsample, # [2, 2, 2] + cfg.dataset.image_channels, # 1 + cfg.model.norm_type, # group + cfg.model.padding_type, # replicate + cfg.model.num_groups, # 32 + ) + self.decoder = Decoder( + cfg.model.n_hiddens, cfg.model.downsample, cfg.dataset.image_channels, cfg.model.norm_type, cfg.model.num_groups) + self.enc_out_ch = self.encoder.out_channels + self.pre_vq_conv = SamePadConv3d( + self.enc_out_ch, cfg.model.embedding_dim, 1, padding_type=cfg.model.padding_type) + self.post_vq_conv = SamePadConv3d( + cfg.model.embedding_dim, self.enc_out_ch, 1) + + self.codebook = Codebook(cfg.model.n_codes, cfg.model.embedding_dim, + no_random_restart=cfg.model.no_random_restart, restart_thres=cfg.model.restart_thres) + + self.gan_feat_weight = cfg.model.gan_feat_weight + # TODO: Changed batchnorm from sync to normal + self.image_discriminator = NLayerDiscriminator( + cfg.dataset.image_channels, cfg.model.disc_channels, cfg.model.disc_layers, norm_layer=nn.BatchNorm2d) + self.video_discriminator = NLayerDiscriminator3D( + cfg.dataset.image_channels, cfg.model.disc_channels, cfg.model.disc_layers, norm_layer=nn.BatchNorm3d) + + if cfg.model.disc_loss_type == 'vanilla': + self.disc_loss = vanilla_d_loss + elif cfg.model.disc_loss_type == 'hinge': + self.disc_loss = hinge_d_loss + + self.perceptual_model = LPIPS().eval() + + self.image_gan_weight = cfg.model.image_gan_weight + self.video_gan_weight = cfg.model.video_gan_weight + + self.perceptual_weight = cfg.model.perceptual_weight + + self.l1_weight = cfg.model.l1_weight + self.save_hyperparameters() + + def encode(self, x, include_embeddings=False, quantize=True): + h = self.pre_vq_conv(self.encoder(x)) + if quantize: + vq_output = self.codebook(h) + if include_embeddings: + return vq_output['embeddings'], vq_output['encodings'] + else: + return vq_output['encodings'] + return h + + def decode(self, latent, quantize=False): + if quantize: + vq_output = self.codebook(latent) + latent = vq_output['encodings'] + h = F.embedding(latent, self.codebook.embeddings) + h = self.post_vq_conv(shift_dim(h, -1, 1)) + return self.decoder(h) + + def forward(self, x, optimizer_idx=None, log_image=False): + B, C, T, H, W = x.shape + z = self.pre_vq_conv(self.encoder(x)) # [2, 32, 32, 32, 32] [2, 8, 32, 32, 32] + vq_output = self.codebook(z) # ['embeddings', 'encodings', 'commitment_loss', 'perplexity'] + x_recon = self.decoder(self.post_vq_conv(vq_output['embeddings'])) # [2, 8, 32, 32, 32] [2, 32, 32, 32, 32] + + recon_loss = F.l1_loss(x_recon, x) * self.l1_weight + + # Selects one random 2D image from each 3D Image + frame_idx = torch.randint(0, T, [B]).cuda() + frame_idx_selected = frame_idx.reshape(-1, + 1, 1, 1, 1).repeat(1, C, 1, H, W) # [2, 1, 1, 64, 64] + frames = torch.gather(x, 2, frame_idx_selected).squeeze(2) # [2, 1, 64, 64] + frames_recon = torch.gather(x_recon, 2, frame_idx_selected).squeeze(2) # [2, 1, 64, 64] + + if log_image: + return frames, frames_recon, x, x_recon + + if optimizer_idx == 0: + # Autoencoder - train the "generator" + + # Perceptual loss + perceptual_loss = 0 + if self.perceptual_weight > 0: + perceptual_loss = self.perceptual_model( + frames, frames_recon).mean() * self.perceptual_weight + + # Discriminator loss (turned on after a certain epoch) + logits_image_fake, pred_image_fake = self.image_discriminator( + frames_recon) + logits_video_fake, pred_video_fake = self.video_discriminator( + x_recon) + g_image_loss = -torch.mean(logits_image_fake) + g_video_loss = -torch.mean(logits_video_fake) + g_loss = self.image_gan_weight*g_image_loss + self.video_gan_weight*g_video_loss + disc_factor = adopt_weight( + self.global_step, threshold=self.cfg.model.discriminator_iter_start) + aeloss = disc_factor * g_loss + + # GAN feature matching loss - tune features such that we get the same prediction result on the discriminator + image_gan_feat_loss = 0 + video_gan_feat_loss = 0 + feat_weights = 4.0 / (3 + 1) + if self.image_gan_weight > 0: + logits_image_real, pred_image_real = self.image_discriminator( + frames) + for i in range(len(pred_image_fake)-1): + image_gan_feat_loss += feat_weights * \ + F.l1_loss(pred_image_fake[i], pred_image_real[i].detach( + )) * (self.image_gan_weight > 0) + if self.video_gan_weight > 0: + logits_video_real, pred_video_real = self.video_discriminator( + x) + for i in range(len(pred_video_fake)-1): + video_gan_feat_loss += feat_weights * \ + F.l1_loss(pred_video_fake[i], pred_video_real[i].detach( + )) * (self.video_gan_weight > 0) + gan_feat_loss = disc_factor * self.gan_feat_weight * \ + (image_gan_feat_loss + video_gan_feat_loss) + + self.log("train/g_image_loss", g_image_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/g_video_loss", g_video_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/image_gan_feat_loss", image_gan_feat_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/video_gan_feat_loss", video_gan_feat_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/perceptual_loss", perceptual_loss, + prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log("train/recon_loss", recon_loss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + self.log("train/aeloss", aeloss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + self.log("train/commitment_loss", vq_output['commitment_loss'], + prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log('train/perplexity', vq_output['perplexity'], + prog_bar=True, logger=True, on_step=True, on_epoch=True) + return recon_loss, x_recon, vq_output, aeloss, perceptual_loss, gan_feat_loss + + if optimizer_idx == 1: + # Train discriminator + logits_image_real, _ = self.image_discriminator(frames.detach()) + logits_video_real, _ = self.video_discriminator(x.detach()) + + logits_image_fake, _ = self.image_discriminator( + frames_recon.detach()) + logits_video_fake, _ = self.video_discriminator(x_recon.detach()) + + d_image_loss = self.disc_loss(logits_image_real, logits_image_fake) + d_video_loss = self.disc_loss(logits_video_real, logits_video_fake) + disc_factor = adopt_weight( + self.global_step, threshold=self.cfg.model.discriminator_iter_start) + discloss = disc_factor * \ + (self.image_gan_weight*d_image_loss + + self.video_gan_weight*d_video_loss) + + self.log("train/logits_image_real", logits_image_real.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_image_fake", logits_image_fake.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_video_real", logits_video_real.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/logits_video_fake", logits_video_fake.mean().detach(), + logger=True, on_step=True, on_epoch=True) + self.log("train/d_image_loss", d_image_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/d_video_loss", d_video_loss, + logger=True, on_step=True, on_epoch=True) + self.log("train/discloss", discloss, prog_bar=True, + logger=True, on_step=True, on_epoch=True) + return discloss + + perceptual_loss = self.perceptual_model( + frames, frames_recon) * self.perceptual_weight + return recon_loss, x_recon, vq_output, perceptual_loss + + def training_step(self, batch, batch_idx, optimizer_idx): + x = batch['image'] + if optimizer_idx == 0: + recon_loss, _, vq_output, aeloss, perceptual_loss, gan_feat_loss = self.forward( + x, optimizer_idx) + commitment_loss = vq_output['commitment_loss'] + loss = recon_loss + commitment_loss + aeloss + perceptual_loss + gan_feat_loss + if optimizer_idx == 1: + discloss = self.forward(x, optimizer_idx) + loss = discloss + return loss + + def validation_step(self, batch, batch_idx): + x = batch['image'] # TODO: batch['stft'] + recon_loss, _, vq_output, perceptual_loss = self.forward(x) + self.log('val/recon_loss', recon_loss, prog_bar=True) + self.log('val/perceptual_loss', perceptual_loss, prog_bar=True) + self.log('val/perplexity', vq_output['perplexity'], prog_bar=True) + self.log('val/commitment_loss', + vq_output['commitment_loss'], prog_bar=True) + + def configure_optimizers(self): + lr = self.cfg.model.lr + opt_ae = torch.optim.Adam(list(self.encoder.parameters()) + + list(self.decoder.parameters()) + + list(self.pre_vq_conv.parameters()) + + list(self.post_vq_conv.parameters()) + + list(self.codebook.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(list(self.image_discriminator.parameters()) + + list(self.video_discriminator.parameters()), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def log_images(self, batch, **kwargs): + log = dict() + x = batch['image'] + x = x.to(self.device) + frames, frames_rec, _, _ = self(x, log_image=True) + log["inputs"] = frames + log["reconstructions"] = frames_rec + #log['mean_org'] = batch['mean_org'] + #log['std_org'] = batch['std_org'] + return log + + def log_videos(self, batch, **kwargs): + log = dict() + x = batch['image'] + _, _, x, x_rec = self(x, log_image=True) + log["inputs"] = x + log["reconstructions"] = x_rec + #log['mean_org'] = batch['mean_org'] + #log['std_org'] = batch['std_org'] + return log + + +def Normalize(in_channels, norm_type='group', num_groups=32): + assert norm_type in ['group', 'batch'] + if norm_type == 'group': + # TODO Changed num_groups from 32 to 8 + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + elif norm_type == 'batch': + return torch.nn.SyncBatchNorm(in_channels) + + +class Encoder(nn.Module): + def __init__(self, n_hiddens, downsample, image_channel=3, norm_type='group', padding_type='replicate', num_groups=32): + super().__init__() + n_times_downsample = np.array([int(math.log2(d)) for d in downsample]) + self.conv_blocks = nn.ModuleList() + max_ds = n_times_downsample.max() + + self.conv_first = SamePadConv3d( + image_channel, n_hiddens, kernel_size=3, padding_type=padding_type) + + for i in range(max_ds): + block = nn.Module() + in_channels = n_hiddens * 2**i + out_channels = n_hiddens * 2**(i+1) + stride = tuple([2 if d > 0 else 1 for d in n_times_downsample]) + block.down = SamePadConv3d( + in_channels, out_channels, 4, stride=stride, padding_type=padding_type) + block.res = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + self.conv_blocks.append(block) + n_times_downsample -= 1 + + self.final_block = nn.Sequential( + Normalize(out_channels, norm_type, num_groups=num_groups), + SiLU() + ) + + self.out_channels = out_channels + + def forward(self, x): + h = self.conv_first(x) + for block in self.conv_blocks: + h = block.down(h) + h = block.res(h) + h = self.final_block(h) + return h + + +class Decoder(nn.Module): + def __init__(self, n_hiddens, upsample, image_channel, norm_type='group', num_groups=32): + super().__init__() + + n_times_upsample = np.array([int(math.log2(d)) for d in upsample]) + max_us = n_times_upsample.max() + + in_channels = n_hiddens*2**max_us + self.final_block = nn.Sequential( + Normalize(in_channels, norm_type, num_groups=num_groups), + SiLU() + ) + + self.conv_blocks = nn.ModuleList() + for i in range(max_us): + block = nn.Module() + in_channels = in_channels if i == 0 else n_hiddens*2**(max_us-i+1) + out_channels = n_hiddens*2**(max_us-i) + us = tuple([2 if d > 0 else 1 for d in n_times_upsample]) + block.up = SamePadConvTranspose3d( + in_channels, out_channels, 4, stride=us) + block.res1 = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + block.res2 = ResBlock( + out_channels, out_channels, norm_type=norm_type, num_groups=num_groups) + self.conv_blocks.append(block) + n_times_upsample -= 1 + + self.conv_last = SamePadConv3d( + out_channels, image_channel, kernel_size=3) + + def forward(self, x): + h = self.final_block(x) + for i, block in enumerate(self.conv_blocks): + h = block.up(h) + h = block.res1(h) + h = block.res2(h) + h = self.conv_last(h) + return h + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, norm_type='group', padding_type='replicate', num_groups=32): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels, norm_type, num_groups=num_groups) + self.conv1 = SamePadConv3d( + in_channels, out_channels, kernel_size=3, padding_type=padding_type) + self.dropout = torch.nn.Dropout(dropout) + self.norm2 = Normalize(in_channels, norm_type, num_groups=num_groups) + self.conv2 = SamePadConv3d( + out_channels, out_channels, kernel_size=3, padding_type=padding_type) + if self.in_channels != self.out_channels: + self.conv_shortcut = SamePadConv3d( + in_channels, out_channels, kernel_size=3, padding_type=padding_type) + + def forward(self, x): + h = x + h = self.norm1(h) + h = silu(h) + h = self.conv1(h) + h = self.norm2(h) + h = silu(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.conv_shortcut(x) + + return x+h + + +# Does not support dilation +class SamePadConv3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + if isinstance(stride, int): + stride = (stride,) * 3 + + # assumes that the input shape is divisible by stride + total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) + pad_input = [] + for p in total_pad[::-1]: # reverse since F.pad starts from last dim + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + self.pad_input = pad_input + self.padding_type = padding_type + + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, + stride=stride, padding=0, bias=bias) + + def forward(self, x): + return self.conv(F.pad(x, self.pad_input, mode=self.padding_type)) + + +class SamePadConvTranspose3d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride=1, bias=True, padding_type='replicate'): + super().__init__() + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * 3 + if isinstance(stride, int): + stride = (stride,) * 3 + + total_pad = tuple([k - s for k, s in zip(kernel_size, stride)]) + pad_input = [] + for p in total_pad[::-1]: # reverse since F.pad starts from last dim + pad_input.append((p // 2 + p % 2, p // 2)) + pad_input = sum(pad_input, tuple()) + self.pad_input = pad_input + self.padding_type = padding_type + + self.convt = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, + stride=stride, bias=bias, + padding=tuple([k - 1 for k in kernel_size])) + + def forward(self, x): + return self.convt(F.pad(x, self.pad_input, mode=self.padding_type)) + + +class NLayerDiscriminator(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True): + # def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_sigmoid=False, getIntermFeat=True): + super(NLayerDiscriminator, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv2d(input_nc, ndf, kernel_size=kw, + stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, + stride=1, padding=padw)]] + + if use_sigmoid: + sequence += [[nn.Sigmoid()]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[-1], res[1:] + else: + return self.model(input), _ + + +class NLayerDiscriminator3D(nn.Module): + def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.SyncBatchNorm, use_sigmoid=False, getIntermFeat=True): + super(NLayerDiscriminator3D, self).__init__() + self.getIntermFeat = getIntermFeat + self.n_layers = n_layers + + kw = 4 + padw = int(np.ceil((kw-1.0)/2)) + sequence = [[nn.Conv3d(input_nc, ndf, kernel_size=kw, + stride=2, padding=padw), nn.LeakyReLU(0.2, True)]] + + nf = ndf + for n in range(1, n_layers): + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=2, padding=padw), + norm_layer(nf), nn.LeakyReLU(0.2, True) + ]] + + nf_prev = nf + nf = min(nf * 2, 512) + sequence += [[ + nn.Conv3d(nf_prev, nf, kernel_size=kw, stride=1, padding=padw), + norm_layer(nf), + nn.LeakyReLU(0.2, True) + ]] + + sequence += [[nn.Conv3d(nf, 1, kernel_size=kw, + stride=1, padding=padw)]] + + if use_sigmoid: + sequence += [[nn.Sigmoid()]] + + if getIntermFeat: + for n in range(len(sequence)): + setattr(self, 'model'+str(n), nn.Sequential(*sequence[n])) + else: + sequence_stream = [] + for n in range(len(sequence)): + sequence_stream += sequence[n] + self.model = nn.Sequential(*sequence_stream) + + def forward(self, input): + if self.getIntermFeat: + res = [input] + for n in range(self.n_layers+2): + model = getattr(self, 'model'+str(n)) + res.append(model(res[-1])) + return res[-1], res[1:] + else: + return self.model(input), _ diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/utils.py b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..acf1a21270d98d0a1ea9935d00f9f16d89d54551 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/ldm/vq_gan_3d/utils.py @@ -0,0 +1,177 @@ +""" Adapted from https://github.com/SongweiGe/TATS""" +# Copyright (c) Meta Platforms, Inc. All Rights Reserved + +import warnings +import torch +import imageio + +import math +import numpy as np + +import sys +import pdb as pdb_original +import logging + +import imageio.core.util +logging.getLogger("imageio_ffmpeg").setLevel(logging.ERROR) + + +class ForkedPdb(pdb_original.Pdb): + """A Pdb subclass that may be used + from a forked multiprocessing child + + """ + + def interaction(self, *args, **kwargs): + _stdin = sys.stdin + try: + sys.stdin = open('/dev/stdin') + pdb_original.Pdb.interaction(self, *args, **kwargs) + finally: + sys.stdin = _stdin + + +# Shifts src_tf dim to dest dim +# i.e. shift_dim(x, 1, -1) would be (b, c, t, h, w) -> (b, t, h, w, c) +def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True): + n_dims = len(x.shape) + if src_dim < 0: + src_dim = n_dims + src_dim + if dest_dim < 0: + dest_dim = n_dims + dest_dim + + assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims + + dims = list(range(n_dims)) + del dims[src_dim] + + permutation = [] + ctr = 0 + for i in range(n_dims): + if i == dest_dim: + permutation.append(src_dim) + else: + permutation.append(dims[ctr]) + ctr += 1 + x = x.permute(permutation) + if make_contiguous: + x = x.contiguous() + return x + + +# reshapes tensor start from dim i (inclusive) +# to dim j (exclusive) to the desired shape +# e.g. if x.shape = (b, thw, c) then +# view_range(x, 1, 2, (t, h, w)) returns +# x of shape (b, t, h, w, c) +def view_range(x, i, j, shape): + shape = tuple(shape) + + n_dims = len(x.shape) + if i < 0: + i = n_dims + i + + if j is None: + j = n_dims + elif j < 0: + j = n_dims + j + + assert 0 <= i < j <= n_dims + + x_shape = x.shape + target_shape = x_shape[:i] + shape + x_shape[j:] + return x.view(target_shape) + + +def accuracy(output, target, topk=(1,)): + """Computes the accuracy over the k top predictions for the specified values of k""" + with torch.no_grad(): + maxk = max(topk) + batch_size = target.size(0) + + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.reshape(1, -1).expand_as(pred)) + + res = [] + for k in topk: + correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True) + res.append(correct_k.mul_(100.0 / batch_size)) + return res + + +def tensor_slice(x, begin, size): + assert all([b >= 0 for b in begin]) + size = [l - b if s == -1 else s + for s, b, l in zip(size, begin, x.shape)] + assert all([s >= 0 for s in size]) + + slices = [slice(b, b + s) for b, s in zip(begin, size)] + return x[slices] + + +def adopt_weight(global_step, threshold=0, value=0.): + weight = 1 + if global_step < threshold: + weight = value + return weight + + +def save_video_grid(video, fname, nrow=None, fps=6): + b, c, t, h, w = video.shape + video = video.permute(0, 2, 3, 4, 1) + video = (video.cpu().numpy() * 255).astype('uint8') + if nrow is None: + nrow = math.ceil(math.sqrt(b)) + ncol = math.ceil(b / nrow) + padding = 1 + video_grid = np.zeros((t, (padding + h) * nrow + padding, + (padding + w) * ncol + padding, c), dtype='uint8') + for i in range(b): + r = i // ncol + c = i % ncol + start_r = (padding + h) * r + start_c = (padding + w) * c + video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] + video = [] + for i in range(t): + video.append(video_grid[i]) + imageio.mimsave(fname, video, fps=fps) + ## skvideo.io.vwrite(fname, video_grid, inputdict={'-r': '5'}) + #print('saved videos to', fname) + + +def comp_getattr(args, attr_name, default=None): + if hasattr(args, attr_name): + return getattr(args, attr_name) + else: + return default + + +def visualize_tensors(t, name=None, nest=0): + if name is not None: + print(name, "current nest: ", nest) + print("type: ", type(t)) + if 'dict' in str(type(t)): + print(t.keys()) + for k in t.keys(): + if t[k] is None: + print(k, "None") + else: + if 'Tensor' in str(type(t[k])): + print(k, t[k].shape) + elif 'dict' in str(type(t[k])): + print(k, 'dict') + visualize_tensors(t[k], name, nest + 1) + elif 'list' in str(type(t[k])): + print(k, len(t[k])) + visualize_tensors(t[k], name, nest + 1) + elif 'list' in str(type(t)): + print("list length: ", len(t)) + for t2 in t: + visualize_tensors(t2, name, nest + 1) + elif 'Tensor' in str(type(t)): + print(t.shape) + else: + print(t) + return "" diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/utils.py b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..54a3d68432d165ab2895859a89d7be4d150e9721 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/utils.py @@ -0,0 +1,465 @@ +### Tumor Generateion +import random +import cv2 +import elasticdeform +import numpy as np +from scipy.ndimage import gaussian_filter +from TumorGeneration.ldm.ddpm.ddim import DDIMSampler + +# Step 1: Random select (numbers) location for tumor. +def random_select(mask_scan, organ_type): + # we first find z index and then sample point with z slice + # print('mask_scan',np.unique(mask_scan)) + # print('pixel num', (mask_scan == 1).sum()) + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0].min(), np.where(np.any(mask_scan, axis=(0, 1)))[0].max() + # print('z_start, z_end',z_start, z_end) + # we need to strict number z's position (0.3 - 0.7 in the middle of liver) + while 1: + z = round(random.uniform(0.3, 0.7) * (z_end - z_start)) + z_start + liver_mask = mask_scan[..., z] + # erode the mask (we don't want the edge points) + if organ_type == 'liver': + kernel = np.ones((5,5), dtype=np.uint8) + liver_mask = cv2.erode(liver_mask, kernel, iterations=1) + if (liver_mask == 1).sum() > 0: + break + + + + # print('liver_mask', (liver_mask == 1).sum()) + coordinates = np.argwhere(liver_mask == 1) + random_index = np.random.randint(0, len(coordinates)) + xyz = coordinates[random_index].tolist() # get x,y + xyz.append(z) + potential_points = xyz + + return potential_points + +def center_select(mask_scan): + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0].min(), np.where(np.any(mask_scan, axis=(0, 1)))[0].max() + x_start, x_end = np.where(np.any(mask_scan, axis=(1, 2)))[0].min(), np.where(np.any(mask_scan, axis=(1, 2)))[0].max() + y_start, y_end = np.where(np.any(mask_scan, axis=(0, 2)))[0].min(), np.where(np.any(mask_scan, axis=(0, 2)))[0].max() + + z = round(0.5 * (z_end - z_start)) + z_start + x = round(0.5 * (x_end - x_start)) + x_start + y = round(0.5 * (y_end - y_start)) + y_start + + xyz = [x, y, z] + potential_points = xyz + + return potential_points + +# Step 2 : generate the ellipsoid +def get_ellipsoid(x, y, z): + """" + x, y, z is the radius of this ellipsoid in x, y, z direction respectly. + """ + sh = (4*x, 4*y, 4*z) + out = np.zeros(sh, int) + aux = np.zeros(sh) + radii = np.array([x, y, z]) + com = np.array([2*x, 2*y, 2*z]) # center point + + # calculate the ellipsoid + bboxl = np.floor(com-radii).clip(0,None).astype(int) + bboxh = (np.ceil(com+radii)+1).clip(None, sh).astype(int) + roi = out[tuple(map(slice,bboxl,bboxh))] + roiaux = aux[tuple(map(slice,bboxl,bboxh))] + logrid = *map(np.square,np.ogrid[tuple( + map(slice,(bboxl-com)/radii,(bboxh-com-1)/radii,1j*(bboxh-bboxl)))]), + dst = (1-sum(logrid)).clip(0,None) + mask = dst>roiaux + roi[mask] = 1 + np.copyto(roiaux,dst,where=mask) + + return out + +def get_fixed_geo(mask_scan, tumor_type, organ_type): + if tumor_type == 'large': + enlarge_x, enlarge_y, enlarge_z = 280, 280, 280 + else: + enlarge_x, enlarge_y, enlarge_z = 160, 160, 160 + geo_mask = np.zeros((mask_scan.shape[0] + enlarge_x, mask_scan.shape[1] + enlarge_y, mask_scan.shape[2] + enlarge_z), dtype=np.int8) + tiny_radius, small_radius, medium_radius, large_radius = 4, 8, 16, 32 + + if tumor_type == 'tiny': + num_tumor = random.randint(1,3) + # num_tumor = 1 + for _ in range(num_tumor): + # Tiny tumor + x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + sigma = random.uniform(0.5,1) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + + if tumor_type == 'small': + num_tumor = random.randint(1,3) + # num_tumor = 1 + for _ in range(num_tumor): + # Small tumor + x = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + y = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + z = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + sigma = random.randint(1, 2) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'medium': + # num_tumor = random.randint(1, 3) + num_tumor = 1 + for _ in range(num_tumor): + # medium tumor + x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + sigma = random.randint(3, 6) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste medium tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'large': + num_tumor = 1 # random.randint(1,3) + for _ in range(num_tumor): + # Large tumor + + x = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + y = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + z = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + sigma = random.randint(5, 10) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + if organ_type == 'liver' or organ_type == 'kidney' : + point = random_select(mask_scan, organ_type) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + else: + x_start, x_end = np.where(np.any(geo, axis=(1, 2)))[0].min(), np.where(np.any(geo, axis=(1, 2)))[0].max() + y_start, y_end = np.where(np.any(geo, axis=(0, 2)))[0].min(), np.where(np.any(geo, axis=(0, 2)))[0].max() + z_start, z_end = np.where(np.any(geo, axis=(0, 1)))[0].min(), np.where(np.any(geo, axis=(0, 1)))[0].max() + geo = geo[x_start:x_end, y_start:y_end, z_start:z_end] + + point = center_select(mask_scan) + + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low = new_point[0] - geo.shape[0]//2 + y_low = new_point[1] - geo.shape[1]//2 + z_low = new_point[2] - geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_low+geo.shape[0], y_low:y_low+geo.shape[1], z_low:z_low+geo.shape[2]] += geo + + geo_mask = geo_mask[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + # texture_map = texture_map[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + if ((tumor_type == 'medium') or (tumor_type == 'large')) and (organ_type == 'kidney'): + if random.random() > 0.5: + geo_mask = (geo_mask>=1) + else: + geo_mask = (geo_mask * mask_scan) >=1 + else: + geo_mask = (geo_mask * mask_scan) >=1 + + return geo_mask + + +from .ldm.vq_gan_3d.model.vqgan import VQGAN +import matplotlib.pyplot as plt +import SimpleITK as sitk +from .ldm.ddpm import Unet3D, GaussianDiffusion, Tester +from hydra import initialize, compose +import torch +import yaml +def synt_model_prepare(device, vqgan_ckpt='TumorGeneration/model_weight/recon_96d4_all.ckpt', diffusion_ckpt='TumorGeneration/model_weight/', fold=0, organ='liver'): + with initialize(config_path="diffusion_config/"): + cfg=compose(config_name="ddpm.yaml") + print('diffusion_ckpt',diffusion_ckpt) + vqgan = VQGAN.load_from_checkpoint(vqgan_ckpt) + vqgan = vqgan.to(device) + vqgan.eval() + + early_Unet3D = Unet3D( + dim=cfg.diffusion_img_size, + dim_mults=cfg.dim_mults, + channels=cfg.diffusion_num_channels, + out_dim=cfg.out_dim + ).to(device) + + early_diffusion = GaussianDiffusion( + early_Unet3D, + vqgan_ckpt= vqgan_ckpt, # cfg.vqgan_ckpt, + image_size=cfg.diffusion_img_size, + num_frames=cfg.diffusion_depth_size, + channels=cfg.diffusion_num_channels, + timesteps=4, # cfg.timesteps, + loss_type=cfg.loss_type, + device=device + ).to(device) + + noearly_Unet3D = Unet3D( + dim=cfg.diffusion_img_size, + dim_mults=cfg.dim_mults, + channels=cfg.diffusion_num_channels, + out_dim=cfg.out_dim + ).to(device) + + noearly_diffusion = GaussianDiffusion( + noearly_Unet3D, + vqgan_ckpt= vqgan_ckpt, # cfg.vqgan_ckpt, + image_size=cfg.diffusion_img_size, + num_frames=cfg.diffusion_depth_size, + channels=cfg.diffusion_num_channels, + timesteps=200, # cfg.timesteps, + loss_type=cfg.loss_type, + device=device + ).to(device) + + early_tester = Tester(early_diffusion) + # noearly_tester = Tester(noearly_diffusion) + early_tester.load(diffusion_ckpt+'diff_{}_fold{}_early.pt'.format(organ, fold), map_location=device) + # noearly_tester.load(diffusion_ckpt+'diff_liver_fold{}_noearly_t200.pt'.format(fold), map_location=device) + + # early_checkpoint = torch.load(diffusion_ckpt+'diff_liver_fold{}_early.pt'.format(fold), map_location=device) + noearly_checkpoint = torch.load(diffusion_ckpt+'diff_{}_fold{}_noearly_t200.pt'.format(organ, fold), map_location=device) + # early_diffusion.load_state_dict(early_checkpoint['ema']) + noearly_diffusion.load_state_dict(noearly_checkpoint['ema']) + # early_sampler = DDIMSampler(early_diffusion, schedule="cosine") + noearly_sampler = DDIMSampler(noearly_diffusion, schedule="cosine") + # breakpoint() + return vqgan, early_tester, noearly_sampler + +def synthesize_early_tumor(ct_volume, organ_mask, organ_type, vqgan, tester): + device=ct_volume.device + + # generate tumor mask + tumor_types = ['tiny', 'small'] + # tumor_probs = np.array([0.5, 0.5]) + tumor_probs = np.array([0.2, 0.8]) + total_tumor_mask = [] + organ_mask_np = organ_mask.cpu().numpy() + with torch.no_grad(): + # get model input + for bs in range(organ_mask_np.shape[0]): + synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type) + total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:]) + total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device) + + volume = ct_volume*2.0 - 1.0 + mask = total_tumor_mask*2.0 - 1.0 + mask_ = 1-total_tumor_mask + masked_volume = (volume*mask_).detach() + + volume = volume.permute(0,1,-1,-3,-2) + masked_volume = masked_volume.permute(0,1,-1,-3,-2) + mask = mask.permute(0,1,-1,-3,-2) + + # vqgan encoder inference + masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True) + masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) / + (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:]) + cond = torch.cat((masked_volume_feat, cc), dim=1) + + # diffusion inference and decoder + tester.ema_model.eval() + sample = tester.ema_model.sample(batch_size=volume.shape[0], cond=cond) + + # if organ_type == 'liver' or organ_type == 'kidney' : + + mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0) + sigma = np.random.uniform(0, 4) # (1, 2) + mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma]) + # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy() + + volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0) + sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0) + + mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device) + final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_ + final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0) + # elif organ_type == 'pancreas': + # final_volume_ = (sample+1.0)/2.0 + final_volume_ = final_volume_.permute(0,1,-2,-1,-3) + organ_tumor_mask = torch.ones_like(organ_mask) + organ_tumor_mask[total_tumor_mask==1] = 2 + + return final_volume_, organ_tumor_mask + +def synthesize_medium_tumor(ct_volume, organ_mask, organ_type, vqgan, sampler, ddim_ts=50): + device=ct_volume.device + + # generate tumor mask + # tumor_types = ['large'] + # tumor_probs = np.array([1.0]) + total_tumor_mask = [] + organ_mask_np = organ_mask.cpu().numpy() + with torch.no_grad(): + # get model input + for bs in range(organ_mask_np.shape[0]): + # synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + synthetic_tumor_type = 'medium' + tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type) + total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:]) + total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device) + + volume = ct_volume*2.0 - 1.0 + mask = total_tumor_mask*2.0 - 1.0 + mask_ = 1-total_tumor_mask + masked_volume = (volume*mask_).detach() + + volume = volume.permute(0,1,-1,-3,-2) + masked_volume = masked_volume.permute(0,1,-1,-3,-2) + mask = mask.permute(0,1,-1,-3,-2) + + # vqgan encoder inference + masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True) + masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) / + (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:]) + cond = torch.cat((masked_volume_feat, cc), dim=1) + + # diffusion inference and decoder + shape = masked_volume_feat.shape[-4:] + samples_ddim, _ = sampler.sample(S=ddim_ts, + conditioning=cond, + batch_size=1, + shape=shape, + verbose=False) + samples_ddim = (((samples_ddim + 1.0) / 2.0) * (vqgan.codebook.embeddings.max() - + vqgan.codebook.embeddings.min())) + vqgan.codebook.embeddings.min() + + sample = vqgan.decode(samples_ddim, quantize=True) + + # if organ_type == 'liver' or organ_type == 'kidney': + # post-process + mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0) + sigma = np.random.uniform(0, 4) # (1, 2) + mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma]) + # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy() + + volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0) + sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0) + + mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device) + final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_ + final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0) + # elif organ_type == 'pancreas': + # final_volume_ = (sample+1.0)/2.0 + + final_volume_ = final_volume_.permute(0,1,-2,-1,-3) + organ_tumor_mask = torch.ones_like(organ_mask) + organ_tumor_mask[total_tumor_mask==1] = 2 + + return final_volume_, organ_tumor_mask + +def synthesize_large_tumor(ct_volume, organ_mask, organ_type, vqgan, sampler, ddim_ts=50): + device=ct_volume.device + + # generate tumor mask + # tumor_types = ['large'] + # tumor_probs = np.array([1.0]) + total_tumor_mask = [] + organ_mask_np = organ_mask.cpu().numpy() + with torch.no_grad(): + # get model input + for bs in range(organ_mask_np.shape[0]): + # synthetic_tumor_type = np.random.choice(tumor_types, p=tumor_probs.ravel()) + synthetic_tumor_type = 'large' + tumor_mask = get_fixed_geo(organ_mask_np[bs,0], synthetic_tumor_type, organ_type) + total_tumor_mask.append(torch.from_numpy(tumor_mask)[None,:]) + total_tumor_mask = torch.stack(total_tumor_mask, dim=0).to(dtype=torch.float32, device=device) + + volume = ct_volume*2.0 - 1.0 + mask = total_tumor_mask*2.0 - 1.0 + mask_ = 1-total_tumor_mask + masked_volume = (volume*mask_).detach() + + volume = volume.permute(0,1,-1,-3,-2) + masked_volume = masked_volume.permute(0,1,-1,-3,-2) + mask = mask.permute(0,1,-1,-3,-2) + + # vqgan encoder inference + masked_volume_feat = vqgan.encode(masked_volume, quantize=False, include_embeddings=True) + masked_volume_feat = ((masked_volume_feat - vqgan.codebook.embeddings.min()) / + (vqgan.codebook.embeddings.max() - vqgan.codebook.embeddings.min())) * 2.0 - 1.0 + + cc = torch.nn.functional.interpolate(mask, size=masked_volume_feat.shape[-3:]) + cond = torch.cat((masked_volume_feat, cc), dim=1) + + # diffusion inference and decoder + shape = masked_volume_feat.shape[-4:] + samples_ddim, _ = sampler.sample(S=ddim_ts, + conditioning=cond, + batch_size=1, + shape=shape, + verbose=False) + samples_ddim = (((samples_ddim + 1.0) / 2.0) * (vqgan.codebook.embeddings.max() - + vqgan.codebook.embeddings.min())) + vqgan.codebook.embeddings.min() + + sample = vqgan.decode(samples_ddim, quantize=True) + + # if organ_type == 'liver' or organ_type == 'kidney': + # post-process + mask_01 = torch.clamp((mask+1.0)/2.0, min=0.0, max=1.0) + sigma = np.random.uniform(0, 4) # (1, 2) + mask_01_np_blur = gaussian_filter(mask_01.cpu().numpy()*1.0, sigma=[0,0,sigma,sigma,sigma]) + # mask_01_np_blur = mask_01_np_blur*mask_01.cpu().numpy() + + volume_ = torch.clamp((volume+1.0)/2.0, min=0.0, max=1.0) + sample_ = torch.clamp((sample+1.0)/2.0, min=0.0, max=1.0) + + mask_01_blur = torch.from_numpy(mask_01_np_blur).to(device=device) + final_volume_ = (1-mask_01_blur)*volume_ +mask_01_blur*sample_ + final_volume_ = torch.clamp(final_volume_, min=0.0, max=1.0) + # elif organ_type == 'pancreas': + # final_volume_ = (sample+1.0)/2.0 + + final_volume_ = final_volume_.permute(0,1,-2,-1,-3) + organ_tumor_mask = torch.ones_like(organ_mask) + organ_tumor_mask[total_tumor_mask==1] = 2 + + return final_volume_, organ_tumor_mask \ No newline at end of file diff --git a/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/utils_.py b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/utils_.py new file mode 100644 index 0000000000000000000000000000000000000000..312e72d1571b3cd996fd567bcedf939443a4e182 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_pancreas/TumorGeneration/utils_.py @@ -0,0 +1,298 @@ +### Tumor Generateion +import random +import cv2 +import elasticdeform +import numpy as np +from scipy.ndimage import gaussian_filter + +# Step 1: Random select (numbers) location for tumor. +def random_select(mask_scan): + # we first find z index and then sample point with z slice + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0][[0, -1]] + + # we need to strict number z's position (0.3 - 0.7 in the middle of liver) + z = round(random.uniform(0.3, 0.7) * (z_end - z_start)) + z_start + + liver_mask = mask_scan[..., z] + + # erode the mask (we don't want the edge points) + kernel = np.ones((5,5), dtype=np.uint8) + liver_mask = cv2.erode(liver_mask, kernel, iterations=1) + + coordinates = np.argwhere(liver_mask == 1) + random_index = np.random.randint(0, len(coordinates)) + xyz = coordinates[random_index].tolist() # get x,y + xyz.append(z) + potential_points = xyz + + return potential_points + +# Step 2 : generate the ellipsoid +def get_ellipsoid(x, y, z): + """" + x, y, z is the radius of this ellipsoid in x, y, z direction respectly. + """ + sh = (4*x, 4*y, 4*z) + out = np.zeros(sh, int) + aux = np.zeros(sh) + radii = np.array([x, y, z]) + com = np.array([2*x, 2*y, 2*z]) # center point + + # calculate the ellipsoid + bboxl = np.floor(com-radii).clip(0,None).astype(int) + bboxh = (np.ceil(com+radii)+1).clip(None, sh).astype(int) + roi = out[tuple(map(slice,bboxl,bboxh))] + roiaux = aux[tuple(map(slice,bboxl,bboxh))] + logrid = *map(np.square,np.ogrid[tuple( + map(slice,(bboxl-com)/radii,(bboxh-com-1)/radii,1j*(bboxh-bboxl)))]), + dst = (1-sum(logrid)).clip(0,None) + mask = dst>roiaux + roi[mask] = 1 + np.copyto(roiaux,dst,where=mask) + + return out + +def get_fixed_geo(mask_scan, tumor_type): + + enlarge_x, enlarge_y, enlarge_z = 160, 160, 160 + geo_mask = np.zeros((mask_scan.shape[0] + enlarge_x, mask_scan.shape[1] + enlarge_y, mask_scan.shape[2] + enlarge_z), dtype=np.int8) + tiny_radius, small_radius, medium_radius, large_radius = 4, 8, 16, 32 + + if tumor_type == 'tiny': + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Tiny tumor + x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + sigma = random.uniform(0.5,1) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + + if tumor_type == 'small': + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Small tumor + x = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + y = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + z = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + sigma = random.randint(1, 2) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'medium': + num_tumor = random.randint(2, 5) + for _ in range(num_tumor): + # medium tumor + x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + sigma = random.randint(3, 6) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste medium tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == 'large': + num_tumor = random.randint(1,3) + for _ in range(num_tumor): + # Large tumor + x = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + y = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + z = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + sigma = random.randint(5, 10) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + if tumor_type == "mix": + # tiny + num_tumor = random.randint(3,10) + for _ in range(num_tumor): + # Tiny tumor + x = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + y = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + z = random.randint(int(0.75*tiny_radius), int(1.25*tiny_radius)) + sigma = random.uniform(0.5,1) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + + # small + num_tumor = random.randint(5,10) + for _ in range(num_tumor): + # Small tumor + x = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + y = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + z = random.randint(int(0.75*small_radius), int(1.25*small_radius)) + sigma = random.randint(1, 2) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + # medium + num_tumor = random.randint(2, 5) + for _ in range(num_tumor): + # medium tumor + x = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + y = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + z = random.randint(int(0.75*medium_radius), int(1.25*medium_radius)) + sigma = random.randint(3, 6) + + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste medium tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + # large + num_tumor = random.randint(1,3) + for _ in range(num_tumor): + # Large tumor + x = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + y = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + z = random.randint(int(0.75*large_radius), int(1.25*large_radius)) + sigma = random.randint(5, 10) + geo = get_ellipsoid(x, y, z) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,1)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(1,2)) + geo = elasticdeform.deform_random_grid(geo, sigma=sigma, points=3, order=0, axis=(0,2)) + # texture = get_texture((4*x, 4*y, 4*z)) + point = random_select(mask_scan) + new_point = [point[0] + enlarge_x//2, point[1] + enlarge_y//2, point[2] + enlarge_z//2] + x_low, x_high = new_point[0] - geo.shape[0]//2, new_point[0] + geo.shape[0]//2 + y_low, y_high = new_point[1] - geo.shape[1]//2, new_point[1] + geo.shape[1]//2 + z_low, z_high = new_point[2] - geo.shape[2]//2, new_point[2] + geo.shape[2]//2 + + # paste small tumor geo into test sample + geo_mask[x_low:x_high, y_low:y_high, z_low:z_high] += geo + # texture_map[x_low:x_high, y_low:y_high, z_low:z_high] = texture + + geo_mask = geo_mask[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + # texture_map = texture_map[enlarge_x//2:-enlarge_x//2, enlarge_y//2:-enlarge_y//2, enlarge_z//2:-enlarge_z//2] + geo_mask = (geo_mask * mask_scan) >=1 + + return geo_mask + + +def get_tumor(volume_scan, mask_scan, tumor_type): + tumor_mask = get_fixed_geo(mask_scan, tumor_type) + + sigma = np.random.uniform(1, 2) + # difference = np.random.uniform(65, 145) + difference = 1 + + # blur the boundary + tumor_mask_blur = gaussian_filter(tumor_mask*1.0, sigma) + + + abnormally_full = volume_scan * (1 - mask_scan) + abnormally + abnormally_mask = mask_scan + geo_mask + + return abnormally_full, abnormally_mask + +def SynthesisTumor(volume_scan, mask_scan, tumor_type): + # for speed_generate_tumor, we only send the liver part into the generate program + x_start, x_end = np.where(np.any(mask_scan, axis=(1, 2)))[0][[0, -1]] + y_start, y_end = np.where(np.any(mask_scan, axis=(0, 2)))[0][[0, -1]] + z_start, z_end = np.where(np.any(mask_scan, axis=(0, 1)))[0][[0, -1]] + + # shrink the boundary + x_start, x_end = max(0, x_start+1), min(mask_scan.shape[0], x_end-1) + y_start, y_end = max(0, y_start+1), min(mask_scan.shape[1], y_end-1) + z_start, z_end = max(0, z_start+1), min(mask_scan.shape[2], z_end-1) + + ct_volume = volume_scan[x_start:x_end, y_start:y_end, z_start:z_end] + organ_mask = mask_scan[x_start:x_end, y_start:y_end, z_start:z_end] + + # input texture shape: 420, 300, 320 + # we need to cut it into the shape of liver_mask + # for examples, the liver_mask.shape = 286, 173, 46; we should change the texture shape + x_length, y_length, z_length = 64, 64, 64 + crop_x = random.randint(x_start, x_end - x_length - 1) # random select the start point, -1 is to avoid boundary check + crop_y = random.randint(y_start, y_end - y_length - 1) + crop_z = random.randint(z_start, z_end - z_length - 1) + + ct_volume, organ_tumor_mask = get_tumor(ct_volume, organ_mask, tumor_type) + volume_scan[x_start:x_end, y_start:y_end, z_start:z_end] = ct_volume + mask_scan[x_start:x_end, y_start:y_end, z_start:z_end] = organ_tumor_mask + + return volume_scan, mask_scan diff --git a/Generation_Pipeline_filter_all/syn_pancreas/healthy_pancreas_1k.txt b/Generation_Pipeline_filter_all/syn_pancreas/healthy_pancreas_1k.txt new file mode 100644 index 0000000000000000000000000000000000000000..5b54b3f0c568b2953320a4691f0196f8315a79fe --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_pancreas/healthy_pancreas_1k.txt @@ -0,0 +1,774 @@ +BDMAP_00004652 +BDMAP_00002164 +BDMAP_00000439 +BDMAP_00004775 +BDMAP_00001148 +BDMAP_00003384 +BDMAP_00003634 +BDMAP_00000030 +BDMAP_00005119 +BDMAP_00003580 +BDMAP_00000320 +BDMAP_00002185 +BDMAP_00002487 +BDMAP_00002775 +BDMAP_00000159 +BDMAP_00004475 +BDMAP_00004278 +BDMAP_00004131 +BDMAP_00000244 +BDMAP_00001966 +BDMAP_00000939 +BDMAP_00003928 +BDMAP_00003031 +BDMAP_00000568 +BDMAP_00001309 +BDMAP_00004232 +BDMAP_00004431 +BDMAP_00001794 +BDMAP_00001752 +BDMAP_00004228 +BDMAP_00002363 +BDMAP_00001020 +BDMAP_00002739 +BDMAP_00004462 +BDMAP_00003749 +BDMAP_00002472 +BDMAP_00003455 +BDMAP_00003911 +BDMAP_00003486 +BDMAP_00001496 +BDMAP_00004378 +BDMAP_00002282 +BDMAP_00001343 +BDMAP_00001747 +BDMAP_00004450 +BDMAP_00000607 +BDMAP_00000100 +BDMAP_00001620 +BDMAP_00004415 +BDMAP_00004897 +BDMAP_00001737 +BDMAP_00000604 +BDMAP_00004922 +BDMAP_00004870 +BDMAP_00002060 +BDMAP_00003412 +BDMAP_00004850 +BDMAP_00000794 +BDMAP_00000989 +BDMAP_00000205 +BDMAP_00003151 +BDMAP_00001255 +BDMAP_00000667 +BDMAP_00003343 +BDMAP_00001237 +BDMAP_00000023 +BDMAP_00003281 +BDMAP_00000907 +BDMAP_00000432 +BDMAP_00001782 +BDMAP_00002166 +BDMAP_00004481 +BDMAP_00003363 +BDMAP_00001474 +BDMAP_00001995 +BDMAP_00002854 +BDMAP_00003603 +BDMAP_00001383 +BDMAP_00002656 +BDMAP_00004427 +BDMAP_00001563 +BDMAP_00001809 +BDMAP_00002114 +BDMAP_00000304 +BDMAP_00001692 +BDMAP_00001688 +BDMAP_00001119 +BDMAP_00000449 +BDMAP_00004014 +BDMAP_00002524 +BDMAP_00000725 +BDMAP_00000918 +BDMAP_00000642 +BDMAP_00001238 +BDMAP_00002373 +BDMAP_00002326 +BDMAP_00004373 +BDMAP_00003615 +BDMAP_00003324 +BDMAP_00002654 +BDMAP_00002849 +BDMAP_00003491 +BDMAP_00002655 +BDMAP_00002712 +BDMAP_00001516 +BDMAP_00000469 +BDMAP_00001549 +BDMAP_00000713 +BDMAP_00000745 +BDMAP_00000259 +BDMAP_00003569 +BDMAP_00005067 +BDMAP_00004185 +BDMAP_00003357 +BDMAP_00002419 +BDMAP_00002598 +BDMAP_00002167 +BDMAP_00000137 +BDMAP_00003448 +BDMAP_00000965 +BDMAP_00000232 +BDMAP_00004608 +BDMAP_00003680 +BDMAP_00000716 +BDMAP_00002403 +BDMAP_00004216 +BDMAP_00001359 +BDMAP_00004175 +BDMAP_00002791 +BDMAP_00002940 +BDMAP_00000355 +BDMAP_00004294 +BDMAP_00001426 +BDMAP_00001475 +BDMAP_00002986 +BDMAP_00002884 +BDMAP_00000400 +BDMAP_00002410 +BDMAP_00000297 +BDMAP_00001636 +BDMAP_00005113 +BDMAP_00004074 +BDMAP_00002333 +BDMAP_00003976 +BDMAP_00002383 +BDMAP_00000161 +BDMAP_00001212 +BDMAP_00000366 +BDMAP_00003070 +BDMAP_00003943 +BDMAP_00003930 +BDMAP_00003164 +BDMAP_00001906 +BDMAP_00002889 +BDMAP_00004163 +BDMAP_00001456 +BDMAP_00003972 +BDMAP_00004586 +BDMAP_00000626 +BDMAP_00001095 +BDMAP_00000532 +BDMAP_00003377 +BDMAP_00003225 +BDMAP_00001289 +BDMAP_00001275 +BDMAP_00004509 +BDMAP_00000998 +BDMAP_00000836 +BDMAP_00001015 +BDMAP_00004650 +BDMAP_00005186 +BDMAP_00000608 +BDMAP_00003898 +BDMAP_00002696 +BDMAP_00003560 +BDMAP_00004578 +BDMAP_00000828 +BDMAP_00000690 +BDMAP_00003564 +BDMAP_00005174 +BDMAP_00000132 +BDMAP_00005105 +BDMAP_00000902 +BDMAP_00003947 +BDMAP_00002184 +BDMAP_00001785 +BDMAP_00002361 +BDMAP_00003255 +BDMAP_00000971 +BDMAP_00003493 +BDMAP_00002267 +BDMAP_00005154 +BDMAP_00000982 +BDMAP_00005157 +BDMAP_00004384 +BDMAP_00003063 +BDMAP_00001982 +BDMAP_00002273 +BDMAP_00001102 +BDMAP_00002689 +BDMAP_00000034 +BDMAP_00001514 +BDMAP_00005081 +BDMAP_00001786 +BDMAP_00004033 +BDMAP_00004457 +BDMAP_00000710 +BDMAP_00001198 +BDMAP_00004479 +BDMAP_00000873 +BDMAP_00000362 +BDMAP_00004616 +BDMAP_00003128 +BDMAP_00001607 +BDMAP_00004104 +BDMAP_00001517 +BDMAP_00004639 +BDMAP_00005170 +BDMAP_00002305 +BDMAP_00004746 +BDMAP_00003333 +BDMAP_00001807 +BDMAP_00004579 +BDMAP_00002260 +BDMAP_00004416 +BDMAP_00003932 +BDMAP_00001316 +BDMAP_00003411 +BDMAP_00000839 +BDMAP_00004738 +BDMAP_00001438 +BDMAP_00003435 +BDMAP_00001697 +BDMAP_00001911 +BDMAP_00001735 +BDMAP_00002902 +BDMAP_00001834 +BDMAP_00000069 +BDMAP_00004066 +BDMAP_00000434 +BDMAP_00004744 +BDMAP_00000347 +BDMAP_00001246 +BDMAP_00003150 +BDMAP_00003957 +BDMAP_00001768 +BDMAP_00002663 +BDMAP_00004147 +BDMAP_00003510 +BDMAP_00002242 +BDMAP_00005016 +BDMAP_00002275 +BDMAP_00001924 +BDMAP_00002214 +BDMAP_00002529 +BDMAP_00000562 +BDMAP_00000122 +BDMAP_00002707 +BDMAP_00000874 +BDMAP_00000176 +BDMAP_00002804 +BDMAP_00005005 +BDMAP_00001422 +BDMAP_00005017 +BDMAP_00000653 +BDMAP_00002609 +BDMAP_00003327 +BDMAP_00002484 +BDMAP_00004673 +BDMAP_00004493 +BDMAP_00003740 +BDMAP_00002271 +BDMAP_00002742 +BDMAP_00002826 +BDMAP_00001035 +BDMAP_00002068 +BDMAP_00003815 +BDMAP_00003052 +BDMAP_00004499 +BDMAP_00002065 +BDMAP_00001025 +BDMAP_00004888 +BDMAP_00002592 +BDMAP_00004030 +BDMAP_00001024 +BDMAP_00002041 +BDMAP_00002807 +BDMAP_00002751 +BDMAP_00003272 +BDMAP_00004600 +BDMAP_00004154 +BDMAP_00003774 +BDMAP_00000948 +BDMAP_00002173 +BDMAP_00004510 +BDMAP_00000104 +BDMAP_00004374 +BDMAP_00000429 +BDMAP_00004420 +BDMAP_00001853 +BDMAP_00003600 +BDMAP_00002349 +BDMAP_00001863 +BDMAP_00004830 +BDMAP_00002981 +BDMAP_00001941 +BDMAP_00001128 +BDMAP_00005151 +BDMAP_00003890 +BDMAP_00003640 +BDMAP_00004257 +BDMAP_00004943 +BDMAP_00001068 +BDMAP_00001305 +BDMAP_00000414 +BDMAP_00000465 +BDMAP_00003727 +BDMAP_00002199 +BDMAP_00001769 +BDMAP_00004187 +BDMAP_00001891 +BDMAP_00000980 +BDMAP_00003923 +BDMAP_00000942 +BDMAP_00001114 +BDMAP_00001602 +BDMAP_00002845 +BDMAP_00003178 +BDMAP_00003409 +BDMAP_00001562 +BDMAP_00002909 +BDMAP_00003808 +BDMAP_00001169 +BDMAP_00001104 +BDMAP_00001483 +BDMAP_00005009 +BDMAP_00001957 +BDMAP_00003153 +BDMAP_00001444 +BDMAP_00000851 +BDMAP_00005191 +BDMAP_00000687 +BDMAP_00003722 +BDMAP_00003330 +BDMAP_00002347 +BDMAP_00002955 +BDMAP_00001089 +BDMAP_00004529 +BDMAP_00003268 +BDMAP_00001522 +BDMAP_00001502 +BDMAP_00000240 +BDMAP_00004867 +BDMAP_00000480 +BDMAP_00000452 +BDMAP_00002918 +BDMAP_00002953 +BDMAP_00002039 +BDMAP_00000889 +BDMAP_00002746 +BDMAP_00003608 +BDMAP_00003664 +BDMAP_00003299 +BDMAP_00001445 +BDMAP_00000113 +BDMAP_00001705 +BDMAP_00000044 +BDMAP_00003513 +BDMAP_00001261 +BDMAP_00004990 +BDMAP_00003143 +BDMAP_00003111 +BDMAP_00002319 +BDMAP_00004664 +BDMAP_00003717 +BDMAP_00004717 +BDMAP_00004745 +BDMAP_00000671 +BDMAP_00002990 +BDMAP_00004901 +BDMAP_00002545 +BDMAP_00004980 +BDMAP_00000913 +BDMAP_00000437 +BDMAP_00002864 +BDMAP_00000364 +BDMAP_00004195 +BDMAP_00000162 +BDMAP_00002840 +BDMAP_00000233 +BDMAP_00002744 +BDMAP_00001218 +BDMAP_00002289 +BDMAP_00000229 +BDMAP_00005114 +BDMAP_00000279 +BDMAP_00003832 +BDMAP_00000241 +BDMAP_00002251 +BDMAP_00001676 +BDMAP_00001635 +BDMAP_00003444 +BDMAP_00002265 +BDMAP_00002498 +BDMAP_00001209 +BDMAP_00001138 +BDMAP_00002407 +BDMAP_00003798 +BDMAP_00001325 +BDMAP_00002631 +BDMAP_00004304 +BDMAP_00001078 +BDMAP_00002562 +BDMAP_00003576 +BDMAP_00001977 +BDMAP_00002396 +BDMAP_00001333 +BDMAP_00004925 +BDMAP_00004903 +BDMAP_00000273 +BDMAP_00000571 +BDMAP_00001027 +BDMAP_00000149 +BDMAP_00001962 +BDMAP_00003481 +BDMAP_00001256 +BDMAP_00000871 +BDMAP_00000926 +BDMAP_00000572 +BDMAP_00004558 +BDMAP_00000435 +BDMAP_00000837 +BDMAP_00003713 +BDMAP_00002875 +BDMAP_00004645 +BDMAP_00001711 +BDMAP_00001296 +BDMAP_00002648 +BDMAP_00004561 +BDMAP_00002318 +BDMAP_00001835 +BDMAP_00003524 +BDMAP_00002959 +BDMAP_00002422 +BDMAP_00004597 +BDMAP_00000487 +BDMAP_00002359 +BDMAP_00005001 +BDMAP_00004817 +BDMAP_00001539 +BDMAP_00002936 +BDMAP_00002719 +BDMAP_00005167 +BDMAP_00001265 +BDMAP_00001471 +BDMAP_00001511 +BDMAP_00005139 +BDMAP_00002426 +BDMAP_00002288 +BDMAP_00004808 +BDMAP_00002085 +BDMAP_00004435 +BDMAP_00000319 +BDMAP_00003614 +BDMAP_00001109 +BDMAP_00000331 +BDMAP_00004491 +BDMAP_00002440 +BDMAP_00003373 +BDMAP_00005065 +BDMAP_00005006 +BDMAP_00002509 +BDMAP_00003973 +BDMAP_00004417 +BDMAP_00000935 +BDMAP_00004624 +BDMAP_00003364 +BDMAP_00005085 +BDMAP_00003073 +BDMAP_00002730 +BDMAP_00004825 +BDMAP_00000039 +BDMAP_00004615 +BDMAP_00003736 +BDMAP_00005097 +BDMAP_00003074 +BDMAP_00000662 +BDMAP_00001122 +BDMAP_00002252 +BDMAP_00001396 +BDMAP_00004011 +BDMAP_00004981 +BDMAP_00004165 +BDMAP_00003920 +BDMAP_00001215 +BDMAP_00003867 +BDMAP_00000923 +BDMAP_00002626 +BDMAP_00003315 +BDMAP_00000660 +BDMAP_00000329 +BDMAP_00004508 +BDMAP_00001518 +BDMAP_00003849 +BDMAP_00003897 +BDMAP_00003300 +BDMAP_00002253 +BDMAP_00003514 +BDMAP_00000117 +BDMAP_00002421 +BDMAP_00001413 +BDMAP_00004328 +BDMAP_00001130 +BDMAP_00000043 +BDMAP_00001410 +BDMAP_00000245 +BDMAP_00004117 +BDMAP_00002401 +BDMAP_00003857 +BDMAP_00000921 +BDMAP_00000138 +BDMAP_00003113 +BDMAP_00003358 +BDMAP_00002099 +BDMAP_00004016 +BDMAP_00003439 +BDMAP_00002152 +BDMAP_00003767 +BDMAP_00001598 +BDMAP_00003482 +BDMAP_00003520 +BDMAP_00002075 +BDMAP_00000987 +BDMAP_00003946 +BDMAP_00005160 +BDMAP_00001286 +BDMAP_00003359 +BDMAP_00002661 +BDMAP_00004704 +BDMAP_00003994 +BDMAP_00002226 +BDMAP_00000968 +BDMAP_00003556 +BDMAP_00003236 +BDMAP_00001791 +BDMAP_00004712 +BDMAP_00001077 +BDMAP_00003955 +BDMAP_00002479 +BDMAP_00001865 +BDMAP_00001059 +BDMAP_00002704 +BDMAP_00000656 +BDMAP_00001379 +BDMAP_00000883 +BDMAP_00002856 +BDMAP_00004199 +BDMAP_00001200 +BDMAP_00005083 +BDMAP_00004552 +BDMAP_00000616 +BDMAP_00004834 +BDMAP_00004815 +BDMAP_00001826 +BDMAP_00000615 +BDMAP_00001045 +BDMAP_00002695 +BDMAP_00004017 +BDMAP_00002103 +BDMAP_00002057 +BDMAP_00004620 +BDMAP_00000128 +BDMAP_00001185 +BDMAP_00002612 +BDMAP_00005073 +BDMAP_00001753 +BDMAP_00004196 +BDMAP_00004281 +BDMAP_00002717 +BDMAP_00000263 +BDMAP_00004103 +BDMAP_00003381 +BDMAP_00001093 +BDMAP_00000373 +BDMAP_00000881 +BDMAP_00002230 +BDMAP_00001707 +BDMAP_00002476 +BDMAP_00003294 +BDMAP_00004482 +BDMAP_00003267 +BDMAP_00002710 +BDMAP_00002451 +BDMAP_00001270 +BDMAP_00004878 +BDMAP_00001784 +BDMAP_00001281 +BDMAP_00002283 +BDMAP_00001183 +BDMAP_00001945 +BDMAP_00004604 +BDMAP_00000413 +BDMAP_00003506 +BDMAP_00002458 +BDMAP_00000977 +BDMAP_00000833 +BDMAP_00001055 +BDMAP_00002495 +BDMAP_00000887 +BDMAP_00002496 +BDMAP_00002942 +BDMAP_00000574 +BDMAP_00001868 +BDMAP_00000547 +BDMAP_00001230 +BDMAP_00003762 +BDMAP_00003971 +BDMAP_00000321 +BDMAP_00004876 +BDMAP_00003833 +BDMAP_00003461 +BDMAP_00003301 +BDMAP_00002846 +BDMAP_00002582 +BDMAP_00001710 +BDMAP_00001487 +BDMAP_00000936 +BDMAP_00004121 +BDMAP_00004459 +BDMAP_00000219 +BDMAP_00000091 +BDMAP_00001283 +BDMAP_00000084 +BDMAP_00000516 +BDMAP_00004250 +BDMAP_00001732 +BDMAP_00003694 +BDMAP_00004031 +BDMAP_00001557 +BDMAP_00002437 +BDMAP_00002933 +BDMAP_00000264 +BDMAP_00005099 +BDMAP_00004296 +BDMAP_00001917 +BDMAP_00003252 +BDMAP_00004389 +BDMAP_00002463 +BDMAP_00004253 +BDMAP_00004910 +BDMAP_00003172 +BDMAP_00001624 +BDMAP_00003484 +BDMAP_00001907 +BDMAP_00003952 +BDMAP_00002653 +BDMAP_00000368 +BDMAP_00000569 +BDMAP_00004995 +BDMAP_00003956 +BDMAP_00003497 +BDMAP_00003058 +BDMAP_00000552 +BDMAP_00000481 +BDMAP_00000805 +BDMAP_00003002 +BDMAP_00000698 +BDMAP_00004783 +BDMAP_00001324 +BDMAP_00002133 +BDMAP_00005120 +BDMAP_00003581 +BDMAP_00004890 +BDMAP_00001533 +BDMAP_00004039 +BDMAP_00000190 +BDMAP_00004028 +BDMAP_00004130 +BDMAP_00001370 +BDMAP_00002805 +BDMAP_00001397 +BDMAP_00001126 +BDMAP_00001875 +BDMAP_00005130 +BDMAP_00003361 +BDMAP_00002485 +BDMAP_00001273 +BDMAP_00000582 +BDMAP_00003672 +BDMAP_00000778 +BDMAP_00002841 +BDMAP_00001242 +BDMAP_00000345 +BDMAP_00000036 +BDMAP_00003996 +BDMAP_00003701 +BDMAP_00003425 +BDMAP_00001656 +BDMAP_00001802 +BDMAP_00001420 +BDMAP_00003752 +BDMAP_00002924 +BDMAP_00003202 +BDMAP_00000831 +BDMAP_00003392 +BDMAP_00002022 +BDMAP_00001223 +BDMAP_00003457 +BDMAP_00001236 +BDMAP_00000810 +BDMAP_00004676 +BDMAP_00003847 +BDMAP_00001225 +BDMAP_00005168 +BDMAP_00004113 +BDMAP_00002828 +BDMAP_00004087 +BDMAP_00004407 +BDMAP_00002748 +BDMAP_00003516 +BDMAP_00004395 +BDMAP_00001985 +BDMAP_00001171 +BDMAP_00000101 +BDMAP_00002117 +BDMAP_00001434 +BDMAP_00000139 +BDMAP_00002465 +BDMAP_00001251 +BDMAP_00001908 +BDMAP_00002354 +BDMAP_00002776 +BDMAP_00004887 +BDMAP_00000066 +BDMAP_00003549 +BDMAP_00000812 +BDMAP_00000353 +BDMAP_00004894 +BDMAP_00004956 +BDMAP_00002871 +BDMAP_00004764 +BDMAP_00004551 +BDMAP_00002404 +BDMAP_00000059 +BDMAP_00002017 +BDMAP_00003558 +BDMAP_00004065 +BDMAP_00003406 +BDMAP_00002471 +BDMAP_00000941 +BDMAP_00003109 +BDMAP_00000511 +BDMAP_00000826 +BDMAP_00004839 +BDMAP_00004671 +BDMAP_00002930 +BDMAP_00004331 +BDMAP_00001664 +BDMAP_00001001 +BDMAP_00001766 +BDMAP_00003827 +BDMAP_00001258 +BDMAP_00001892 +BDMAP_00000062 +BDMAP_00000867 +BDMAP_00002803 +BDMAP_00000285 +BDMAP_00001647 +BDMAP_00005077 +BDMAP_00000152 +BDMAP_00000709 +BDMAP_00002172 +BDMAP_00004148 +BDMAP_00001010 diff --git a/Generation_Pipeline_filter_all/syn_pancreas/requirements.txt b/Generation_Pipeline_filter_all/syn_pancreas/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9c1067c8eb4f44699bbf517601396113cea4b370 --- /dev/null +++ b/Generation_Pipeline_filter_all/syn_pancreas/requirements.txt @@ -0,0 +1,94 @@ +absl-py==1.1.0 +accelerate==0.11.0 +aiohttp==3.8.1 +aiosignal==1.2.0 +antlr4-python3-runtime==4.9.3 +async-timeout==4.0.2 +attrs==21.4.0 +autopep8==1.6.0 +cachetools==5.2.0 +certifi==2022.6.15 +charset-normalizer==2.0.12 +click==8.1.3 +cycler==0.11.0 +Deprecated==1.2.13 +docker-pycreds==0.4.0 +einops==0.4.1 +einops-exts==0.0.3 +ema-pytorch==0.0.8 +fonttools==4.34.4 +frozenlist==1.3.0 +fsspec==2022.5.0 +ftfy==6.1.1 +future==0.18.2 +gitdb==4.0.9 +GitPython==3.1.27 +google-auth==2.9.0 +google-auth-oauthlib==0.4.6 +grpcio==1.47.0 +h5py==3.7.0 +humanize==4.2.2 +hydra-core==1.2.0 +idna==3.3 +imageio==2.19.3 +imageio-ffmpeg==0.4.7 +importlib-metadata==4.12.0 +importlib-resources==5.9.0 +joblib==1.1.0 +kiwisolver==1.4.3 +lxml==4.9.1 +Markdown==3.3.7 +matplotlib==3.5.2 +multidict==6.0.2 +networkx==2.8.5 +nibabel==4.0.1 +nilearn==0.9.1 +numpy==1.23.0 +oauthlib==3.2.0 +omegaconf==2.2.3 +pandas==1.4.3 +Pillow==9.1.1 +pyasn1==0.4.8 +pyasn1-modules==0.2.8 +pycodestyle==2.8.0 +pyDeprecate==0.3.1 +pydicom==2.3.0 +pytorch-lightning==1.6.4 +pytz==2022.1 +PyWavelets==1.3.0 +PyYAML==6.0 +pyzmq==19.0.2 +regex==2022.6.2 +requests==2.28.0 +requests-oauthlib==1.3.1 +rotary-embedding-torch==0.1.5 +rsa==4.8 +scikit-image==0.19.3 +scikit-learn==1.1.2 +scikit-video==1.1.11 +scipy==1.8.1 +seaborn==0.11.2 +sentry-sdk==1.7.2 +setproctitle==1.2.3 +shortuuid==1.0.9 +SimpleITK==2.1.1.2 +smmap==5.0.0 +tensorboard==2.9.1 +tensorboard-data-server==0.6.1 +tensorboard-plugin-wit==1.8.1 +threadpoolctl==3.1.0 +tifffile==2022.8.3 +toml==0.10.2 +torch-tb-profiler==0.4.0 +torchio==0.18.80 +torchmetrics==0.9.1 +tqdm==4.64.0 +typing_extensions==4.2.0 +urllib3==1.26.9 +wandb==0.12.21 +Werkzeug==2.1.2 +wrapt==1.14.1 +yarl==1.7.2 +zipp==3.8.0 +wandb +tensorboardX==2.4.1 diff --git a/Generation_Pipeline_filter_all/val_set/bodymap_colon.txt b/Generation_Pipeline_filter_all/val_set/bodymap_colon.txt new file mode 100644 index 0000000000000000000000000000000000000000..c529ab8278b77ca39b67afc9ea118c67cb4ebcec --- /dev/null +++ b/Generation_Pipeline_filter_all/val_set/bodymap_colon.txt @@ -0,0 +1,25 @@ +BDMAP_00004910 +BDMAP_00001438 +BDMAP_00000568 +BDMAP_00002828 +BDMAP_00003634 +BDMAP_00004121 +BDMAP_00004764 +BDMAP_00003972 +BDMAP_00003113 +BDMAP_00005001 +BDMAP_00001785 +BDMAP_00005016 +BDMAP_00002739 +BDMAP_00003299 +BDMAP_00003357 +BDMAP_00001078 +BDMAP_00000874 +BDMAP_00003560 +BDMAP_00003373 +BDMAP_00003172 +BDMAP_00002875 +BDMAP_00000552 +BDMAP_00003510 +BDMAP_00004604 +BDMAP_00002598 diff --git a/Generation_Pipeline_filter_all/val_set/bodymap_kidney.txt b/Generation_Pipeline_filter_all/val_set/bodymap_kidney.txt new file mode 100644 index 0000000000000000000000000000000000000000..f420fbe5a18b65775d1eb8bc4362eb4e9cc1421b --- /dev/null +++ b/Generation_Pipeline_filter_all/val_set/bodymap_kidney.txt @@ -0,0 +1,24 @@ +BDMAP_00000487 +BDMAP_00002631 +BDMAP_00002744 +BDMAP_00000833 +BDMAP_00002648 +BDMAP_00002840 +BDMAP_00000608 +BDMAP_00002804 +BDMAP_00002775 +BDMAP_00004551 +BDMAP_00001413 +BDMAP_00000511 +BDMAP_00003150 +BDMAP_00000794 +BDMAP_00001255 +BDMAP_00002242 +BDMAP_00004746 +BDMAP_00002864 +BDMAP_00003486 +BDMAP_00004250 +BDMAP_00003143 +BDMAP_00003164 +BDMAP_00004578 +BDMAP_00001735 diff --git a/Generation_Pipeline_filter_all/val_set/bodymap_liver.txt b/Generation_Pipeline_filter_all/val_set/bodymap_liver.txt new file mode 100644 index 0000000000000000000000000000000000000000..d16f55abd27cc01769b6e96a3a9606a3abb568c9 --- /dev/null +++ b/Generation_Pipeline_filter_all/val_set/bodymap_liver.txt @@ -0,0 +1,25 @@ +BDMAP_00004281 +BDMAP_00003481 +BDMAP_00004890 +BDMAP_00001786 +BDMAP_00000101 +BDMAP_00004117 +BDMAP_00000615 +BDMAP_00000921 +BDMAP_00005130 +BDMAP_00004378 +BDMAP_00004704 +BDMAP_00003439 +BDMAP_00002717 +BDMAP_00004878 +BDMAP_00000100 +BDMAP_00001309 +BDMAP_00002214 +BDMAP_00001198 +BDMAP_00001962 +BDMAP_00002463 +BDMAP_00005139 +BDMAP_00000831 +BDMAP_00002955 +BDMAP_00003272 +BDMAP_00000745 diff --git a/Generation_Pipeline_filter_all/val_set/bodymap_pancreas.txt b/Generation_Pipeline_filter_all/val_set/bodymap_pancreas.txt new file mode 100644 index 0000000000000000000000000000000000000000..6da3361e4b14cfd1237e52dc2fd014ae4681c135 --- /dev/null +++ b/Generation_Pipeline_filter_all/val_set/bodymap_pancreas.txt @@ -0,0 +1,24 @@ +BDMAP_00000332 +BDMAP_00004858 +BDMAP_00005155 +BDMAP_00001205 +BDMAP_00004770 +BDMAP_00001361 +BDMAP_00002944 +BDMAP_00003961 +BDMAP_00000430 +BDMAP_00000679 +BDMAP_00003809 +BDMAP_00004115 +BDMAP_00003367 +BDMAP_00002899 +BDMAP_00003771 +BDMAP_00003502 +BDMAP_00001628 +BDMAP_00003884 +BDMAP_00005074 +BDMAP_00003114 +BDMAP_00004741 +BDMAP_00001746 +BDMAP_00002603 +BDMAP_00004128