Spaces:
Runtime error
Runtime error
import torch.multiprocessing | |
import torchvision.transforms as T | |
from utils import transform_to_pil | |
def inference(image, model): | |
# tensorize & normalize img | |
preprocess = T.Compose( | |
[ | |
T.ToPILImage(), | |
T.Resize((320, 320)), | |
# T.CenterCrop(224), | |
T.ToTensor(), | |
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
] | |
) | |
# Preprocess opened img | |
x = preprocess(image) | |
# launch inference on cpu | |
x = torch.unsqueeze(x, dim=0).cpu() | |
model = model.cpu() | |
with torch.no_grad(): | |
feats, code = model.net(x) | |
linear_pred = model.linear_probe(x, code) | |
linear_pred = linear_pred.argmax(1) | |
output = { | |
"img": x[: model.cfg.n_images].detach().cpu(), | |
"linear_preds": linear_pred[: model.cfg.n_images].detach().cpu(), | |
} | |
img, label, labeled_img = transform_to_pil(output) | |
return img, labeled_img, label | |
if __name__ == "__main__": | |
import hydra | |
from model import LitUnsupervisedSegmenter | |
from utils_gee import extract_img, transform_ee_img | |
latitude = 2.98 | |
longitude = 48.81 | |
start_date = '2020-03-20' | |
end_date = '2020-04-20' | |
location = [float(latitude), float(longitude)] | |
# Extract img numpy from earth engine and transform it to PIL img | |
img = extract_img(location, start_date, end_date) | |
image = transform_ee_img( | |
img, max=0.3 | |
) # max value is the value from numpy file that will be equal to 255 | |
print("image loaded") | |
# Initialize hydra with configs | |
hydra.initialize(config_path="configs", job_name="corine") | |
cfg = hydra.compose(config_name="my_train_config.yml") | |
# Load the model | |
model_path = "checkpoint/model/model.pt" | |
saved_state_dict = torch.load(model_path, map_location=torch.device("cpu")) | |
nbclasses = cfg.dir_dataset_n_classes | |
model = LitUnsupervisedSegmenter(nbclasses, cfg) | |
print("model initialized") | |
model.load_state_dict(saved_state_dict) | |
print("model loaded") | |
# img.save("output/image.png") | |
img, labeled_img, label = inference(image, model) | |
img.save("output/img.png") | |
label.save("output/label.png") | |
labeled_img.save("output/labeled_img.png") | |
# def get_list_date(start_date, end_date): | |
# """Get all the date between the start date and the end date | |
# Args: | |
# start_date (str): start date at the format '%Y-%m-%d' | |
# end_date (str): end date at the format '%Y-%m-%d' | |
# Returns: | |
# list[str]: all the date between the start date and the end date | |
# """ | |
# start_date = datetime.datetime.strptime(start_date, "%Y-%m-%d").date() | |
# end_date = datetime.datetime.strptime(end_date, "%Y-%m-%d").date() | |
# list_date = [start_date] | |
# date = start_date | |
# while date < end_date: | |
# date = date + datetime.timedelta(days=1) | |
# list_date.append(date) | |
# list_date.append(end_date) | |
# list_date2 = [x.strftime("%Y-%m-%d") for x in list_date] | |
# return list_date2 | |
# def get_length_interval(start_date, end_date): | |
# """Return how many days there is between the start date and the end date | |
# Args: | |
# start_date (str): start date at the format '%Y-%m-%d' | |
# end_date (str): end date at the format '%Y-%m-%d' | |
# Returns: | |
# int : number of days between start date and the end date | |
# """ | |
# try: | |
# return len(get_list_date(start_date, end_date)) | |
# except ValueError: | |
# return 0 | |
# def infer_unique_date(latitude, longitude, date, model=model): | |
# """Perform an inference on a latitude and a longitude at a specific date | |
# Args: | |
# latitude (float): the latitude of the landscape | |
# longitude (float): the longitude of the landscape | |
# date (str): date for the inference at the format '%Y-%m-%d' | |
# model (_type_, optional): _description_. Defaults to model. | |
# Returns: | |
# img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape | |
# """ | |
# start_date = date | |
# end_date = date | |
# location = [float(latitude), float(longitude)] | |
# # Extract img numpy from earth engine and transform it to PIL img | |
# img = extract_img(location, start_date, end_date) | |
# img_test = transform_ee_img( | |
# img, max=0.3 | |
# ) # max value is the value from numpy file that will be equal to 255 | |
# # tensorize & normalize img | |
# preprocess = T.Compose( | |
# [ | |
# T.ToPILImage(), | |
# T.Resize((320, 320)), | |
# # T.CenterCrop(224), | |
# T.ToTensor(), | |
# T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
# ] | |
# ) | |
# # Preprocess opened img | |
# x = preprocess(img_test) | |
# # launch inference on cpu | |
# x = torch.unsqueeze(x, dim=0).cpu() | |
# model = model.cpu() | |
# with torch.no_grad(): | |
# feats, code = model.net(x) | |
# linear_pred = model.linear_probe(x, code) | |
# linear_pred = linear_pred.argmax(1) | |
# output = { | |
# "img": x[: model.cfg.n_images].detach().cpu(), | |
# "linear_preds": linear_pred[: model.cfg.n_images].detach().cpu(), | |
# } | |
# img, label, labeled_img = transform_to_pil(output) | |
# biodiv_score = compute_biodiv_score(labeled_img) | |
# return img, labeled_img, biodiv_score | |
# def get_img_array(start_date, end_date, latitude, longitude, model=model): | |
# list_date = get_list_date(start_date, end_date) | |
# list_img = [] | |
# for date in list_date: | |
# list_img.append(img) | |
# return list_img | |
# def variable_outputs(start_date, end_date, latitude, longitude, day, model=model): | |
# """Perform an inference on the day number day starting from the start at the latitude and longitude selected | |
# Args: | |
# latitude (float): the latitude of the landscape | |
# longitude (float): the longitude of the landscape | |
# start_date (str): the start date for our inference | |
# end_date (str): the end date for our inference | |
# model (_type_, optional): _description_. Defaults to model. | |
# Returns: | |
# img,labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape at the selected, longitude, latitude and date | |
# """ | |
# list_date = get_list_date(start_date, end_date) | |
# k = int(day) | |
# date = list_date[k] | |
# img, labeled_img, biodiv_score = infer_unique_date( | |
# latitude, longitude, date, model=model | |
# ) | |
# return img, labeled_img, biodiv_score | |
# def variable_outputs2( | |
# start_date, end_date, latitude, longitude, day_number, model=model | |
# ): | |
# """Perform an inference on the day number day starting from the start at the latitude and longitude selected | |
# Args: | |
# latitude (float): the latitude of the landscape | |
# longitude (float): the longitude of the landscape | |
# start_date (str): the start date for our inference | |
# end_date (str): the end date for our inference | |
# model (_type_, optional): _description_. Defaults to model. | |
# Returns: | |
# list[img,labeled_img,biodiv_score]: the original landscape, the labeled landscape and the biodiversity score and the landscape at the selected, longitude, latitude and date | |
# """ | |
# list_date = get_list_date(start_date, end_date) | |
# k = int(day_number) | |
# date = list_date[k] | |
# img, labeled_img, biodiv_score = infer_unique_date( | |
# latitude, longitude, date, model=model | |
# ) | |
# return [img, labeled_img, biodiv_score] | |
# def gif_maker(img_array): | |
# output_file = "test2.mkv" | |
# image_test = img_array[0] | |
# size = (320, 320) | |
# print(size) | |
# out = cv2.VideoWriter( | |
# output_file, cv2.VideoWriter_fourcc(*"avc1"), 15, frameSize=size | |
# ) | |
# for i in range(len(img_array)): | |
# image = img_array[i] | |
# pix = np.array(image.getdata()) | |
# out.write(pix) | |
# out.release() | |
# return output_file | |
# def infer_multiple_date(start_date, end_date, latitude, longitude, model=model): | |
# """Perform an inference on all the dates between the start date and the end date at the latitude and longitude | |
# Args: | |
# latitude (float): the latitude of the landscape | |
# longitude (float): the longitude of the landscape | |
# start_date (str): the start date for our inference | |
# end_date (str): the end date for our inference | |
# model (_type_, optional): _description_. Defaults to model. | |
# Returns: | |
# list_img,list_labeled_img,list_biodiv_score: list of the original landscape, the labeled landscape and the biodiversity score and the landscape | |
# """ | |
# list_date = get_list_date(start_date, end_date) | |
# list_img = [] | |
# list_labeled_img = [] | |
# list_biodiv_score = [] | |
# for date in list_date: | |
# img, labeled_img, biodiv_score = infer_unique_date( | |
# latitude, longitude, date, model=model | |
# ) | |
# list_img.append(img) | |
# list_labeled_img.append(labeled_img) | |
# list_biodiv_score.append(biodiv_score) | |
# return gif_maker(list_img), gif_maker(list_labeled_img), list_biodiv_score[0] | |