Suchinthana commited on
Commit
a52b051
·
1 Parent(s): 3e8f3e6

Update for ZeroGPU

Browse files
Files changed (2) hide show
  1. app.py +24 -17
  2. requirements.txt +1 -0
app.py CHANGED
@@ -1,22 +1,24 @@
1
- import gradio as gr
2
  import os
3
  import json
4
- from openai import OpenAI
5
- from geopy.geocoders import Nominatim
6
- from folium import Map, GeoJson
7
- from gradio_folium import Folium
8
  import cv2
9
  import numpy as np
10
  import torch
11
- from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
12
  from PIL import Image
13
  import io
 
 
 
 
 
 
 
14
 
15
  # Initialize APIs
16
  openai_client = OpenAI(api_key=os.environ['OPENAI_API_KEY'])
17
  geolocator = Nominatim(user_agent="geoapi")
18
 
19
  # Function to fetch coordinates
 
20
  def get_geo_coordinates(location_name):
21
  try:
22
  location = geolocator.geocode(location_name)
@@ -28,6 +30,7 @@ def get_geo_coordinates(location_name):
28
  return None
29
 
30
  # Function to process OpenAI chat response
 
31
  def process_openai_response(query):
32
  response = openai_client.chat.completions.create(
33
  model="gpt-4o-mini",
@@ -45,17 +48,18 @@ def process_openai_response(query):
45
  return json.loads(response.choices[0].message.content)
46
 
47
  # Generate GeoJSON from OpenAI response
 
48
  def generate_geojson(response):
49
  feature_type = response['output']['feature_representation']['type']
50
  city_names = response['output']['feature_representation']['cities']
51
  properties = response['output']['feature_representation']['properties']
52
-
53
  coordinates = []
54
  for city in city_names:
55
  coord = get_geo_coordinates(city)
56
  if coord:
57
  coordinates.append(coord)
58
-
59
  if feature_type == "Polygon":
60
  coordinates.append(coordinates[0]) # Close the polygon
61
 
@@ -71,8 +75,8 @@ def generate_geojson(response):
71
  }]
72
  }
73
 
74
-
75
  # Function to compute bounds from GeoJSON
 
76
  def get_bounds(geojson):
77
  coordinates = []
78
  for feature in geojson["features"]:
@@ -94,6 +98,7 @@ def get_bounds(geojson):
94
  return [[min(lats), min(lngs)], [max(lats), max(lngs)]]
95
 
96
  # Generate map image
 
97
  def save_map_image(geojson_data):
98
  m = Map()
99
  geo_layer = GeoJson(geojson_data, name="Feature map")
@@ -110,10 +115,10 @@ controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpai
110
  pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
111
  "stable-diffusion-v1-5/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
112
  )
113
- #pipeline.enable_model_cpu_offload()
114
  pipeline.to('cuda')
115
 
116
-
117
  def make_inpaint_condition(init_image, mask_image):
118
  init_image = np.array(init_image.convert("RGB")).astype(np.float32) / 255.0
119
  mask_image = np.array(mask_image.convert("L")).astype(np.float32) / 255.0
@@ -124,6 +129,7 @@ def make_inpaint_condition(init_image, mask_image):
124
  init_image = torch.from_numpy(init_image)
125
  return init_image
126
 
 
127
  def generate_satellite_image(init_image_path, mask_image_path, prompt):
128
  init_image = Image.open(init_image_path)
129
  mask_image = Image.open(mask_image_path)
@@ -132,24 +138,25 @@ def generate_satellite_image(init_image_path, mask_image_path, prompt):
132
  return result.images[0]
133
 
134
  # Gradio UI
 
135
  def handle_query(query):
136
  # Process OpenAI response
137
  response = process_openai_response(query)
138
  geojson_data = generate_geojson(response)
139
-
140
  # Save map image
141
  map_image_path = save_map_image(geojson_data)
142
-
143
  # Generate mask for ControlNet
144
  empty_map = cv2.imread("empty_map_image.png")
145
  map_image = cv2.imread(map_image_path)
146
  difference = cv2.absdiff(cv2.cvtColor(empty_map, cv2.COLOR_BGR2GRAY), cv2.cvtColor(map_image, cv2.COLOR_BGR2GRAY))
147
  _, mask = cv2.threshold(difference, 15, 255, cv2.THRESH_BINARY)
148
  cv2.imwrite("mask.png", mask)
149
-
150
  # Generate satellite image
151
  satellite_image = generate_satellite_image("map_image.png", "mask.png", response['output']['feature_representation']['properties']['description'])
152
-
153
  return map_image_path, satellite_image
154
 
155
  # Gradio interface
@@ -160,8 +167,8 @@ with gr.Blocks() as demo:
160
  with gr.Row():
161
  map_output = gr.Image(label="Map Visualization")
162
  satellite_output = gr.Image(label="Generated Satellite Image")
163
-
164
  submit_btn.click(handle_query, inputs=[query_input], outputs=[map_output, satellite_output])
