Spaces:
				
			
			
	
			
			
		Build error
		
	
	
	
			
			
	
	
	
	
		
		
		Build error
		
	Commit 
							
							·
						
						bee801c
	
1
								Parent(s):
							
							dde7894
								
Add model and demo
Browse files- Procfile +1 -0
 - app.py +91 -0
 - autoencoder_model.png +0 -0
 - model-final.pth +3 -0
 - predict.py +79 -0
 - prediction.ipynb +0 -0
 - requirements.txt +8 -0
 
    	
        Procfile
    ADDED
    
    | 
         @@ -0,0 +1 @@ 
     | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
              web: sh setup.sh && streamlit run app.py
         
     | 
    	
        app.py
    ADDED
    
    | 
         @@ -0,0 +1,91 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import PIL
         
     | 
| 2 | 
         
            +
            import torch
         
     | 
| 3 | 
         
            +
            import torch.nn as nn
         
     | 
| 4 | 
         
            +
            import cv2
         
     | 
| 5 | 
         
            +
            from skimage.color import lab2rgb, rgb2lab, rgb2gray
         
     | 
| 6 | 
         
            +
            from skimage import io
         
     | 
| 7 | 
         
            +
            import matplotlib.pyplot as plt
         
     | 
| 8 | 
         
            +
            import numpy as np
         
     | 
| 9 | 
         
            +
             
     | 
| 10 | 
         
            +
            class ColorizationNet(nn.Module):
         
     | 
| 11 | 
         
            +
              def __init__(self, input_size=128):
         
     | 
| 12 | 
         
            +
                super(ColorizationNet, self).__init__()
         
     | 
| 13 | 
         
            +
             
     | 
| 14 | 
         
            +
                MIDLEVEL_FEATURE_SIZE = 128
         
     | 
| 15 | 
         
            +
                resnet=models.resnet18(pretrained=True)
         
     | 
| 16 | 
         
            +
                resnet.conv1.weight=nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1))
         
     | 
| 17 | 
         
            +
                
         
     | 
| 18 | 
         
            +
                self.midlevel_resnet =nn.Sequential(*list(resnet.children())[0:6])
         
     | 
| 19 | 
         
            +
             
     | 
| 20 | 
         
            +
                self.upsample = nn.Sequential(     
         
     | 
| 21 | 
         
            +
                  nn.Conv2d(MIDLEVEL_FEATURE_SIZE, 128, kernel_size=3, stride=1, padding=1),
         
     | 
| 22 | 
         
            +
                  nn.BatchNorm2d(128),
         
     | 
| 23 | 
         
            +
                  nn.ReLU(),
         
     | 
| 24 | 
         
            +
                  nn.Upsample(scale_factor=2),
         
     | 
| 25 | 
         
            +
                  nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
         
     | 
| 26 | 
         
            +
                  nn.BatchNorm2d(64),
         
     | 
| 27 | 
         
            +
                  nn.ReLU(),
         
     | 
| 28 | 
         
            +
                  nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
         
     | 
| 29 | 
         
            +
                  nn.BatchNorm2d(64),
         
     | 
| 30 | 
         
            +
                  nn.ReLU(),
         
     | 
| 31 | 
         
            +
                  nn.Upsample(scale_factor=2),
         
     | 
| 32 | 
         
            +
                  nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
         
     | 
| 33 | 
         
            +
                  nn.BatchNorm2d(32),
         
     | 
| 34 | 
         
            +
                  nn.ReLU(),
         
     | 
| 35 | 
         
            +
                  nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1),
         
     | 
| 36 | 
         
            +
                  nn.Upsample(scale_factor=2)
         
     | 
| 37 | 
         
            +
                )
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
              def forward(self, input):
         
     | 
| 40 | 
         
            +
             
     | 
| 41 | 
         
            +
                # Pass input through ResNet-gray to extract features
         
     | 
| 42 | 
         
            +
                midlevel_features = self.midlevel_resnet(input)
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                # Upsample to get colors
         
     | 
| 45 | 
         
            +
                output = self.upsample(midlevel_features)
         
     | 
| 46 | 
         
            +
                return output
         
     | 
| 47 | 
         
            +
             
     | 
| 48 | 
         
            +
                
         
     | 
| 49 | 
         
            +
                
         
     | 
| 50 | 
         
            +
            def show_output(grayscale_input, ab_input):
         
     | 
