gunateja commited on
Commit
02cb6a1
·
verified ·
1 Parent(s): 6dd6101

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -29
app.py CHANGED
@@ -1,38 +1,51 @@
1
  import gradio as gr
2
- from huggingface_hub import from_pretrained_keras
3
- import tensorflow as tf
4
- import numpy as np
5
  from PIL import Image
6
  import io
7
 
8
- # Load the model from Hugging Face Hub using from_pretrained_keras
9
- model = from_pretrained_keras("google/maxim-s3-deblurring-reds")
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def deblur_image(input_image):
12
- # Preprocess the input image
13
- image = np.array(input_image)
14
- image = tf.convert_to_tensor(image)
15
- image = tf.image.resize(image, (256, 256))
16
-
17
- # Make predictions with the model
18
- predictions = model.predict(tf.expand_dims(image, 0))
19
-
20
- # Convert the prediction back to an image
21
- output_image = predictions[0].numpy().astype(np.uint8)
22
- output_image = Image.fromarray(output_image)
23
-
24
- # Save the result in memory as a byte array
25
- byte_arr = io.BytesIO()
26
- output_image.save(byte_arr, format='PNG')
27
- byte_arr.seek(0)
28
-
29
- return byte_arr
30
 
31
- # Set up the Gradio interface
32
- iface = gr.Interface(fn=deblur_image,
33
- inputs=gr.inputs.Image(type="pil"),
34
- outputs=gr.outputs.Image(type="file"),
35
- live=True)
 
 
 
 
 
 
 
36
 
37
- # Launch the Gradio interface
38
  iface.launch()
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
+ import torch
 
4
  from PIL import Image
5
  import io
6
 
7
+ # Check for CUDA availability and set device
8
+ if torch.cuda.is_available():
9
+ device = torch.device("cuda")
10
+ print(f"Using CUDA device: {torch.cuda.get_device_name(0)}")
11
+ else:
12
+ device = torch.device("cpu")
13
+ print("Using CPU. CUDA is not available.")
14
+
15
+ try:
16
+ # Initialize the deblurring pipeline with the specified model and device
17
+ deblurrer = pipeline("image-to-image", model="google/maxim-s3-deblurring-reds", device=device)
18
+ except Exception as e:
19
+ print(f"Error loading the model: {e}")
20
+ exit() # Exit if model loading fails
21
 
22
  def deblur_image(input_image):
23
+ try:
24
+ output = deblurrer(input_image)
25
+ deblurred_image = output[0]
26
+
27
+ # Convert PIL Image to Bytes for download
28
+ img_byte_arr = io.BytesIO()
29
+ deblurred_image.save(img_byte_arr, format='PNG') # or JPEG, etc.
30
+ img_byte_arr = img_byte_arr.getvalue()
31
+
32
+ return deblurred_image, img_byte_arr # Return both image and bytes
33
+ except Exception as e:
34
+ print(f"Error during deblurring: {e}")
35
+ return None, None
 
 
 
 
 
36
 
37
+ # Create the Gradio interface
38
+ iface = gr.Interface(
39
+ fn=deblur_image,
40
+ inputs=gr.Image(type="pil", label="Upload Blurred Image"),
41
+ outputs=[
42
+ gr.Image(type="pil", label="Deblurred Image"),
43
+ gr.File(label="Download Deblurred Image", file_types=[".png", ".jpg", ".jpeg"]) # Added File output
44
+ ],
45
+ title="Deblurring App",
46
+ description="Deblur your images using the google/maxim-s3-deblurring-reds model.",
47
+ examples=[["blurred_image.jpg"]],
48
+ )
49
 
50
+ # Launch the Gradio app
51
  iface.launch()