Biomap / biomap /helper.py
jeremyLE-Ekimetrics's picture
first commit
5c718d1
raw
history blame
6.52 kB
import torch.multiprocessing
import torchvision.transforms as T
import numpy as np
from utils import transform_to_pil, create_video
from utils_gee import extract_img, transform_ee_img
from dateutil.relativedelta import relativedelta
import datetime
from dateutil.relativedelta import relativedelta
import cv2
from joblib import Parallel, cpu_count, delayed
def get_image(location, d1, d2):
print(f"getting image for {d1} to {d2}")
try:
img = extract_img(location, d1, d2)
img_test = transform_ee_img(
img, max=0.3
)
return img_test
except Exception as err:
print(err)
return
def inference_on_location(model, latitude = 2.98, longitude = 48.81, start_date=2020, end_date=2022):
"""Performe an inference on the latitude and longitude between the start date and the end date
Args:
latitude (float): the latitude of the landscape
longitude (float): the longitude of the landscape
start_date (str): the start date for our inference
end_date (str): the end date for our inference
model (_type_, optional): _description_. Defaults to model.
Returns:
img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape
"""
assert end_date > start_date, "end date must be stricly higher than start date"
location = [float(latitude), float(longitude)]
# Extract img numpy from earth engine and transform it to PIL img
dates = [datetime.datetime(start_date, 1, 1, 0, 0, 0)]
while dates[-1] < datetime.datetime(end_date, 1, 1, 0, 0, 0):
dates.append(dates[-1] + relativedelta(months=1))
dates = [d.strftime("%Y-%m-%d") for d in dates]
all_image = Parallel(n_jobs=cpu_count(), prefer="threads")(delayed(get_image)(location, d1,d2) for d1, d2 in zip(dates[:-1],dates[1:]))
all_image = [image for image in all_image if image is not None]
# tensorize & normalize img
preprocess = T.Compose(
[
T.ToPILImage(),
T.Resize((320, 320)),
# T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Preprocess opened img
x = torch.stack([preprocess(imag) for imag in all_image]).cpu()
# launch inference on cpu
# x = torch.unsqueeze(x, dim=0).cpu()
model = model.cpu()
with torch.no_grad():
feats, code = model.net(x)
linear_pred = model.linear_probe(x, code)
linear_pred = linear_pred.argmax(1)
outputs = [{
"img": torch.unsqueeze(img, dim=0).detach().cpu(),
"linear_preds": torch.unsqueeze(linear_pred, dim=0).detach().cpu(),
} for img, linear_pred in zip(x, linear_pred)]
all_img = []
all_label = []
all_labeled_img = []
for output in outputs:
img, label, labeled_img = transform_to_pil(output)
all_img.append(img)
all_label.append(label)
all_labeled_img.append(labeled_img)
all_labeled_img = [np.array(pil_image)[:, :, ::-1] for pil_image in all_labeled_img]
create_video(all_labeled_img, output_path='output/output.mp4')
# all_labeled_img = [np.array(pil_image)[:, :, ::-1] for pil_image in all_img]
# create_video(all_labeled_img, output_path='raw.mp4')
return 'output.mp4'
def inference_on_location_and_month(model, latitude = 2.98, longitude = 48.81, start_date = '2020-03-20'):
"""Performe an inference on the latitude and longitude between the start date and the end date
Args:
latitude (float): the latitude of the landscape
longitude (float): the longitude of the landscape
start_date (str): the start date for our inference
end_date (str): the end date for our inference
model (_type_, optional): _description_. Defaults to model.
Returns:
img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape
"""
location = [float(latitude), float(longitude)]
# Extract img numpy from earth engine and transform it to PIL img
end_date = datetime.datetime.strptime(start_date, "%Y-%m-%d") + relativedelta(months=1)
end_date = datetime.datetime.strftime(end_date, "%Y-%m-%d")
img = extract_img(location, start_date, end_date)
img_test = transform_ee_img(
img, max=0.3
) # max value is the value from numpy file that will be equal to 255
# tensorize & normalize img
preprocess = T.Compose(
[
T.ToPILImage(),
T.Resize((320, 320)),
# T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Preprocess opened img
x = preprocess(img_test)
# launch inference on cpu
x = torch.unsqueeze(x, dim=0).cpu()
model = model.cpu()
with torch.no_grad():
feats, code = model.net(x)
linear_pred = model.linear_probe(x, code)
linear_pred = linear_pred.argmax(1)
output = {
"img": x[: model.cfg.n_images].detach().cpu(),
"linear_preds": linear_pred[: model.cfg.n_images].detach().cpu(),
}
nb_values = []
for i in range(7):
nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1))
scores_init = [2,3,4,3,1,4,0]
score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init)
img, label, labeled_img = transform_to_pil(output)
return img, labeled_img,score
if __name__ == "__main__":
import logging
import hydra
from model import LitUnsupervisedSegmenter
logging.basicConfig(filename='example.log', encoding='utf-8', level=logging.INFO)
# Initialize hydra with configs
hydra.initialize(config_path="configs", job_name="corine")
cfg = hydra.compose(config_name="my_train_config.yml")
logging.info(f"config : {cfg}")
# Load the model
nbclasses = cfg.dir_dataset_n_classes
model = LitUnsupervisedSegmenter(nbclasses, cfg)
logging.info(f"Model Initialiazed")
model_path = "checkpoint/model/model.pt"
saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
logging.info(f"Model weights Loaded")
model.load_state_dict(saved_state_dict)
logging.info(f"Model Loaded")
inference_on_location(model)