blumenstiel
commited on
Commit
Β·
540601d
1
Parent(s):
58a73c7
Switched to config.json
Browse files
app.py
CHANGED
@@ -10,8 +10,8 @@ from huggingface_hub import hf_hub_download
|
|
10 |
|
11 |
# pull files from hub
|
12 |
token = os.environ.get("HF_TOKEN", None)
|
13 |
-
|
14 |
-
filename="
|
15 |
checkpoint = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
|
16 |
filename='Prithvi_EO_V2_300M_TL.pt', token=token)
|
17 |
model_def = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
|
@@ -67,7 +67,7 @@ def extract_rgb_imgs(input_img, rec_img, mask_img, channels, mean, std):
|
|
67 |
return outputs
|
68 |
|
69 |
|
70 |
-
def predict_on_images(data_files: list,
|
71 |
try:
|
72 |
data_files = [x.name for x in data_files]
|
73 |
print('Path extracted from example')
|
@@ -77,18 +77,17 @@ def predict_on_images(data_files: list, yaml_file_path: str, checkpoint: str, ma
|
|
77 |
# Get parameters --------
|
78 |
print('This is the printout', data_files)
|
79 |
|
80 |
-
with open(
|
81 |
-
config = yaml.safe_load(f)
|
82 |
|
83 |
batch_size = 8
|
84 |
-
bands = config['
|
85 |
num_frames = len(data_files)
|
86 |
-
mean = config['
|
87 |
-
std = config['
|
88 |
-
coords_encoding = config['
|
89 |
-
img_size = config['
|
90 |
-
|
91 |
-
mask_ratio = mask_ratio or config['DATA']['MASK_RATIO']
|
92 |
|
93 |
assert num_frames <= 4, "Demo only supports up to four timestamps"
|
94 |
|
@@ -110,21 +109,12 @@ def predict_on_images(data_files: list, yaml_file_path: str, checkpoint: str, ma
|
|
110 |
|
111 |
# Create model and load checkpoint -------------------------------------------------------------
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
num_heads=config['MODEL']['NUM_HEADS'],
|
120 |
-
decoder_embed_dim=config['MODEL']['DECODER_EMBED_DIM'],
|
121 |
-
decoder_depth=config['MODEL']['DECODER_DEPTH'],
|
122 |
-
decoder_num_heads=config['MODEL']['DECODER_NUM_HEADS'],
|
123 |
-
mlp_ratio=config['MODEL']['MLP_RATIO'],
|
124 |
-
norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
|
125 |
-
norm_pix_loss=config['MODEL']['NORM_PIX_LOSS'],
|
126 |
-
coords_encoding=coords_encoding,
|
127 |
-
coords_scale_learn=config['MODEL']['COORDS_SCALE_LEARN'])
|
128 |
|
129 |
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
130 |
print(f"\n--> Model has {total_params:,} parameters.\n")
|
@@ -196,7 +186,7 @@ def predict_on_images(data_files: list, yaml_file_path: str, checkpoint: str, ma
|
|
196 |
return outputs
|
197 |
|
198 |
|
199 |
-
run_inference = partial(predict_on_images,
|
200 |
|
201 |
with gr.Blocks() as demo:
|
202 |
|
|
|
10 |
|
11 |
# pull files from hub
|
12 |
token = os.environ.get("HF_TOKEN", None)
|
13 |
+
config_path = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
|
14 |
+
filename="config.json", token=token)
|
15 |
checkpoint = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
|
16 |
filename='Prithvi_EO_V2_300M_TL.pt', token=token)
|
17 |
model_def = hf_hub_download(repo_id="ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL",
|
|
|
67 |
return outputs
|
68 |
|
69 |
|
70 |
+
def predict_on_images(data_files: list, config_path: str, checkpoint: str, mask_ratio: float = None):
|
71 |
try:
|
72 |
data_files = [x.name for x in data_files]
|
73 |
print('Path extracted from example')
|
|
|
77 |
# Get parameters --------
|
78 |
print('This is the printout', data_files)
|
79 |
|
80 |
+
with open(config_path, 'r') as f:
|
81 |
+
config = yaml.safe_load(f)['pretrained_cfg']
|
82 |
|
83 |
batch_size = 8
|
84 |
+
bands = config['bands']
|
85 |
num_frames = len(data_files)
|
86 |
+
mean = config['mean']
|
87 |
+
std = config['std']
|
88 |
+
coords_encoding = config['coords_encoding']
|
89 |
+
img_size = config['img_size']
|
90 |
+
mask_ratio = mask_ratio or config['mask_ratio']
|
|
|
91 |
|
92 |
assert num_frames <= 4, "Demo only supports up to four timestamps"
|
93 |
|
|
|
109 |
|
110 |
# Create model and load checkpoint -------------------------------------------------------------
|
111 |
|
112 |
+
config.update(
|
113 |
+
num_frames=num_frames,
|
114 |
+
coords_encoding=coords_encoding,
|
115 |
+
)
|
116 |
+
|
117 |
+
model = PrithviMAE(**config)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
118 |
|
119 |
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
120 |
print(f"\n--> Model has {total_params:,} parameters.\n")
|
|
|
186 |
return outputs
|
187 |
|
188 |
|
189 |
+
run_inference = partial(predict_on_images, config_path=config_path,checkpoint=checkpoint)
|
190 |
|
191 |
with gr.Blocks() as demo:
|
192 |
|