165
 
166
  if __name__ == "__main__":
167
- demo.launch()
 
 
1
  import os
2
  import json
 
 
 
 
3
  import cv2
4
  import numpy as np
5
  import torch
 
6
  from PIL import Image
7
  import io
8
+ import gradio as gr
9
+ from openai import OpenAI
10
+ from geopy.geocoders import Nominatim
11
+ from folium import Map, GeoJson
12
+ from gradio_folium import Folium
13
+ from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
14
+ import spaces
15
 
16
  # Initialize APIs
17
  openai_client = OpenAI(api_key=os.environ['OPENAI_API_KEY'])
18
  geolocator = Nominatim(user_agent="geoapi")
19
 
20
  # Function to fetch coordinates
21
+ @spaces.GPU
22
  def get_geo_coordinates(location_name):
23
  try:
24
  location = geolocator.geocode(location_name)
 
30
  return None
31
 
32
  # Function to process OpenAI chat response
33
+ @spaces.GPU
34
  def process_openai_response(query):
35
  response = openai_client.chat.completions.create(
36
  model="gpt-4o-mini",
 
48
  return json.loads(response.choices[0].message.content)
49
 
50
  # Generate GeoJSON from OpenAI response
51
+ @spaces.GPU
52
  def generate_geojson(response):
53
  feature_type = response['output']['feature_representation']['type']
54
  city_names = response['output']['feature_representation']['cities']
55
  properties = response['output']['feature_representation']['properties']
56
+
57
  coordinates = []
58
  for city in city_names:
59
  coord = get_geo_coordinates(city)
60
  if coord:
61
  coordinates.append(coord)
62
+
63
  if feature_type == "Polygon":
64
  coordinates.append(coordinates[0]) # Close the polygon
65
 
 
75
  }]
76
  }
77
 
 
78
  # Function to compute bounds from GeoJSON
79
+ @spaces.GPU
80
  def get_bounds(geojson):
81
  coordinates = []
82
  for feature in geojson["features"]:
 
98
  return [[min(lats), min(lngs)], [max(lats), max(lngs)]]
99
 
100
  # Generate map image
101
+ @spaces.GPU
102
  def save_map_image(geojson_data):
103
  m = Map()
104
  geo_layer = GeoJson(geojson_data, name="Feature map")
 
115
  pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
116
  "stable-diffusion-v1-5/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
117
  )
118
+ # ZeroGPU compatibility
119
  pipeline.to('cuda')
120
 
121
+ @spaces.GPU
122
  def make_inpaint_condition(init_image, mask_image):
123
  init_image = np.array(init_image.convert("RGB")).astype(np.float32) / 255.0
124
  mask_image = np.array(mask_image.convert("L")).astype(np.float32) / 255.0
 
129
  init_image = torch.from_numpy(init_image)
130
  return init_image
131
 
132
+ @spaces.GPU
133
  def generate_satellite_image(init_image_path, mask_image_path, prompt):
134
  init_image = Image.open(init_image_path)
135
  mask_image = Image.open(mask_image_path)
 
138
  return result.images[0]
139
 
140
  # Gradio UI
141
+ @spaces.GPU
142
  def handle_query(query):
143
  # Process OpenAI response
144
  response = process_openai_response(query)
145
  geojson_data = generate_geojson(response)
146
+
147
  # Save map image
148
  map_image_path = save_map_image(geojson_data)
149
+
150
  # Generate mask for ControlNet
151
  empty_map = cv2.imread("empty_map_image.png")
152
  map_image = cv2.imread(map_image_path)
153
  difference = cv2.absdiff(cv2.cvtColor(empty_map, cv2.COLOR_BGR2GRAY), cv2.cvtColor(map_image, cv2.COLOR_BGR2GRAY))
154
  _, mask = cv2.threshold(difference, 15, 255, cv2.THRESH_BINARY)
155
  cv2.imwrite("mask.png", mask)
156
+
157
  # Generate satellite image
158
  satellite_image = generate_satellite_image("map_image.png", "mask.png", response['output']['feature_representation']['properties']['description'])
159
+
160
  return map_image_path, satellite_image
161
 
162
  # Gradio interface
 
167
  with gr.Row():
168
  map_output = gr.Image(label="Map Visualization")
169
  satellite_output = gr.Image(label="Generated Satellite Image")
170
+
171
  submit_btn.click(handle_query, inputs=[query_input], outputs=[map_output, satellite_output])
172
 
173
  if __name__ == "__main__":
174
+ demo.launch()
requirements.txt CHANGED
@@ -6,6 +6,7 @@ geopy # For fetching geolocation data # For PyTorch (used by Diffuser
6
  numpy # For numerical operations
7
  diffusers
8
  transformers
 
9
  torchvision
10
  opencv-python
11
  torch
 
6
  numpy # For numerical operations
7
  diffusers
8
  transformers
9
+ spaces
10
  torchvision
11
  opencv-python
12
  torch