Deblurring-App / app.py
gunateja's picture
Update app.py
02cb6a1 verified
raw
history blame
1.67 kB
import gradio as gr
from transformers import pipeline
import torch
from PIL import Image
import io
# Check for CUDA availability and set device
if torch.cuda.is_available():
device = torch.device("cuda")
print(f"Using CUDA device: {torch.cuda.get_device_name(0)}")
else:
device = torch.device("cpu")
print("Using CPU. CUDA is not available.")
try:
# Initialize the deblurring pipeline with the specified model and device
deblurrer = pipeline("image-to-image", model="google/maxim-s3-deblurring-reds", device=device)
except Exception as e:
print(f"Error loading the model: {e}")
exit() # Exit if model loading fails
def deblur_image(input_image):
try:
output = deblurrer(input_image)
deblurred_image = output[0]
# Convert PIL Image to Bytes for download
img_byte_arr = io.BytesIO()
deblurred_image.save(img_byte_arr, format='PNG') # or JPEG, etc.
img_byte_arr = img_byte_arr.getvalue()
return deblurred_image, img_byte_arr # Return both image and bytes
except Exception as e:
print(f"Error during deblurring: {e}")
return None, None
# Create the Gradio interface
iface = gr.Interface(
fn=deblur_image,
inputs=gr.Image(type="pil", label="Upload Blurred Image"),
outputs=[
gr.Image(type="pil", label="Deblurred Image"),
gr.File(label="Download Deblurred Image", file_types=[".png", ".jpg", ".jpeg"]) # Added File output
],
title="Deblurring App",
description="Deblur your images using the google/maxim-s3-deblurring-reds model.",
examples=[["blurred_image.jpg"]],
)
# Launch the Gradio app
iface.launch()