Suchinthana commited on
Commit
36e6906
·
1 Parent(s): e3bfe63

in memory update

Browse files
Files changed (1) hide show
  1. app.py +25 -73
app.py CHANGED
@@ -13,9 +13,6 @@ from gradio_folium import Folium
13
  from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
14
  import spaces
15
 
16
- from selenium.webdriver.firefox.options import Options
17
- from selenium import webdriver
18
-
19
  # Initialize APIs
20
  openai_client = OpenAI(api_key=os.environ['OPENAI_API_KEY'])
21
  geolocator = Nominatim(user_agent="geoapi")
@@ -38,51 +35,15 @@ def process_openai_response(query):
38
  response = openai_client.chat.completions.create(
39
  model="gpt-4o-mini",
40
  messages=[
41
- {
42
- "role": "system",
43
- "content": [
44
- {
45
- "type": "text",
46
- "text": "\"input\": \"\"\"You are a skilled assistant answering geographical and historical questions. For each question, generate a structured output in JSON format, based on city names without coordinates. The response should include:\
47
- Answer: A concise response to the question.\
48
- Feature Representation: A feature type based on city names (Point, LineString, Polygon, MultiPoint, MultiLineString, MultiPolygon, GeometryCollection).\
49
- Description: A prompt for a diffusion model describing the what should we draw regarding that.\
50
- \
51
- Handle the following cases:\
52
- \
53
- 1. **Single or Multiple Points**: Create a point or a list of points for multiple cities.\
54
- 2. **LineString**: Create a line between two cities.\
55
- 3. **Polygon**: Represent an area formed by three or more cities (closed). Example: Cities forming a triangle (A, B, C).\
56
- 4. **MultiPoint, MultiLineString, MultiPolygon, GeometryCollection**: Use as needed based on the question.\
57
- \
58
- For example, if asked about cities forming a polygon, create a feature like this:\
59
- \
60
- Input: Mark an area with three cities.\
61
- Output: {\"input\": \"Mark an area with three cities.\", \"output\": {\"answer\": \"The cities A, B, and C form a triangle.\", \"feature_representation\": {\"type\": \"Polygon\", \"cities\": [\"A\", \"B\", \"C\"], \"properties\": {\"description\": \"satelite image of a plantation, green fill, 4k, map, detailed, greenary, plants, vegitation, high contrast\"}}}}\
62
- \
63
- Ensure all responses are descriptive and relevant to city names only, without coordinates.\
64
- \"}\"}"
65
- }
66
- ]
67
- },
68
- {
69
- "role": "user",
70
- "content": [
71
- {
72
- "type": "text",
73
- "text": query
74
- }
75
- ]
76
- }
77
- ],
78
- temperature=1,
79
- max_tokens=2048,
80
- top_p=1,
81
- frequency_penalty=0,
82
- presence_penalty=0,
83
- response_format={
84
- "type": "json_object"
85
- }
86
  )
87
  return json.loads(response.choices[0].message.content)
88
 
@@ -136,25 +97,16 @@ def get_bounds(geojson):
136
  lngs = [coord[0] for coord in coordinates]
137
  return [[min(lats), min(lngs)], [max(lats), max(lngs)]]
138
 
139
- # Generate map image
140
  @spaces.GPU
141
- def save_map_image(geojson_data):
142
  m = Map()
143
  geo_layer = GeoJson(geojson_data, name="Feature map")
144
  geo_layer.add_to(m)
145
  bounds = get_bounds(geojson_data)
146
  m.fit_bounds(bounds)
147
-
148
- # Configure Selenium for headless operation
149
- options = Options()
150
- options.add_argument("--headless") # Enable headless mode
151
- driver = webdriver.Firefox(options=options) # Ensure GeckoDriver is properly installed and in PATH
152
- img_data = m._to_png(5, driver=driver)
153
- driver.quit()
154
-
155
- img = Image.open(io.BytesIO(img_data))
156
- img.save('map_image.png')
157
- return 'map_image.png'
158
 
159
  # ControlNet pipeline setup
160
  controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16)
@@ -176,9 +128,7 @@ def make_inpaint_condition(init_image, mask_image):
176
  return init_image
177
 
178
  @spaces.GPU
179
- def generate_satellite_image(init_image_path, mask_image_path, prompt):
180
- init_image = Image.open(init_image_path)
181
- mask_image = Image.open(mask_image_path)
182
  control_image = make_inpaint_condition(init_image, mask_image)
