Spaces:
Runtime error
Runtime error
import torch.multiprocessing | |
import torchvision.transforms as T | |
import numpy as np | |
from utils import transform_to_pil, create_video | |
from utils_gee import extract_img, transform_ee_img | |
from dateutil.relativedelta import relativedelta | |
import datetime | |
from dateutil.relativedelta import relativedelta | |
import cv2 | |
from joblib import Parallel, cpu_count, delayed | |
def get_image(location, d1, d2): | |
print(f"getting image for {d1} to {d2}") | |
try: | |
img = extract_img(location, d1, d2) | |
img_test = transform_ee_img( | |
img, max=0.3 | |
) | |
return img_test | |
except Exception as err: | |
print(err) | |
return | |
def inference_on_location(model, latitude = 2.98, longitude = 48.81, start_date=2020, end_date=2022): | |
"""Performe an inference on the latitude and longitude between the start date and the end date | |
Args: | |
latitude (float): the latitude of the landscape | |
longitude (float): the longitude of the landscape | |
start_date (str): the start date for our inference | |
end_date (str): the end date for our inference | |
model (_type_, optional): _description_. Defaults to model. | |
Returns: | |
img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape | |
""" | |
assert end_date > start_date, "end date must be stricly higher than start date" | |
location = [float(latitude), float(longitude)] | |
# Extract img numpy from earth engine and transform it to PIL img | |
dates = [datetime.datetime(start_date, 1, 1, 0, 0, 0)] | |
while dates[-1] < datetime.datetime(end_date, 1, 1, 0, 0, 0): | |
dates.append(dates[-1] + relativedelta(months=1)) | |
dates = [d.strftime("%Y-%m-%d") for d in dates] | |
all_image = Parallel(n_jobs=cpu_count(), prefer="threads")(delayed(get_image)(location, d1,d2) for d1, d2 in zip(dates[:-1],dates[1:])) | |
all_image = [image for image in all_image if image is not None] | |
# tensorize & normalize img | |
preprocess = T.Compose( | |
[ | |
T.ToPILImage(), | |
T.Resize((320, 320)), | |
# T.CenterCrop(224), | |
T.ToTensor(), | |
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
] | |
) | |
# Preprocess opened img | |
x = torch.stack([preprocess(imag) for imag in all_image]).cpu() | |
# launch inference on cpu | |
# x = torch.unsqueeze(x, dim=0).cpu() | |
model = model.cpu() | |
with torch.no_grad(): | |
feats, code = model.net(x) | |
linear_pred = model.linear_probe(x, code) | |
linear_pred = linear_pred.argmax(1) | |
outputs = [{ | |
"img": torch.unsqueeze(img, dim=0).detach().cpu(), | |
"linear_preds": torch.unsqueeze(linear_pred, dim=0).detach().cpu(), | |
} for img, linear_pred in zip(x, linear_pred)] | |
all_img = [] | |
all_label = [] | |
all_labeled_img = [] | |
for output in outputs: | |
img, label, labeled_img = transform_to_pil(output) | |
all_img.append(img) | |
all_label.append(label) | |
all_labeled_img.append(labeled_img) | |
all_labeled_img = [np.array(pil_image)[:, :, ::-1] for pil_image in all_labeled_img] | |
create_video(all_labeled_img, output_path='output/output.mp4') | |
# all_labeled_img = [np.array(pil_image)[:, :, ::-1] for pil_image in all_img] | |
# create_video(all_labeled_img, output_path='raw.mp4') | |
return 'output.mp4' | |
def inference_on_location_and_month(model, latitude = 2.98, longitude = 48.81, start_date = '2020-03-20'): | |
"""Performe an inference on the latitude and longitude between the start date and the end date | |
Args: | |
latitude (float): the latitude of the landscape | |
longitude (float): the longitude of the landscape | |
start_date (str): the start date for our inference | |
end_date (str): the end date for our inference | |
model (_type_, optional): _description_. Defaults to model. | |
Returns: | |
img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape | |
""" | |
location = [float(latitude), float(longitude)] | |
# Extract img numpy from earth engine and transform it to PIL img | |
end_date = datetime.datetime.strptime(start_date, "%Y-%m-%d") + relativedelta(months=1) | |
end_date = datetime.datetime.strftime(end_date, "%Y-%m-%d") | |
img = extract_img(location, start_date, end_date) | |
img_test = transform_ee_img( | |
img, max=0.3 | |
) # max value is the value from numpy file that will be equal to 255 | |
# tensorize & normalize img | |
preprocess = T.Compose( | |
[ | |
T.ToPILImage(), | |
T.Resize((320, 320)), | |
# T.CenterCrop(224), | |
T.ToTensor(), | |
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
] | |
) | |
# Preprocess opened img | |
x = preprocess(img_test) | |
# launch inference on cpu | |
x = torch.unsqueeze(x, dim=0).cpu() | |
model = model.cpu() | |
with torch.no_grad(): | |
feats, code = model.net(x) | |
linear_pred = model.linear_probe(x, code) | |
linear_pred = linear_pred.argmax(1) | |
output = { | |
"img": x[: model.cfg.n_images].detach().cpu(), | |
"linear_preds": linear_pred[: model.cfg.n_images].detach().cpu(), | |
} | |
nb_values = [] | |
for i in range(7): | |
nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1)) | |
scores_init = [2,3,4,3,1,4,0] | |
score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init) | |
img, label, labeled_img = transform_to_pil(output) | |
return img, labeled_img,score | |
if __name__ == "__main__": | |
import logging | |
import hydra | |
from model import LitUnsupervisedSegmenter | |
logging.basicConfig(filename='example.log', encoding='utf-8', level=logging.INFO) | |
# Initialize hydra with configs | |
hydra.initialize(config_path="configs", job_name="corine") | |
cfg = hydra.compose(config_name="my_train_config.yml") | |
logging.info(f"config : {cfg}") | |
# Load the model | |
nbclasses = cfg.dir_dataset_n_classes | |
model = LitUnsupervisedSegmenter(nbclasses, cfg) | |
logging.info(f"Model Initialiazed") | |
model_path = "checkpoint/model/model.pt" | |
saved_state_dict = torch.load(model_path, map_location=torch.device("cpu")) | |
logging.info(f"Model weights Loaded") | |
model.load_state_dict(saved_state_dict) | |
logging.info(f"Model Loaded") | |
inference_on_location(model) | |