import gradio as gr import torch from torchvision import transforms from PIL import Image from .colorization_model import ColorizationModel # Import your model class # Load the trained generator model model_path = "generator.pth" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Define model options (replace with your configuration) class Options: input_nc = 1 output_nc = 2 ngf = 64 netG = "unet_256" norm = "batch" no_dropout = False init_type = "normal" init_gain = 0.02 gpu_ids = [0] if torch.cuda.is_available() else [] opt = Options() generator = ColorizationModel(opt).netG generator.load_state_dict(torch.load(model_path, map_location=device)) generator.eval().to(device) # Define preprocessing and postprocessing steps def preprocess_image(image): transform = transforms.Compose([ transforms.Grayscale(num_output_channels=1), transforms.Resize((256, 256)), transforms.ToTensor(), transforms.Normalize(mean=[0.5], std=[0.5]) ]) return transform(image).unsqueeze(0).to(device) def postprocess_image(output): output = output.squeeze(0).cpu().detach() output = torch.cat([output[0:1, :, :] * 50.0 + 50.0, output[1:, :, :] * 110.0], dim=0) output_image = transforms.ToPILImage()(output) return output_image # Gradio interface function def colorize(grayscale_image): input_tensor = preprocess_image(grayscale_image) with torch.no_grad(): colorized = generator(input_tensor) return postprocess_image(colorized) # Define Gradio interface interface = gr.Interface( fn=colorize, inputs=gr.Image(type="pil", label="Grayscale Image"), outputs=gr.Image(type="pil", label="Colorized Image"), title="Pix2Pix Image Colorization", description="Upload a grayscale image, and the model will colorize it using Pix2Pix GAN." ) # Launch the app if __name__ == "__main__": interface.launch()