183
  result = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, control_image=control_image)
184
  return result.images[0]
@@ -190,20 +140,22 @@ def handle_query(query):
190
  response = process_openai_response(query)
191
  geojson_data = generate_geojson(response)
192
 
193
- # Save map image
194
- map_image_path = save_map_image(geojson_data)
195
 
196
  # Generate mask for ControlNet
197
- empty_map = cv2.imread("empty_map_image.png")
198
- map_image = cv2.imread(map_image_path)
199
- difference = cv2.absdiff(cv2.cvtColor(empty_map, cv2.COLOR_BGR2GRAY), cv2.cvtColor(map_image, cv2.COLOR_BGR2GRAY))
200
  _, mask = cv2.threshold(difference, 15, 255, cv2.THRESH_BINARY)
201
- cv2.imwrite("mask.png", mask)
 
 
202
 
203
  # Generate satellite image
204
- satellite_image = generate_satellite_image("map_image.png", "mask.png", response['output']['feature_representation']['properties']['description'])
205
 
206
- return map_image_path, satellite_image
207
 
208
  # Gradio interface
209
  with gr.Blocks() as demo:
@@ -217,4 +169,4 @@ with gr.Blocks() as demo:
217
  submit_btn.click(handle_query, inputs=[query_input], outputs=[map_output, satellite_output])
218
 
219
  if __name__ == "__main__":
220
- demo.launch()
 
13
  from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
14
  import spaces
15
 
 
 
 
16
  # Initialize APIs
17
  openai_client = OpenAI(api_key=os.environ['OPENAI_API_KEY'])
18
  geolocator = Nominatim(user_agent="geoapi")
 
35
  response = openai_client.chat.completions.create(
36
  model="gpt-4o-mini",
37
  messages=[
38
+ {"role": "system", "content": "You are a skilled assistant answering geographical and historical questions..."},
39
+ {"role": "user", "content": query}
40
+ ],
41
+ temperature=1,
42
+ max_tokens=2048,
43
+ top_p=1,
44
+ frequency_penalty=0,
45
+ presence_penalty=0,
46
+ response_format={"type": "json_object"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  )
48
  return json.loads(response.choices[0].message.content)
49
 
 
97
  lngs = [coord[0] for coord in coordinates]
98
  return [[min(lats), min(lngs)], [max(lats), max(lngs)]]
99
 
100
+ # Generate map image in memory
101
  @spaces.GPU
102
+ def generate_map_image(geojson_data):
103
  m = Map()
104
  geo_layer = GeoJson(geojson_data, name="Feature map")
105
  geo_layer.add_to(m)
106
  bounds = get_bounds(geojson_data)
107
  m.fit_bounds(bounds)
108
+ img_data = m._to_png(5)
109
+ return Image.open(io.BytesIO(img_data))
 
 
 
 
 
 
 
 
 
110
 
111
  # ControlNet pipeline setup
112
  controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16)
 
128
  return init_image
129
 
130
  @spaces.GPU
131
+ def generate_satellite_image(init_image, mask_image, prompt):
 
 
132
  control_image = make_inpaint_condition(init_image, mask_image)
133
  result = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, control_image=control_image)
134
  return result.images[0]
 
140
  response = process_openai_response(query)
141
  geojson_data = generate_geojson(response)
142
 
143
+ # Generate map image
144
+ map_image = generate_map_image(geojson_data)
145
 
146
  # Generate mask for ControlNet
147
+ empty_map = cv2.cvtColor(np.array(generate_map_image({"type": "FeatureCollection", "features": []})), cv2.COLOR_BGR2GRAY)
148
+ map_image_array = cv2.cvtColor(np.array(map_image), cv2.COLOR_BGR2GRAY)
149
+ difference = cv2.absdiff(empty_map, map_image_array)
150
  _, mask = cv2.threshold(difference, 15, 255, cv2.THRESH_BINARY)
151
+
152
+ # Convert mask to PIL Image
153
+ mask_image = Image.fromarray(mask)
154
 
155
  # Generate satellite image
156
+ satellite_image = generate_satellite_image(map_image, mask_image, response['output']['feature_representation']['properties']['description'])
157
 
158
+ return map_image, satellite_image
159
 
160
  # Gradio interface
161
  with gr.Blocks() as demo:
 
169
  submit_btn.click(handle_query, inputs=[query_input], outputs=[map_output, satellite_output])
170
 
171
  if __name__ == "__main__":
172
+ demo.launch()