Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import json | |
import numpy as np | |
import torch | |
from PIL import Image, ImageDraw | |
import gradio as gr | |
from openai import OpenAI | |
from geopy.geocoders import Nominatim | |
from staticmap import StaticMap, CircleMarker, Polygon | |
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline | |
import spaces | |
import logging | |
import math | |
from typing import List, Union | |
# Set up logging | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# Initialize APIs | |
openai_client = OpenAI(api_key=os.environ['OPENAI_API_KEY']) | |
geolocator = Nominatim(user_agent="geoapi") | |
# Function to fetch coordinates | |
def get_geo_coordinates(location_name): | |
try: | |
location = geolocator.geocode(location_name) | |
if location: | |
return [location.longitude, location.latitude] | |
return None | |
except Exception as e: | |
logger.error(f"Error fetching coordinates for {location_name}: {e}") | |
return None | |
# Function to process OpenAI chat response | |
def process_openai_response(query): | |
response = openai_client.chat.completions.create( | |
model="gpt-4o-mini", | |
messages=[ | |
{ | |
"role": "system", | |
"content": """ | |
You are an assistant that generates structured JSON output for geographical queries with city names. Your task is to generate a JSON object containing information about geographical features and their representation based on the user's query. Follow these rules: | |
1. The JSON should always have the following structure: | |
{ | |
"input": "<user's query>", | |
"output": { | |
"answer": "<concise text answering the query>", | |
"feature_representation": { | |
"type": "<one of: Point, LineString, Polygon, MultiPoint, MultiLineString, MultiPolygon, GeometryCollection>", | |
"cities": ["<list of city names>"], | |
"properties": { | |
"description": "<a prompt for a diffusion model describing the geographical feature>" | |
} | |
} | |
} | |
} | |
2. For the `type` field in `feature_representation`: | |
- Use "Point" for single city queries. | |
- Use "MultiPoint" for queries involving multiple cities not forming a line or area. | |
- Use "LineString" for queries about paths between two or more cities. | |
- Use "Polygon" for queries about areas formed by three or more cities. | |
3. For the `cities` field: | |
- List the names of cities mentioned in the query in the order they appear. | |
- If no cities are mentioned, try to add them with your knowledge. | |
4. For the `properties.description` field: | |
- Describe the geographical feature in a creative way, suitable for generating an image with a diffusion model. | |
### Example Input: | |
"Mark a triangular area of 3 US cities." | |
### Example Output: | |
{ | |
"input": "Mark a triangular area of 3 US cities.", | |
"output": { | |
"answer": "The cities New York, Boston, and Philadelphia form a triangle.", | |
"feature_representation": { | |
"type": "Polygon", | |
"cities": ["New York", "Boston", "Philadelphia"], | |
"properties": { | |
"description": "A satellite image of a triangular area formed by New York, Boston, and Philadelphia, with green fields and urban regions, 4k resolution, highly detailed." | |
} | |
} | |
} | |
} | |
Generate similar JSON for the following query: | |
""" | |
}, | |
{ | |
"role": "user", | |
"content": query | |
} | |
], | |
temperature=1, | |
max_tokens=2048, | |
top_p=1, | |
frequency_penalty=0, | |
presence_penalty=0, | |
response_format={"type": "json_object"} | |
) | |
return json.loads(response.choices[0].message.content) | |
# Generate GeoJSON from OpenAI response | |
def generate_geojson(response): | |
logger.info(f"OpenAI response: {response}") | |
feature_type = response['output']['feature_representation']['type'] | |
city_names = response['output']['feature_representation']['cities'] | |
properties = response['output']['feature_representation']['properties'] | |
coordinates = [] | |
# Fetch coordinates for cities | |
for city in city_names: | |
try: | |
coord = get_geo_coordinates(city) | |
if coord: | |
coordinates.append(coord) | |
else: | |
logger.warning(f"Coordinates not found for city: {city}") | |
except Exception as e: | |
logger.error(f"Error fetching coordinates for {city}: {e}") | |
if feature_type == "Polygon": | |
if len(coordinates) < 3: | |
raise ValueError("Polygon requires at least 3 coordinates.") | |
# Close the polygon by appending the first point at the end | |
coordinates.append(coordinates[0]) | |
coordinates = [coordinates] # Nest coordinates for Polygon | |
# Create the GeoJSON object | |
geojson_data = { | |
"type": "FeatureCollection", | |
"features": [ | |
{ | |
"type": "Feature", | |
"properties": properties, | |
"geometry": { | |
"type": feature_type, | |
"coordinates": coordinates, | |
}, | |
} | |
], | |
} | |
return geojson_data | |
# Sort coordinates for a simple polygon (Reduce intersection points) | |
def sort_coordinates_for_simple_polygon(geojson): | |
# Extract coordinates from the GeoJSON | |
coordinates = geojson['features'][0]['geometry']['coordinates'][0] | |
# Remove the last point if it duplicates the first (GeoJSON convention for polygons) | |
if coordinates[0] == coordinates[-1]: | |
coordinates = coordinates[:-1] | |
# Calculate the centroid of the points | |
centroid_x = sum(point[0] for point in coordinates) / len(coordinates) | |
centroid_y = sum(point[1] for point in coordinates) / len(coordinates) | |
# Define a function to calculate the angle relative to the centroid | |
def angle_from_centroid(point): | |
dx = point[0] - centroid_x | |
dy = point[1] - centroid_y | |
return math.atan2(dy, dx) | |
# Sort points by their angle from the centroid | |
sorted_coordinates = sorted(coordinates, key=angle_from_centroid) | |
# Close the polygon by appending the first point to the end | |
sorted_coordinates.append(sorted_coordinates[0]) | |
# Update the GeoJSON with sorted coordinates | |
geojson['features'][0]['geometry']['coordinates'][0] = sorted_coordinates | |
return geojson | |
# Generate static map image | |
def generate_static_map(geojson_data, invisible=False): | |
m = StaticMap(600, 600) | |
logger.info(f"GeoJSON data: {geojson_data}") | |
for feature in geojson_data["features"]: | |
geom_type = feature["geometry"]["type"] | |
coords = feature["geometry"]["coordinates"] | |
if geom_type == "Point": | |
m.add_marker(CircleMarker((coords[0][0], coords[0][1]), '#1C00ff00' if invisible else '#42445A85', 100)) | |
elif geom_type in ["MultiPoint", "LineString"]: | |
for coord in coords: | |
m.add_marker(CircleMarker((coord[0], coord[1]), '#1C00ff00' if invisible else '#42445A85', 100)) | |
elif geom_type in ["Polygon", "MultiPolygon"]: | |
for polygon in coords: | |
m.add_polygon(Polygon([(c[0], c[1]) for c in polygon], '#1C00ff00' if invisible else '#42445A85', 3)) | |
return m.render() | |
# ControlNet pipeline setup | |
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16) | |
pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained( | |
"stable-diffusion-v1-5/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16 | |
) | |
pipeline.to('cuda') | |
def make_inpaint_condition(init_image, mask_image): | |
init_image = np.array(init_image.convert("RGB")).astype(np.float32) / 255.0 | |
mask_image = np.array(mask_image.convert("L")).astype(np.float32) / 255.0 | |
assert init_image.shape[0:1] == mask_image.shape[0:1], "image and image_mask must have the same image size" | |
init_image[mask_image > 0.5] = -1.0 # set as masked pixel | |
init_image = np.expand_dims(init_image, 0).transpose(0, 3, 1, 2) | |
init_image = torch.from_numpy(init_image) | |
return init_image | |
def generate_satellite_image(init_image, mask_image, prompt): | |
control_image = make_inpaint_condition(init_image, mask_image) | |
result = pipeline( | |
prompt=prompt, | |
image=init_image, | |
mask_image=mask_image, | |
control_image=control_image, | |
strength=0.47, | |
guidance_scale=95, | |
num_inference_steps=250 | |
) | |
return result.images[0] | |
# Gradio UI | |
def handle_query(query): | |
response = process_openai_response(query) | |
geojson_data = generate_geojson(response) | |
if geojson_data["features"][0]["geometry"]["type"] == 'Polygon': | |
geojson_data_coords = sort_coordinates_for_simple_polygon(geojson_data) | |
map_image = generate_static_map(geojson_data_coords) | |
else: | |
map_image = generate_static_map(geojson_data) | |
empty_map_image = generate_static_map(geojson_data, invisible=True) | |
difference = np.abs(np.array(map_image.convert("RGB")) - np.array(empty_map_image.convert("RGB"))) | |
threshold = 10 | |
mask = (np.sum(difference, axis=-1) > threshold).astype(np.uint8) * 255 | |
mask_image = Image.fromarray(mask, mode="L") | |
satellite_image = generate_satellite_image( | |
empty_map_image, mask_image, response['output']['feature_representation']['properties']['description'] | |
) | |
return map_image, satellite_image, empty_map_image, mask_image, response | |
def update_query(selected_query): | |
return selected_query | |
query_options = [ | |
"Area covering south asian subcontinent", | |
"Mark a triangular area using New York, Boston, and Texas", | |
"Mark cities in India", | |
"Show me Lotus Tower in a Map", | |
"Mark the area of west germany", | |
"Mark the area of the Amazon rainforest", | |
"Mark the area of the Sahara desert" | |
] | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
selected_query = gr.Dropdown(label="Select Query", choices=query_options, value=query_options[-1]) | |
query_input = gr.Textbox(label="Enter Query", value=query_options[-1]) | |
selected_query.change(update_query, inputs=selected_query, outputs=query_input) | |
submit_btn = gr.Button("Submit") | |
with gr.Row(): | |
map_output = gr.Image(label="Map Visualization") | |
satellite_output = gr.Image(label="Generated Map Image") | |
with gr.Row(): | |
empty_map_output = gr.Image(label="Empty Visualization") | |
mask_output = gr.Image(label="Mask") | |
image_prompt = gr.Textbox(label="Image Prompt Used") | |
submit_btn.click(handle_query, inputs=[query_input], outputs=[map_output, satellite_output, empty_map_output, mask_output, image_prompt]) | |
if __name__ == "__main__": | |
demo.launch() | |