NimaBoscarino commited on
Commit
c557eb7
·
1 Parent(s): 4756ce1

Fix image output dims

Browse files
Files changed (2) hide show
  1. app.py +11 -6
  2. inferences.py +7 -5
app.py CHANGED
@@ -17,7 +17,7 @@ model = ClimateGAN(model_path="config/model/masker")
17
  def predict(place):
18
  geocode_result = gmaps.geocode(place)
19
  loc = geocode_result[0]['geometry']['location']
20
- static_map_url = f"https://maps.googleapis.com/maps/api/streetview?size=400x400&location={loc['lat']},{loc['lng']}&fov=80&heading=70&pitch=0&key={API_KEY}"
21
 
22
  img_np = io.imread(static_map_url)
23
  flood, wildfire, smog = model.inference(img_np)
@@ -29,14 +29,19 @@ gr.Interface(
29
  inputs=[
30
  gr.inputs.Textbox(label="Address or place name")
31
  ],
32
- outputs=["image", "image", "image", "image"],
 
 
 
 
 
33
  title="ClimateGAN",
34
  description="Enter an address or place name, and ClimateGAN will generate images showing how the location could be impacted by flooding, wildfires, or smog.",
35
  article="<p style='text-align: center'>This project is a clone of <a href='https://thisclimatedoesnotexist.com/'>ThisClimateDoesNotExist</a> | <a href='https://github.com/cc-ai/climategan'>ClimateGAN GitHub Repo</a></p>",
36
  examples=[
37
- "Kafka's Great Northern Way, Vancouver",
38
- "Simon Fraser University",
39
- "Duomo, Milano"
40
  ],
41
  css=".footer{display:none !important}",
42
- ).launch()
 
17
  def predict(place):
18
  geocode_result = gmaps.geocode(place)
19
  loc = geocode_result[0]['geometry']['location']
20
+ static_map_url = f"https://maps.googleapis.com/maps/api/streetview?size=640x640&location={loc['lat']},{loc['lng']}&source=outdoor&key={API_KEY}"
21
 
22
  img_np = io.imread(static_map_url)
23
  flood, wildfire, smog = model.inference(img_np)
 
29
  inputs=[
30
  gr.inputs.Textbox(label="Address or place name")
31
  ],
32
+ outputs=[
33
+ gr.outputs.Image(type="numpy", label="Original image"),
34
+ gr.outputs.Image(type="numpy", label="Flooding"),
35
+ gr.outputs.Image(type="numpy", label="Wildfire"),
36
+ gr.outputs.Image(type="numpy", label="Smog"),
37
+ ],
38
  title="ClimateGAN",
39
  description="Enter an address or place name, and ClimateGAN will generate images showing how the location could be impacted by flooding, wildfires, or smog.",
40
  article="<p style='text-align: center'>This project is a clone of <a href='https://thisclimatedoesnotexist.com/'>ThisClimateDoesNotExist</a> | <a href='https://github.com/cc-ai/climategan'>ClimateGAN GitHub Repo</a></p>",
41
  examples=[
42
+ "Vancouver Art Gallery",
43
+ "Chicago Bean",
44
+ "Duomo Siracusa"
45
  ],
46
  css=".footer{display:none !important}",
47
+ ).launch(cache_examples=True)
inferences.py CHANGED
@@ -82,9 +82,8 @@ class ClimateGAN():
82
 
83
  # Does all three inferences at the moment.
84
  def inference(self, orig_image):
85
- image, new_size = self._preprocess_image(orig_image)
86
 
87
- image = np.stack(image)
88
  # Retreive numpy events as a dict {event: array[BxHxWxC]}
89
  outputs = self.trainer.infer_all(
90
  image,
@@ -92,7 +91,11 @@ class ClimateGAN():
92
  bin_value=0.5,
93
  )
94
 
95
- return outputs['flood'], outputs['wildfire'], outputs['smog']
 
 
 
 
96
 
97
  def _preprocess_image(self, img):
98
  # rgba to rgb
@@ -100,8 +103,7 @@ class ClimateGAN():
100
 
101
  # to args.target_size
102
  data = resize_and_crop(data, self.target_size)
103
- new_size = (self.target_size, self.target_size)
104
 
105
  # resize() produces [0, 1] images, rescale to [-1, 1]
106
  data = to_m1_p1(data)
107
- return data, new_size
 
82
 
83
  # Does all three inferences at the moment.
84
  def inference(self, orig_image):
85
+ image = self._preprocess_image(orig_image)
86
 
 
87
  # Retreive numpy events as a dict {event: array[BxHxWxC]}
88
  outputs = self.trainer.infer_all(
89
  image,
 
91
  bin_value=0.5,
92
  )
93
 
94
+ return (
95
+ outputs['flood'].squeeze(),
96
+ outputs['wildfire'].squeeze(),
97
+ outputs['smog'].squeeze()
98
+ )
99
 
100
  def _preprocess_image(self, img):
101
  # rgba to rgb
 
103
 
104
  # to args.target_size
105
  data = resize_and_crop(data, self.target_size)
 
106
 
107
  # resize() produces [0, 1] images, rescale to [-1, 1]
108
  data = to_m1_p1(data)
109
+ return data