Spaces:
Runtime error
Runtime error
| from PIL import Image | |
| import json | |
| import torch | |
| from torchvision import transforms | |
| import cv2 | |
| import numpy as np | |
| import os | |
| import torch.nn as nn | |
| def show_cam_on_img(img, mask, img_path_save): | |
| heat_map = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET) | |
| heat_map = np.float32(heat_map) / 255 | |
| cam = heat_map + np.float32(img) | |
| cam = cam / np.max(cam) | |
| cv2.imwrite(img_path_save, np.uint8(255 * cam)) | |
| img_path_read = "" | |
| img_path_save = "" | |
| def main(): | |
| img = cv2.imread(img_path_read, flags=1) | |
| img = np.float32(cv2.resize(img, (224, 224))) / 255 | |
| # cam_all is the score tensor of shape (B, C, H, W), similar to y_raw in out Figure 1 | |
| # cls_idx specifying the i-th class out of C class | |
| # visualize the 0's class heatmap | |
| cls_idx = 0 | |
| cam = cam_all[cls_idx] | |
| # cam = nn.ReLU()(cam) | |
| cam = cam / torch.max(cam) | |
| cam = cv2.resize(np.array(cam), (224, 224)) | |
| show_cam_on_img(img, cam, img_path_save) | |