Suchinthana commited on
Commit
45f7be1
·
1 Parent(s): 6efeffc

Init code add

Browse files
Files changed (2) hide show
  1. app.py +130 -4
  2. requirements.txt +10 -0
app.py CHANGED
@@ -1,7 +1,133 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import os
3
+ import json
4
+ from openai import OpenAI
5
+ from geopy.geocoders import Nominatim
6
+ from folium import Map, GeoJson
7
+ from gradio_folium import Folium
8
+ import cv2
9
+ import numpy as np
10
+ import torch
11
+ from diffusers import ControlNetModel, StableDiffusionControlNetInpaintPipeline
12
+ from PIL import Image
13
+ import io
14
 
15
+ # Initialize APIs
16
+ openai_client = OpenAI(api_key=os.environ['OPENAI_API_KEY'])
17
+ geolocator = Nominatim(user_agent="geoapi")
18
 
19
+ # Function to fetch coordinates
20
+ def get_geo_coordinates(location_name):
21
+ try:
22
+ location = geolocator.geocode(location_name)
23
+ if location:
24
+ return [location.longitude, location.latitude]
25
+ return None
26
+ except Exception as e:
27
+ print(f"Error fetching coordinates for {location_name}: {e}")
28
+ return None
29
+
30
+ # Function to process OpenAI chat response
31
+ def process_openai_response(query):
32
+ response = openai_client.chat.completions.create(
33
+ model="gpt-4o-mini",
34
+ messages=[
35
+ {"role": "system", "content": "You are a skilled assistant answering geographical and historical questions..."},
36
+ {"role": "user", "content": query}
37
+ ],
38
+ temperature=1,
39
+ max_tokens=2048,
40
+ top_p=1,
41
+ frequency_penalty=0,
42
+ presence_penalty=0,
43
+ response_format={"type": "json_object"}
44
+ )
45
+ return json.loads(response.choices[0].message.content)
46
+
47
+ # Generate GeoJSON from OpenAI response
48
+ def generate_geojson(response):
49
+ feature_type = response['output']['feature_representation']['type']
50
+ city_names = response['output']['feature_representation']['cities']
51
+ properties = response['output']['feature_representation']['properties']
52
+
53
+ coordinates = []
54
+ for city in city_names:
55
+ coord = get_geo_coordinates(city)
56
+ if coord:
57
+ coordinates.append(coord)
58
+
59
+ if feature_type == "Polygon":
60
+ coordinates.append(coordinates[0]) # Close the polygon
61
+
62
+ return {
63
+ "type": "FeatureCollection",
64
+ "features": [{
65
+ "type": "Feature",
66
+ "properties": properties,
67
+ "geometry": {
68
+ "type": feature_type,
69
+ "coordinates": [coordinates] if feature_type == "Polygon" else coordinates
70
+ }
71
+ }]
72
+ }
73
+
74
+ # Generate map image
75
+ def save_map_image(geojson_data):
76
+ m = Map()
77
+ geo_layer = GeoJson(geojson_data, name="Feature map")
78
+ geo_layer.add_to(m)
79
+ bounds = get_bounds(geojson_data)
80
+ m.fit_bounds(bounds)
81
+ img_data = m._to_png(5)
82
+ img = Image.open(io.BytesIO(img_data))
83
+ img.save('map_image.png')
84
+ return 'map_image.png'
85
+
86
+ # ControlNet pipeline setup
87
+ controlnet = ControlNetModel.from_pretrained("lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16)
88
+ pipeline = StableDiffusionControlNetInpaintPipeline.from_pretrained(
89
+ "runwayml/stable-diffusion-inpainting", controlnet=controlnet, torch_dtype=torch.float16
90
+ )
91
+ pipeline.enable_model_cpu_offload()
92
+
93
+ def generate_satellite_image(init_image_path, mask_image_path, prompt):
94
+ init_image = Image.open(init_image_path)
95
+ mask_image = Image.open(mask_image_path)
96
+ control_image = make_inpaint_condition(init_image, mask_image)
97
+ result = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, control_image=control_image)
98
+ return result.images[0]
99
+
100
+ # Gradio UI
101
+ def handle_query(query):
102
+ # Process OpenAI response
103
+ response = process_openai_response(query)
104
+ geojson_data = generate_geojson(response)
105
+
106
+ # Save map image
107
+ map_image_path = save_map_image(geojson_data)
108
+
109
+ # Generate mask for ControlNet
110
+ empty_map = cv2.imread("empty_map_image.png")
111
+ map_image = cv2.imread(map_image_path)
112
+ difference = cv2.absdiff(cv2.cvtColor(empty_map, cv2.COLOR_BGR2GRAY), cv2.cvtColor(map_image, cv2.COLOR_BGR2GRAY))
113
+ _, mask = cv2.threshold(difference, 15, 255, cv2.THRESH_BINARY)
114
+ cv2.imwrite("mask.png", mask)
115
+
116
+ # Generate satellite image
117
+ satellite_image = generate_satellite_image("map_image.png", "mask.png", response['output']['feature_representation']['properties']['description'])
118
+
119
+ return map_image_path, satellite_image
120
+
121
+ # Gradio interface
122
+ with gr.Blocks() as demo:
123
+ with gr.Row():
124
+ query_input = gr.Textbox(label="Enter Query")
125
+ submit_btn = gr.Button("Submit")
126
+ with gr.Row():
127
+ map_output = gr.Image(label="Map Visualization")
128
+ satellite_output = gr.Image(label="Generated Satellite Image")
129
+
130
+ submit_btn.click(handle_query, inputs=[query_input], outputs=[map_output, satellite_output])
131
+
132
+ if __name__ == "__main__":
133
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ openai # For interacting with OpenAI API
2
+ gradio # For creating the Gradio UI
3
+ gradio-folium # For embedding Folium maps into Gradio
4
+ folium # For creating maps
5
+ geopy # For fetching geolocation data
6
+ torch # For PyTorch (used by Diffusers and ControlNet)
7
+ diffusers # For the Stable Diffusion inpainting pipeline
8
+ opencv-python-headless # For image processing with OpenCV
9
+ Pillow # For working with images
10
+ numpy # For numerical operations