| 51 | 
         
            +
              '''Show/save rgb image from grayscale and ab channels
         
     | 
| 52 | 
         
            +
                 Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
         
     | 
| 53 | 
         
            +
              color_image = torch.cat((grayscale_input, ab_input), 0).detach().numpy() # combine channels
         
     | 
| 54 | 
         
            +
              color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
         
     | 
| 55 | 
         
            +
              color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
         
     | 
| 56 | 
         
            +
              color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
         
     | 
| 57 | 
         
            +
              color_image = lab2rgb(color_image.astype(np.float64))
         
     | 
| 58 | 
         
            +
              grayscale_input = grayscale_input.squeeze().numpy()
         
     | 
| 59 | 
         
            +
              # plt.imshow(grayscale_input)
         
     | 
| 60 | 
         
            +
              # plt.imshow(color_image)
         
     | 
| 61 | 
         
            +
              return color_image
         
     | 
| 62 | 
         
            +
             
     | 
| 63 | 
         
            +
            def colorize(img,print_img=True):
         
     | 
| 64 | 
         
            +
                # img=cv2.imread(img)
         
     | 
| 65 | 
         
            +
                img=cv2.resize(img,(224,224))
         
     | 
| 66 | 
         
            +
                grayscale_input= torch.Tensor(rgb2gray(img))
         
     | 
| 67 | 
         
            +
                ab_input=model(grayscale_input.unsqueeze(0).unsqueeze(0)).squeeze(0)
         
     | 
| 68 | 
         
            +
                predicted=show_output(grayscale_input.unsqueeze(0), ab_input)
         
     | 
| 69 | 
         
            +
                if print_img:
         
     | 
| 70 | 
         
            +
                    plt.imshow(predicted)
         
     | 
| 71 | 
         
            +
                return predicted
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            # device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
         
     | 
| 74 | 
         
            +
            # torch.load with map_location=torch.device('cpu') 
         
     | 
| 75 | 
         
            +
            model=torch.load("model-final.pth",map_location ='cpu')
         
     | 
| 76 | 
         
            +
             
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
            import streamlit as st
         
     | 
| 79 | 
         
            +
            st.title("Image Colorizer")
         
     | 
| 80 | 
         
            +
                    
         
     | 
| 81 | 
         
            +
            file=st.file_uploader("Please upload the B/W image",type=["jpg","jpeg","png"])
         
     | 
| 82 | 
         
            +
            print(file)
         
     | 
| 83 | 
         
            +
            if file is None:
         
     | 
| 84 | 
         
            +
              st.text("Please Upload an image")
         
     | 
| 85 | 
         
            +
            else:
         
     | 
| 86 | 
         
            +
                file_bytes = np.asarray(bytearray(file.read()), dtype=np.uint8)
         
     | 
| 87 | 
         
            +
                opencv_image = cv2.imdecode(file_bytes, 1)
         
     | 
| 88 | 
         
            +
                im=colorize(opencv_image)
         
     | 
| 89 | 
         
            +
                st.image(im)
         
     | 
| 90 | 
         
            +
                st.text("Colorized!!")
         
     | 
| 91 | 
         
            +
                # st.image(file)
         
     | 
    	
        autoencoder_model.png
    ADDED
    
    
											 
									 | 
									
								
    	
        model-final.pth
    ADDED
    
    | 
         @@ -0,0 +1,3 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            version https://git-lfs.github.com/spec/v1
         
     | 
| 2 | 
         
            +
            oid sha256:6268c0b73c7bc3fefd3918d113fb74976f9780f4737bf6e4c088811a1a6872ec
         
     | 
| 3 | 
         
            +
            size 3867929
         
     | 
    	
        predict.py
    ADDED
    
    | 
         @@ -0,0 +1,79 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            import sys
         
     | 
| 2 | 
         
            +
            sys.path.insert(0, './WordLM')
         
     | 
| 3 | 
         
            +
             
     | 
| 4 | 
         
            +
            import PIL
         
     | 
| 5 | 
         
            +
            import torch
         
     | 
| 6 | 
         
            +
            import torch.nn as nn
         
     | 
| 7 | 
         
            +
            import cv2
         
     | 
| 8 | 
         
            +
            from skimage.color import lab2rgb, rgb2lab, rgb2gray
         
     | 
| 9 | 
         
            +
            from skimage import io
         
     | 
| 10 | 
         
            +
            import matplotlib.pyplot as plt
         
     | 
| 11 | 
         
            +
            import numpy as np
         
     | 
| 12 | 
         
            +
             
     | 
| 13 | 
         
            +
            class ColorizationNet(nn.Module):
         
     | 
| 14 | 
         
            +
              def __init__(self, input_size=128):
         
     | 
| 15 | 
         
            +
                super(ColorizationNet, self).__init__()
         
     | 
| 16 | 
         
            +
             
     | 
| 17 | 
         
            +
                MIDLEVEL_FEATURE_SIZE = 128
         
     | 
| 18 | 
         
            +
                resnet=models.resnet18(pretrained=True)
         
     | 
| 19 | 
         
            +
                resnet.conv1.weight=nn.Parameter(resnet.conv1.weight.sum(dim=1).unsqueeze(1))
         
     | 
| 20 | 
         
            +
                
         
     | 
| 21 | 
         
            +
                self.midlevel_resnet =nn.Sequential(*list(resnet.children())[0:6])
         
     | 
| 22 | 
         
            +
             
     | 
| 23 | 
         
            +
                self.upsample = nn.Sequential(     
         
     | 
| 24 | 
         
            +
                  nn.Conv2d(MIDLEVEL_FEATURE_SIZE, 128, kernel_size=3, stride=1, padding=1),
         
     | 
| 25 | 
         
            +
                  nn.BatchNorm2d(128),
         
     | 
| 26 | 
         
            +
                  nn.ReLU(),
         
     | 
| 27 | 
         
            +
                  nn.Upsample(scale_factor=2),
         
     | 
| 28 | 
         
            +
                  nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
         
     | 
| 29 | 
         
            +
                  nn.BatchNorm2d(64),
         
     | 
| 30 | 
         
            +
                  nn.ReLU(),
         
     | 
| 31 | 
         
            +
                  nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
         
     | 
| 32 | 
         
            +
                  nn.BatchNorm2d(64),
         
     | 
| 33 | 
         
            +
                  nn.ReLU(),
         
     | 
| 34 | 
         
            +
                  nn.Upsample(scale_factor=2),
         
     | 
| 35 | 
         
            +
                  nn.Conv2d(64, 32, kernel_size=3, stride=1, padding=1),
         
     | 
| 36 | 
         
            +
                  nn.BatchNorm2d(32),
         
     | 
| 37 | 
         
            +
                  nn.ReLU(),
         
     | 
| 38 | 
         
            +
                  nn.Conv2d(32, 2, kernel_size=3, stride=1, padding=1),
         
     | 
| 39 | 
         
            +
                  nn.Upsample(scale_factor=2)
         
     | 
| 40 | 
         
            +
                )
         
     | 
| 41 | 
         
            +
             
     | 
| 42 | 
         
            +
              def forward(self, input):
         
     | 
| 43 | 
         
            +
             
     | 
| 44 | 
         
            +
                # Pass input through ResNet-gray to extract features
         
     | 
| 45 | 
         
            +
                midlevel_features = self.midlevel_resnet(input)
         
     | 
| 46 | 
         
            +
             
     | 
| 47 | 
         
            +
                # Upsample to get colors
         
     | 
| 48 | 
         
            +
                output = self.upsample(midlevel_features)
         
     | 
| 49 | 
         
            +
                return output
         
     | 
| 50 | 
         
            +
             
     | 
| 51 | 
         
            +
                
         
     | 
| 52 | 
         
            +
                
         
     | 
| 53 | 
         
            +
            def show_output(grayscale_input, ab_input):
         
     | 
| 54 | 
         
            +
              '''Show/save rgb image from grayscale and ab channels
         
     | 
