Spaces:
Runtime error
Runtime error
Update backend.py
Browse files- backend.py +24 -1
backend.py
CHANGED
@@ -82,6 +82,17 @@ def building_model(building_version_dropdown, building_pth_dropdown, building_th
|
|
82 |
building_predictor = DefaultPredictor(building_cfg)
|
83 |
return building_predictor
|
84 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
# A function that runs the buildings model on an given image and confidence threshold
|
86 |
def segment_building(im, building_predictor):
|
87 |
outputs = building_predictor(im)
|
@@ -96,6 +107,13 @@ def segment_tree(im, tree_predictor):
|
|
96 |
|
97 |
return tree_instances
|
98 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
# Function to map strings to color mode
|
100 |
def map_color_mode(color_mode):
|
101 |
if color_mode == "Black/white":
|
@@ -122,7 +140,7 @@ def get_metadata(dataset_name, coco_file):
|
|
122 |
metadata.thing_classes = [c["name"] for c in categories]
|
123 |
return metadata
|
124 |
|
125 |
-
def visualize_image(im, mode, tree_threshold, building_threshold, color_mode, tree_version, tree_pth, building_version, building_pth):
|
126 |
im = np.array(im)
|
127 |
color_mode = map_color_mode(color_mode)
|
128 |
|
@@ -138,6 +156,11 @@ def visualize_image(im, mode, tree_threshold, building_threshold, color_mode, tr
|
|
138 |
building_instances = load_instances(im, building_predictor, segment_building)
|
139 |
instances = building_instances if mode == "Buildings" else combine_instances(instances, building_instances)
|
140 |
|
|
|
|
|
|
|
|
|
|
|
141 |
# Assuming 'urban-small_train' is intended for both Trees and Buildings
|
142 |
metadata = get_metadata("urban-small_train", "building_model_weight/_annotations.coco.json")
|
143 |
visualizer = Visualizer(im[:, :, ::-1], metadata=metadata, scale=0.5, instance_mode=color_mode)
|
|
|
82 |
building_predictor = DefaultPredictor(building_cfg)
|
83 |
return building_predictor
|
84 |
|
85 |
+
# Model for LCZs
|
86 |
+
def tree_model(lcz_version_dropdown, lcz_pth_dropdown, lcz_threshold, device="cpu"):
|
87 |
+
lcz_cfg = get_cfg()
|
88 |
+
lcz_cfg.merge_from_file(get_version_cfg_yml("lcz_model_weights/lczs_cfg.yaml"))
|
89 |
+
lcz_cfg.MODEL.DEVICE=device
|
90 |
+
lcz_cfg.MODEL.WEIGHTS = f"tree_model_weights/{tree_pth_dropdown}"
|
91 |
+
lcz_cfg.MODEL.ROI_HEADS.NUM_CLASSES = 14 # TODO change this
|
92 |
+
lcz_cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = lcz_threshold
|
93 |
+
lcz_predictor = DefaultPredictor(lcz_cfg)
|
94 |
+
return lcz_predictor
|
95 |
+
|
96 |
# A function that runs the buildings model on an given image and confidence threshold
|
97 |
def segment_building(im, building_predictor):
|
98 |
outputs = building_predictor(im)
|
|
|
107 |
|
108 |
return tree_instances
|
109 |
|
110 |
+
# A function that runs the trees model on an given image and confidence threshold
|
111 |
+
def segment_lcz(im, lcz_predictor):
|
112 |
+
outputs = lcz_predictor(im)
|
113 |
+
lcz_instances = outputs["instances"].to("cpu")
|
114 |
+
|
115 |
+
return lcz_instances
|
116 |
+
|
117 |
# Function to map strings to color mode
|
118 |
def map_color_mode(color_mode):
|
119 |
if color_mode == "Black/white":
|
|
|
140 |
metadata.thing_classes = [c["name"] for c in categories]
|
141 |
return metadata
|
142 |
|
143 |
+
def visualize_image(im, mode, tree_threshold, building_threshold, color_mode, tree_version, tree_pth, building_version, building_pth, lcz_version, lcz_pth):
|
144 |
im = np.array(im)
|
145 |
color_mode = map_color_mode(color_mode)
|
146 |
|
|
|
156 |
building_instances = load_instances(im, building_predictor, segment_building)
|
157 |
instances = building_instances if mode == "Buildings" else combine_instances(instances, building_instances)
|
158 |
|
159 |
+
if mode in {"LCZ", "Both"}:
|
160 |
+
lcz_predictor = load_predictor(lcz_model, lcz_version, lcz_pth, lcz_threshold)
|
161 |
+
lcz_instances = load_instances(im, lcz_predictor, segment_lcz)
|
162 |
+
instances = lcz_instances if mode == "LCZ" else combine_instances(instances, LCZ_instances)
|
163 |
+
|
164 |
# Assuming 'urban-small_train' is intended for both Trees and Buildings
|
165 |
metadata = get_metadata("urban-small_train", "building_model_weight/_annotations.coco.json")
|
166 |
visualizer = Visualizer(im[:, :, ::-1], metadata=metadata, scale=0.5, instance_mode=color_mode)
|