datafreak commited on
Commit
a41434f
·
verified ·
1 Parent(s): a361ab7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ from PIL import Image
5
+ from tensorflow_examples.models.pix2pix import pix2pix
6
+
7
+ OUTPUT_CHANNELS = 3
8
+
9
+ generator_g = pix2pix.unet_generator(OUTPUT_CHANNELS, norm_type='instancenorm')
10
+ generator_g.load_weights("vibrantGAN-generator-g-final.weights.h5")
11
+
12
+ def preprocess_single_image(image, target_height=256, target_width=256):
13
+ # Convert PIL image to tensorflow tensor
14
+ image = tf.convert_to_tensor(np.array(image))
15
+
16
+ # Ensure image has 3 channels (RGB)
17
+ if len(image.shape) == 2: # If grayscale
18
+ image = tf.stack([image, image, image], axis=-1)
19
+ elif image.shape[-1] == 4: # If RGBA
20
+ image = image[:, :, :3]
21
+
22
+ # Resize the image
23
+ image = tf.image.resize(image, [target_height, target_width])
24
+
25
+ # Normalize to [-1, 1]
26
+ image = tf.cast(image, tf.float32)
27
+ image = (image / 127.5) - 1
28
+
29
+ return image
30
+
31
+ def process_image(input_image):
32
+ if input_image is None:
33
+ return None
34
+
35
+ # Get original input image size
36
+ original_size = input_image.size
37
+
38
+ # Preprocess the image
39
+ processed_input = preprocess_single_image(input_image)
40
+
41
+ # Add batch dimension
42
+ processed_input = tf.expand_dims(processed_input, 0)
43
+
44
+ # Generate prediction
45
+ prediction = generator_g(processed_input)
46
+
47
+ # Convert the prediction to displayable format
48
+ output_image = prediction[0] * 0.5 + 0.5 # Denormalize to [0, 1]
49
+ output_image = tf.clip_by_value(output_image, 0, 1)
50
+
51
+ # Convert to numpy array and then to PIL Image
52
+ output_array = (output_image.numpy() * 255).astype(np.uint8)
53
+ output_pil = Image.fromarray(output_array)
54
+
55
+ return output_pil
56
+
57
+ # Create Gradio interface
58
+ demo = gr.Interface(
59
+ fn=process_image,
60
+ inputs=gr.Image(type="pil", label="Input Image"),
61
+ outputs=gr.Image(type="pil", label="Generated Output"),
62
+ title="Image Processing Model",
63
+ description="Upload an image to see the model's output.",
64
+ )
65
+
66
+ # Launch the interface
67
+ demo.launch(debug=True)