Biomap / biomap /inference.py
jeremyLE-Ekimetrics's picture
first commit
5c718d1
raw
history blame
9.55 kB
import torch.multiprocessing
import torchvision.transforms as T
from utils import transform_to_pil
def inference(image, model):
# tensorize & normalize img
preprocess = T.Compose(
[
T.ToPILImage(),
T.Resize((320, 320)),
# T.CenterCrop(224),
T.ToTensor(),
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
# Preprocess opened img
x = preprocess(image)
# launch inference on cpu
x = torch.unsqueeze(x, dim=0).cpu()
model = model.cpu()
with torch.no_grad():
feats, code = model.net(x)
linear_pred = model.linear_probe(x, code)
linear_pred = linear_pred.argmax(1)
output = {
"img": x[: model.cfg.n_images].detach().cpu(),
"linear_preds": linear_pred[: model.cfg.n_images].detach().cpu(),
}
img, label, labeled_img = transform_to_pil(output)
return img, labeled_img, label
if __name__ == "__main__":
import hydra
from model import LitUnsupervisedSegmenter
from utils_gee import extract_img, transform_ee_img
latitude = 2.98
longitude = 48.81
start_date = '2020-03-20'
end_date = '2020-04-20'
location = [float(latitude), float(longitude)]
# Extract img numpy from earth engine and transform it to PIL img
img = extract_img(location, start_date, end_date)
image = transform_ee_img(
img, max=0.3
) # max value is the value from numpy file that will be equal to 255
print("image loaded")
# Initialize hydra with configs
hydra.initialize(config_path="configs", job_name="corine")
cfg = hydra.compose(config_name="my_train_config.yml")
# Load the model
model_path = "checkpoint/model/model.pt"
saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
nbclasses = cfg.dir_dataset_n_classes
model = LitUnsupervisedSegmenter(nbclasses, cfg)
print("model initialized")
model.load_state_dict(saved_state_dict)
print("model loaded")
# img.save("output/image.png")
img, labeled_img, label = inference(image, model)
img.save("output/img.png")
label.save("output/label.png")
labeled_img.save("output/labeled_img.png")
# def get_list_date(start_date, end_date):
# """Get all the date between the start date and the end date
# Args:
# start_date (str): start date at the format '%Y-%m-%d'
# end_date (str): end date at the format '%Y-%m-%d'
# Returns:
# list[str]: all the date between the start date and the end date
# """
# start_date = datetime.datetime.strptime(start_date, "%Y-%m-%d").date()
# end_date = datetime.datetime.strptime(end_date, "%Y-%m-%d").date()
# list_date = [start_date]
# date = start_date
# while date < end_date:
# date = date + datetime.timedelta(days=1)
# list_date.append(date)
# list_date.append(end_date)
# list_date2 = [x.strftime("%Y-%m-%d") for x in list_date]
# return list_date2
# def get_length_interval(start_date, end_date):
# """Return how many days there is between the start date and the end date
# Args:
# start_date (str): start date at the format '%Y-%m-%d'
# end_date (str): end date at the format '%Y-%m-%d'
# Returns:
# int : number of days between start date and the end date
# """
# try:
# return len(get_list_date(start_date, end_date))
# except ValueError:
# return 0
# def infer_unique_date(latitude, longitude, date, model=model):
# """Perform an inference on a latitude and a longitude at a specific date
# Args:
# latitude (float): the latitude of the landscape
# longitude (float): the longitude of the landscape
# date (str): date for the inference at the format '%Y-%m-%d'
# model (_type_, optional): _description_. Defaults to model.
# Returns:
# img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape
# """
# start_date = date
# end_date = date
# location = [float(latitude), float(longitude)]
# # Extract img numpy from earth engine and transform it to PIL img
# img = extract_img(location, start_date, end_date)
# img_test = transform_ee_img(
# img, max=0.3
# ) # max value is the value from numpy file that will be equal to 255
# # tensorize & normalize img
# preprocess = T.Compose(
# [
# T.ToPILImage(),
# T.Resize((320, 320)),
# # T.CenterCrop(224),
# T.ToTensor(),
# T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
# ]
# )
# # Preprocess opened img
# x = preprocess(img_test)
# # launch inference on cpu
# x = torch.unsqueeze(x, dim=0).cpu()
# model = model.cpu()
# with torch.no_grad():
# feats, code = model.net(x)
# linear_pred = model.linear_probe(x, code)
# linear_pred = linear_pred.argmax(1)
# output = {
# "img": x[: model.cfg.n_images].detach().cpu(),
# "linear_preds": linear_pred[: model.cfg.n_images].detach().cpu(),
# }
# img, label, labeled_img = transform_to_pil(output)
# biodiv_score = compute_biodiv_score(labeled_img)
# return img, labeled_img, biodiv_score
# def get_img_array(start_date, end_date, latitude, longitude, model=model):
# list_date = get_list_date(start_date, end_date)
# list_img = []
# for date in list_date:
# list_img.append(img)
# return list_img
# def variable_outputs(start_date, end_date, latitude, longitude, day, model=model):
# """Perform an inference on the day number day starting from the start at the latitude and longitude selected
# Args:
# latitude (float): the latitude of the landscape
# longitude (float): the longitude of the landscape
# start_date (str): the start date for our inference
# end_date (str): the end date for our inference
# model (_type_, optional): _description_. Defaults to model.
# Returns:
# img,labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape at the selected, longitude, latitude and date
# """
# list_date = get_list_date(start_date, end_date)
# k = int(day)
# date = list_date[k]
# img, labeled_img, biodiv_score = infer_unique_date(
# latitude, longitude, date, model=model
# )
# return img, labeled_img, biodiv_score
# def variable_outputs2(
# start_date, end_date, latitude, longitude, day_number, model=model
# ):
# """Perform an inference on the day number day starting from the start at the latitude and longitude selected
# Args:
# latitude (float): the latitude of the landscape
# longitude (float): the longitude of the landscape
# start_date (str): the start date for our inference
# end_date (str): the end date for our inference
# model (_type_, optional): _description_. Defaults to model.
# Returns:
# list[img,labeled_img,biodiv_score]: the original landscape, the labeled landscape and the biodiversity score and the landscape at the selected, longitude, latitude and date
# """
# list_date = get_list_date(start_date, end_date)
# k = int(day_number)
# date = list_date[k]
# img, labeled_img, biodiv_score = infer_unique_date(
# latitude, longitude, date, model=model
# )
# return [img, labeled_img, biodiv_score]
# def gif_maker(img_array):
# output_file = "test2.mkv"
# image_test = img_array[0]
# size = (320, 320)
# print(size)
# out = cv2.VideoWriter(
# output_file, cv2.VideoWriter_fourcc(*"avc1"), 15, frameSize=size
# )
# for i in range(len(img_array)):
# image = img_array[i]
# pix = np.array(image.getdata())
# out.write(pix)
# out.release()
# return output_file
# def infer_multiple_date(start_date, end_date, latitude, longitude, model=model):
# """Perform an inference on all the dates between the start date and the end date at the latitude and longitude
# Args:
# latitude (float): the latitude of the landscape
# longitude (float): the longitude of the landscape
# start_date (str): the start date for our inference
# end_date (str): the end date for our inference
# model (_type_, optional): _description_. Defaults to model.
# Returns:
# list_img,list_labeled_img,list_biodiv_score: list of the original landscape, the labeled landscape and the biodiversity score and the landscape
# """
# list_date = get_list_date(start_date, end_date)
# list_img = []
# list_labeled_img = []
# list_biodiv_score = []
# for date in list_date:
# img, labeled_img, biodiv_score = infer_unique_date(
# latitude, longitude, date, model=model
# )
# list_img.append(img)
# list_labeled_img.append(labeled_img)
# list_biodiv_score.append(biodiv_score)
# return gif_maker(list_img), gif_maker(list_labeled_img), list_biodiv_score[0]