Suchinthana commited on
Commit
9b980f8
·
1 Parent(s): 77ca91f

Difference with actual map

Browse files
Files changed (1) hide show
  1. app.py +68 -9
app.py CHANGED
@@ -164,6 +164,54 @@ def generate_satellite_image(init_image, mask_image, prompt):
164
  )
165
  return result.images[0]
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  # Gradio UI
168
  @spaces.GPU
169
  def handle_query(query):
@@ -171,22 +219,33 @@ def handle_query(query):
171
  response = process_openai_response(query)
172
  geojson_data = generate_geojson(response)
173
 
174
- # Generate map image
175
  map_image = generate_static_map(geojson_data)
176
 
177
- # Generate mask for ControlNet
178
- empty_map = Image.new("RGB", map_image.size, "white")
179
- difference = np.array(map_image) - np.array(empty_map)
180
- mask = np.any(difference != 0, axis=-1).astype(np.uint8) * 255
 
 
 
 
 
 
 
 
181
 
182
- # Convert mask to PIL Image
183
- mask_image = Image.fromarray(mask)
184
 
185
- # Generate satellite image
186
- satellite_image = generate_satellite_image(map_image, mask_image, response['output']['feature_representation']['properties']['description'])
 
 
187
 
188
  return map_image, satellite_image, mask_image, response
189
 
 
190
  # Gradio interface
191
  with gr.Blocks() as demo:
192
  with gr.Row():
 
164
  )
165
  return result.images[0]
166
 
167
+ def get_bounds(geojson):
168
+ coordinates = []
169
+ for feature in geojson["features"]:
170
+ geom_type = feature["geometry"]["type"]
171
+ coords = feature["geometry"]["coordinates"]
172
+ if geom_type == "Point":
173
+ coordinates.append(coords)
174
+ elif geom_type in ["MultiPoint", "LineString"]:
175
+ coordinates.extend(coords)
176
+ elif geom_type in ["MultiLineString", "Polygon"]:
177
+ for part in coords:
178
+ coordinates.extend(part)
179
+ elif geom_type == "MultiPolygon":
180
+ for polygon in coords:
181
+ for part in polygon:
182
+ coordinates.extend(part)
183
+ lats = [coord[1] for coord in coordinates]
184
+ lngs = [coord[0] for coord in coordinates]
185
+ return [[min(lats), min(lngs)], [max(lats), max(lngs)]]
186
+
187
+ @spaces.GPU
188
+ def generate_static_map(geojson_data, bounds=None):
189
+ # Create a static map object with specified dimensions
190
+ m = StaticMap(600, 600)
191
+
192
+ if bounds:
193
+ center_lat = (bounds[0][0] + bounds[1][0]) / 2
194
+ center_lng = (bounds[0][1] + bounds[1][1]) / 2
195
+ zoom = 10 # Adjust zoom level as needed
196
+ m.set_center(center_lat, center_lng, zoom)
197
+
198
+ # Process each feature in the GeoJSON
199
+ for feature in geojson_data["features"]:
200
+ geom_type = feature["geometry"]["type"]
201
+ coords = feature["geometry"]["coordinates"]
202
+
203
+ if geom_type == "Point":
204
+ m.add_marker(CircleMarker((coords[0], coords[1]), 'blue', 10))
205
+ elif geom_type in ["MultiPoint", "LineString"]:
206
+ for coord in coords:
207
+ m.add_marker(CircleMarker((coord[0], coord[1]), 'blue', 10))
208
+ elif geom_type in ["Polygon", "MultiPolygon"]:
209
+ for polygon in coords:
210
+ m.add_polygon(Polygon([(c[0], c[1]) for c in polygon], 'blue', 3))
211
+
212
+ return m.render(zoom=10)
213
+
214
+
215
  # Gradio UI
216
  @spaces.GPU
217
  def handle_query(query):
 
219
  response = process_openai_response(query)
220
  geojson_data = generate_geojson(response)
221
 
222
+ # Generate the main map image
223
  map_image = generate_static_map(geojson_data)
224
 
225
+ # Generate the empty map using the same bounds
226
+ bounds = get_bounds(geojson_data)
227
+ empty_geojson = {
228
+ "type": "FeatureCollection",
229
+ "features": [] # Empty map contains no features
230
+ }
231
+ empty_map_image = generate_static_map(empty_geojson) # Empty map with the same bounds
232
+
233
+ # Create the mask
234
+ difference = np.abs(np.array(map_image.convert("RGB")) - np.array(empty_map_image.convert("RGB")))
235
+ threshold = 10 # Tolerance for difference
236
+ mask = (np.sum(difference, axis=-1) > threshold).astype(np.uint8) * 255
237
 
238
+ # Convert the mask to a PIL image
239
+ mask_image = Image.fromarray(mask, mode="L")
240
 
241
+ # Generate the satellite image
242
+ satellite_image = generate_satellite_image(
243
+ map_image, mask_image, response['output']['feature_representation']['properties']['description']
244
+ )
245
 
246
  return map_image, satellite_image, mask_image, response
247
 
248
+
249
  # Gradio interface
250
  with gr.Blocks() as demo:
251
  with gr.Row():