Spaces:
Runtime error
Runtime error
File size: 1,951 Bytes
82449ec c6670f7 82449ec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import gradio as gr
import torch
from torchvision import transforms
from PIL import Image
from .colorization_model import ColorizationModel # Import your model class
# Load the trained generator model
model_path = "generator.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Define model options (replace with your configuration)
class Options:
input_nc = 1
output_nc = 2
ngf = 64
netG = "unet_256"
norm = "batch"
no_dropout = False
init_type = "normal"
init_gain = 0.02
gpu_ids = [0] if torch.cuda.is_available() else []
opt = Options()
generator = ColorizationModel(opt).netG
generator.load_state_dict(torch.load(model_path, map_location=device))
generator.eval().to(device)
# Define preprocessing and postprocessing steps
def preprocess_image(image):
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize((256, 256)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
return transform(image).unsqueeze(0).to(device)
def postprocess_image(output):
output = output.squeeze(0).cpu().detach()
output = torch.cat([output[0:1, :, :] * 50.0 + 50.0, output[1:, :, :] * 110.0], dim=0)
output_image = transforms.ToPILImage()(output)
return output_image
# Gradio interface function
def colorize(grayscale_image):
input_tensor = preprocess_image(grayscale_image)
with torch.no_grad():
colorized = generator(input_tensor)
return postprocess_image(colorized)
# Define Gradio interface
interface = gr.Interface(
fn=colorize,
inputs=gr.Image(type="pil", label="Grayscale Image"),
outputs=gr.Image(type="pil", label="Colorized Image"),
title="Pix2Pix Image Colorization",
description="Upload a grayscale image, and the model will colorize it using Pix2Pix GAN."
)
# Launch the app
if __name__ == "__main__":
interface.launch()
|