| 55 | 
         
            +
                 Input save_path in the form {'grayscale': '/path/', 'colorized': '/path/'}'''
         
     | 
| 56 | 
         
            +
              color_image = torch.cat((grayscale_input, ab_input), 0).detach().numpy() # combine channels
         
     | 
| 57 | 
         
            +
              color_image = color_image.transpose((1, 2, 0))  # rescale for matplotlib
         
     | 
| 58 | 
         
            +
              color_image[:, :, 0:1] = color_image[:, :, 0:1] * 100
         
     | 
| 59 | 
         
            +
              color_image[:, :, 1:3] = color_image[:, :, 1:3] * 255 - 128   
         
     | 
| 60 | 
         
            +
              color_image = lab2rgb(color_image.astype(np.float64))
         
     | 
| 61 | 
         
            +
              grayscale_input = grayscale_input.squeeze().numpy()
         
     | 
| 62 | 
         
            +
              # plt.imshow(grayscale_input)
         
     | 
| 63 | 
         
            +
              # plt.imshow(color_image)
         
     | 
| 64 | 
         
            +
              return color_image
         
     | 
| 65 | 
         
            +
             
     | 
| 66 | 
         
            +
            model=torch.load("model-final.pth")
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            def colorize(img_path,print_img=True):
         
     | 
| 69 | 
         
            +
                img=cv2.imread(img_path)
         
     | 
| 70 | 
         
            +
                img=cv2.resize(img,(224,224))
         
     | 
| 71 | 
         
            +
                grayscale_input= torch.Tensor(rgb2gray(img))
         
     | 
| 72 | 
         
            +
                ab_input=model(grayscale_input.unsqueeze(0).unsqueeze(0)).squeeze(0)
         
     | 
| 73 | 
         
            +
                predicted=show_output(grayscale_input.unsqueeze(0), ab_input)
         
     | 
| 74 | 
         
            +
                if print_img:
         
     | 
| 75 | 
         
            +
                    plt.imshow(predicted)
         
     | 
| 76 | 
         
            +
                return predicted
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
            # out=colorize("download.png")
         
     | 
| 79 | 
         
            +
            # print(out)
         
     | 
    	
        prediction.ipynb
    ADDED
    
    | 
         The diff for this file is too large to render. 
		See raw diff 
     | 
| 
         | 
    	
        requirements.txt
    ADDED
    
    | 
         @@ -0,0 +1,8 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            -f https://download.pytorch.org/whl/torch_stable.html
         
     | 
| 2 | 
         
            +
            torch==1.7.1+cpu
         
     | 
| 3 | 
         
            +
            torchvision==0.9.1+cpu
         
     | 
| 4 | 
         
            +
            numpy==1.18.5
         
     | 
| 5 | 
         
            +
            opencv-python-headless==4.4.0.46
         
     | 
| 6 | 
         
            +
            matplotlib==3.4.2
         
     | 
| 7 | 
         
            +
            scikit-image==0.18.1
         
     | 
| 8 | 
         
            +
            streamlit==0.81.1
         
     |