File size: 6,522 Bytes
5c718d1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
import torch.multiprocessing
import torchvision.transforms as T
import numpy as np
from utils import transform_to_pil, create_video
from utils_gee import extract_img, transform_ee_img
from dateutil.relativedelta import relativedelta
import datetime
from dateutil.relativedelta import relativedelta
import cv2

from joblib import Parallel, cpu_count, delayed

def get_image(location, d1, d2):
    print(f"getting image for {d1} to {d2}")
    try:
        img = extract_img(location, d1, d2)
        img_test = transform_ee_img(
            img, max=0.3
        )
        return img_test
    except Exception as err:
        print(err)
        return 
        

def inference_on_location(model, latitude = 2.98, longitude = 48.81, start_date=2020, end_date=2022):
    """Performe an inference on the latitude and longitude between the start date and the end date

    Args:
        latitude (float): the latitude of the landscape
        longitude (float): the longitude of the landscape
        start_date (str): the start date for our inference
        end_date (str): the end date for our inference
        model (_type_, optional): _description_. Defaults to model.

    Returns:            
        img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape
    """
    assert end_date > start_date, "end date must be stricly higher than start date"
    location = [float(latitude), float(longitude)]
    
    # Extract img numpy from earth engine and transform it to PIL img
    dates = [datetime.datetime(start_date, 1, 1, 0, 0, 0)]
    while dates[-1] < datetime.datetime(end_date, 1, 1, 0, 0, 0):
        dates.append(dates[-1] + relativedelta(months=1))
    
    dates = [d.strftime("%Y-%m-%d") for d in dates]

    all_image = Parallel(n_jobs=cpu_count(), prefer="threads")(delayed(get_image)(location, d1,d2) for d1, d2 in zip(dates[:-1],dates[1:]))
    all_image = [image for image in all_image if image is not None]

    
        

    # tensorize & normalize img
    preprocess = T.Compose(
        [
            T.ToPILImage(),
            T.Resize((320, 320)),
            #    T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

    # Preprocess opened img
    x = torch.stack([preprocess(imag) for imag in all_image]).cpu()

    # launch inference on cpu
    # x = torch.unsqueeze(x, dim=0).cpu()
    model = model.cpu()

    with torch.no_grad():
        feats, code = model.net(x)
        linear_pred = model.linear_probe(x, code)
        linear_pred = linear_pred.argmax(1)
        outputs = [{
            "img": torch.unsqueeze(img, dim=0).detach().cpu(),
            "linear_preds": torch.unsqueeze(linear_pred, dim=0).detach().cpu(),
        } for img, linear_pred in zip(x, linear_pred)]
    all_img = []
    all_label = []
    all_labeled_img = []
    for output in outputs:
        img, label, labeled_img = transform_to_pil(output)
        all_img.append(img)
        all_label.append(label)
        all_labeled_img.append(labeled_img)

    all_labeled_img = [np.array(pil_image)[:, :, ::-1] for pil_image in all_labeled_img]
    create_video(all_labeled_img, output_path='output/output.mp4')
    
    # all_labeled_img = [np.array(pil_image)[:, :, ::-1] for pil_image in all_img]
    # create_video(all_labeled_img, output_path='raw.mp4')
    
    return 'output.mp4'

def inference_on_location_and_month(model, latitude = 2.98, longitude = 48.81, start_date = '2020-03-20'):
    """Performe an inference on the latitude and longitude between the start date and the end date

    Args:
        latitude (float): the latitude of the landscape
        longitude (float): the longitude of the landscape
        start_date (str): the start date for our inference
        end_date (str): the end date for our inference
        model (_type_, optional): _description_. Defaults to model.

    Returns:
        img, labeled_img,biodiv_score: the original landscape, the labeled landscape and the biodiversity score and the landscape
    """
    location = [float(latitude), float(longitude)]
    
    # Extract img numpy from earth engine and transform it to PIL img
    end_date = datetime.datetime.strptime(start_date, "%Y-%m-%d") + relativedelta(months=1)
    end_date = datetime.datetime.strftime(end_date, "%Y-%m-%d")
    img = extract_img(location, start_date, end_date)
    img_test = transform_ee_img(
        img, max=0.3
    )  # max value is the value from numpy file that will be equal to 255

    # tensorize & normalize img
    preprocess = T.Compose(
        [
            T.ToPILImage(),
            T.Resize((320, 320)),
            #    T.CenterCrop(224),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ]
    )

    # Preprocess opened img
    x = preprocess(img_test)

    # launch inference on cpu
    x = torch.unsqueeze(x, dim=0).cpu()
    model = model.cpu()

    with torch.no_grad():
        feats, code = model.net(x)
        linear_pred = model.linear_probe(x, code)
        linear_pred = linear_pred.argmax(1)
        output = {
            "img": x[: model.cfg.n_images].detach().cpu(),
            "linear_preds": linear_pred[: model.cfg.n_images].detach().cpu(),
        }
    nb_values = []
    for i in range(7):
        nb_values.append(np.count_nonzero(output['linear_preds'][0] == i+1))
    scores_init = [2,3,4,3,1,4,0]
    score = sum(x * y for x, y in zip(scores_init, nb_values)) / sum(nb_values) / max(scores_init)

    img, label, labeled_img = transform_to_pil(output)
    return img, labeled_img,score


if __name__ == "__main__":
    import logging
    import hydra


    from model import LitUnsupervisedSegmenter
    logging.basicConfig(filename='example.log', encoding='utf-8', level=logging.INFO)
    # Initialize hydra with configs
    hydra.initialize(config_path="configs", job_name="corine")
    cfg = hydra.compose(config_name="my_train_config.yml")
    logging.info(f"config : {cfg}")
    # Load the model

    nbclasses = cfg.dir_dataset_n_classes
    model = LitUnsupervisedSegmenter(nbclasses, cfg)
    logging.info(f"Model Initialiazed")
    
    model_path = "checkpoint/model/model.pt"
    saved_state_dict = torch.load(model_path, map_location=torch.device("cpu"))
    logging.info(f"Model weights Loaded")
    model.load_state_dict(saved_state_dict)
    logging.info(f"Model Loaded")
    inference_on_location(model)