Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import json | |
import cv2 | |
import numpy as np | |
import torch | |
from PIL import Image | |
import io | |
import gradio as gr | |
from openai import OpenAI | |
from geopy.geocoders import Nominatim | |
from folium import Map, GeoJson | |
from gradio_folium import Folium | |
from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline | |
import spaces | |
# Initialize APIs | |
openai_client = OpenAI(api_key=os.environ['OPENAI_API_KEY']) | |
geolocator = Nominatim(user_agent="geoapi") | |
# Function to fetch coordinates | |
def get_geo_coordinates(location_name): | |
try: | |
location = geolocator.geocode(location_name) | |
if location: | |
return [location.longitude, location.latitude] | |
return None | |
except Exception as e: | |
print(f"Error fetching coordinates for {location_name}: {e}") | |
return None | |
# Function to process OpenAI chat response | |
def process_openai_response(query): | |
response = openai_client.chat.completions.create( | |
model="gpt-4o-mini", | |
messages=[ | |
{"role": "system", "content": "You are a skilled assistant answering geographical and historical questions..."}, | |
{"role": "user", "content": query} | |
], | |
temperature=1, | |
max_tokens=2048, | |
top_p=1, | |
frequency_penalty=0, | |
presence_penalty=0, | |
response_format={"type": "json_object"} | |
) | |
return json.loads(response.choices[0].message.content) | |
# Generate GeoJSON from OpenAI response | |
def generate_geojson(response): | |
feature_type = response['output']['feature_representation']['type'] | |
city_names = response['output']['feature_representation']['cities'] | |
properties = response['output']['feature_representation']['properties'] | |
coordinates = [] | |
for city in city_names: | |
coord = get_geo_coordinates(city) | |
if coord: | |
coordinates.append(coord) | |
if feature_type == "Polygon": | |
coordinates.append(coordinates[0]) # Close the polygon | |
return { | |
"type": "FeatureCollection", | |
"features": [{ | |
"type": "Feature", | |
"properties": properties, | |
"geometry": { | |
"type": feature_type, | |
"coordinates": [coordinates] if feature_type == "Polygon" else coordinates | |
} | |
}] | |
} | |
# Function to compute bounds from GeoJSON | |
def get_bounds(geojson): | |
coordinates = [] | |
for feature in geojson["features"]: | |
geom_type = feature["geometry"]["type"] | |
coords = feature["geometry"]["coordinates"] | |
if geom_type == "Point": | |
coordinates.append(coords) | |
elif geom_type in ["MultiPoint", "LineString"]: | |
coordinates.extend(coords) | |
elif geom_type in ["MultiLineString", "Polygon"]: | |
for part in coords: | |
coordinates.extend(part) | |
elif geom_type == "MultiPolygon": | |
for polygon in coords: | |
for part in polygon: | |
coordinates.extend(part) | |
lats = [coord[1] for coord in coordinates] | |
lngs = [coord[0] for coord in coordinates] | |
return [[min(lats), min(lngs)], [max(lats), max(lngs)]] | |
# Generate map image in memory | |
def generate_map_image(geojson_data): | |
m = Map() | |
geo_layer = GeoJson(geojson_data, name="Feature map") | |
geo_layer.add_to(m) | |
bounds = get_bounds(geojson_data) | |
m.fit_bounds(bounds) | |
img_data = m._to_png(5) | |
return Image.open(io.BytesIO(img_data)) | |
# ControlNet pipeline setup | |
controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16) | |
pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained( | |
"stable-diffusion-v1-5/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16 | |
) | |
# ZeroGPU compatibility | |
pipeline.to('cuda') | |
def make_inpaint_condition(init_image, mask_image): | |
init_image = np.array(init_image.convert("RGB")).astype(np.float32) / 255.0 | |
mask_image = np.array(mask_image.convert("L")).astype(np.float32) / 255.0 | |
assert init_image.shape[0:1] == mask_image.shape[0:1], "image and image_mask must have the same image size" | |
init_image[mask_image > 0.5] = -1.0 # set as masked pixel | |
init_image = np.expand_dims(init_image, 0).transpose(0, 3, 1, 2) | |
init_image = torch.from_numpy(init_image) | |
return init_image | |
def generate_satellite_image(init_image, mask_image, prompt): | |
control_image = make_inpaint_condition(init_image, mask_image) | |
result = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, control_image=control_image) | |
return result.images[0] | |
# Gradio UI | |
def handle_query(query): | |
# Process OpenAI response | |
response = process_openai_response(query) | |
geojson_data = generate_geojson(response) | |
# Generate map image | |
map_image = generate_map_image(geojson_data) | |
# Generate mask for ControlNet | |
empty_map = cv2.cvtColor(np.array(generate_map_image({"type": "FeatureCollection", "features": []})), cv2.COLOR_BGR2GRAY) | |
map_image_array = cv2.cvtColor(np.array(map_image), cv2.COLOR_BGR2GRAY) | |
difference = cv2.absdiff(empty_map, map_image_array) | |
_, mask = cv2.threshold(difference, 15, 255, cv2.THRESH_BINARY) | |
# Convert mask to PIL Image | |
mask_image = Image.fromarray(mask) | |
# Generate satellite image | |
satellite_image = generate_satellite_image(map_image, mask_image, response['output']['feature_representation']['properties']['description']) | |
return map_image, satellite_image | |
# Gradio interface | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
query_input = gr.Textbox(label="Enter Query") | |
submit_btn = gr.Button("Submit") | |
with gr.Row(): | |
map_output = gr.Image(label="Map Visualization") | |
satellite_output = gr.Image(label="Generated Satellite Image") | |
submit_btn.click(handle_query, inputs=[query_input], outputs=[map_output, satellite_output]) | |
if __name__ == "__main__": | |
demo.launch() | |