Spaces:
Runtime error
Runtime error
File size: 6,522 Bytes
5c718d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import torch.multiprocessing
import torchvision.transforms as T
import numpy as np
from utils import transform_to_pil, create_video
from utils_gee import extract_img, transform_ee_img
from dateutil.relativedelta import relativedelta
import datetime
from dateutil.relativedelta import relativedelta
import cv2
from joblib import Parallel, cpu_count, delayed
def get_image(location, d1, d2):
print(f"getting image for {d1} to {d2}")
try:
img = extract_img(location, d1, d2)
img_test = transform_ee_img(
img, max=0.3
)
return img_test
except Exception as err:
print(err)
return
def inference_on_location(model, latitude = 2.98, longitude = 48.81, start_date=2020, end_date=2022):
"""Performe an inference on the latitude and longitude between the start date and the end date
Args:
latitude (float): the latitude of the landscape
longitude (float): the longitude of the landscape
start_date (str): the start date for our inference
end_date (str): the end date for our inference
model (_type_, optional): _description_. Defaults to model.
Returns:
img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape
"""
assert end_date > start_date, "end date must be stricly higher than start date"
location = [float(latitude), float(longitude)]
# Extract img numpy from earth engine and transform it to PIL img
dates = [datetime.datetime(start_date, 1, 1, 0, 0, 0)]
while dates[-1] < datetime.datetime(end_date, 1, 1, 0, 0, 0):
dates.append(dates[-1] + relativedelta(months=1))
dates = [d.strftime("%Y-%m-%d") for d in dates]
all_image = Parallel(n_jobs=cpu_count(), prefer="threads")(delayed(get_image)(location, d1,d2) for d1, d2 in zip(dates[:-1],dates[1:]))
all_image = [image for image in all_image if image is not None]
# tensorize & normalize img
preprocess = T.Compose(
[
T.ToPILImage(),
T.Resize((320, 320)),
# T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Preprocess opened img
x = torch.stack([preprocess(imag) for imag in all_image]).cpu()
# launch inference on cpu
# x = torch.unsqueeze(x, dim=0).cpu()
model = model.cpu()
with torch.no_grad():
feats, code = model.net(x)
linear_pred = model.linear_probe(x, code)
linear_pred = linear_pred.argmax(1)
outputs = [{
"img": torch.unsqueeze(img, dim=0).detach().cpu(),
"linear_preds": torch.unsqueeze(linear_pred, dim=0).detach().cpu(),
} for img, linear_pred in zip(x, linear_pred)]
all_img = []
all_label = []
all_labeled_img = []
for output in outputs:
img, label, labeled_img = transform_to_pil(output)
all_img.append(img)
all_label.append(label)
all_labeled_img.append(labeled_img)
all_labeled_img = [np.array(pil_image)[:, :, ::-1] for pil_image in all_labeled_img]
create_video(all_labeled_img, output_path='output/output.mp4')
# all_labeled_img = [np.array(pil_image)[:, :, ::-1] for pil_image in all_img]
# create_video(all_labeled_img, output_path='raw.mp4')
return 'output.mp4'
def inference_on_location_and_month(model, latitude = 2.98, longitude = 48.81, start_date = '2020-03-20'):
"""Performe an inference on the latitude and longitude between the start date and the end date
Args:
latitude (float): the latitude of the landscape
longitude (float): the longitude of the landscape
start_date (str): the start date for our inference
end_date (str): the end date for our inference
model (_type_, optional): _description_. Defaults to model.
Returns:
img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape
"""
location = [float(latitude), float(longitude)]
# Extract img numpy from earth engine and transform it to PIL img
end_date = datetime.datetime.strptime(start_date, "%Y-%m-%d") + relativedelta(months=1)
end_date = datetime.datetime.strftime(end_date, "%Y-%m-%d")
img = extract_img(location, start_date, end_date)
img_test = transform_ee_img(
img, max=0.3
) # max value is the value from numpy file that will be equal to 255
# tensorize & normalize img
preprocess = T.Compose(
[
T.ToPILImage(),
T.Resize((320, 320)),
# T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Preprocess opened img
x = preprocess(img_test)
# launch inference on cpu
x = torch.unsqueeze(x, dim=0).cpu()
model = model.cpu()
with torch.no_grad():
feats, code = model.net(x)
linear_pred = model.linear_probe(x, code)
linear_pred = linear_pred.argmax(1)
output = {
"img": x[: model.cfg.n_images].detach().cpu(),
"linear_preds": linear_pred[: model.cfg.n_images].detach().cpu(),
}
nb_values = []
for i in range(7):
nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1))
scores_init = [2,3,4,3,1,4,0]
score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init)
img, label, labeled_img = transform_to_pil(output)
return img, labeled_img,score
if __name__ == "__main__":
import logging
import hydra
from model import LitUnsupervisedSegmenter
logging.basicConfig(filename='example.log', encoding='utf-8', level=logging.INFO)
# Initialize hydra with configs
hydra.initialize(config_path="configs", job_name="corine")
cfg = hydra.compose(config_name="my_train_config.yml")
logging.info(f"config : {cfg}")
# Load the model
nbclasses = cfg.dir_dataset_n_classes
model = LitUnsupervisedSegmenter(nbclasses, cfg)
logging.info(f"Model Initialiazed")
model_path = "checkpoint/model/model.pt"
saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
logging.info(f"Model weights Loaded")
model.load_state_dict(saved_state_dict)
logging.info(f"Model Loaded")
inference_on_location(model)
|