Paul-Edouard Sarlin commited on
Commit
b0cf684
·
unverified ·
1 Parent(s): e63f4c9

Code formatting (#47)

Browse files

* Add formatting CI
* Apply black and isort
* Fix flake8 errors
* Add bash script

.github/workflows/code-quality.yml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Format and Lint
2
+ on:
3
+ push:
4
+ branches:
5
+ - master
6
+ paths:
7
+ - '*.py'
8
+ - '*.ipynb'
9
+ pull_request:
10
+ types: [ assigned, opened, synchronize, reopened ]
11
+ jobs:
12
+ check:
13
+ runs-on: ubuntu-latest
14
+ steps:
15
+ - uses: actions/checkout@v4
16
+ - uses: actions/setup-python@v5
17
+ with:
18
+ python-version: '3.10'
19
+ cache: 'pip'
20
+ cache-dependency-path: 'requirements/dev.txt'
21
+ - run: python -m pip install --upgrade pip
22
+ - run: python -m pip install -r requirements/dev.txt
23
+ - run: python -m flake8 maploc
24
+ - run: python -m isort maploc *.ipynb --check-only --diff
25
+ - run: python -m black maploc *.ipynb --check --diff
.isort.cfg ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ [settings]
2
+ profile=black
demo.ipynb CHANGED
@@ -14,7 +14,7 @@
14
  "# The highest accuracy is achieved with num_rotations=360\n",
15
  "# but num_rotations=64~128 is often sufficient.\n",
16
  "# To reduce the memory usage, we can reduce the tile size in the next cell.\n",
17
- "demo = Demo(num_rotations=256, device='cpu')"
18
  ]
19
  },
20
  {
@@ -135,6 +135,7 @@
135
  "\n",
136
  "# Show the query area in an interactive map\n",
137
  "from maploc.osm.viz import GeoPlotter\n",
 
138
  "plot = GeoPlotter(zoom=16)\n",
139
  "plot.points(prior_latlon[:2], \"red\", name=\"location prior\", size=10)\n",
140
  "plot.bbox(proj.unproject(bbox), \"blue\", name=\"map tile\")\n",
@@ -142,12 +143,14 @@
142
  "\n",
143
  "# Query OpenStreetMap for this area\n",
144
  "from maploc.osm.tiling import TileManager\n",
 
145
  "tiler = TileManager.from_bbox(proj, bbox + 10, demo.config.data.pixel_per_meter)\n",
146
  "canvas = tiler.query(bbox)\n",
147
  "\n",
148
  "# Show the inputs to the model: image and raster map\n",
149
  "from maploc.osm.viz import Colormap, plot_nodes\n",
150
  "from maploc.utils.viz_2d import plot_images\n",
 
151
  "map_viz = Colormap.apply(canvas.raster)\n",
152
  "plot_images([image, map_viz], titles=[\"input image\", \"OpenStreetMap raster\"])\n",
153
  "plot_nodes(1, canvas.raster[2], fontsize=6, size=10)"
@@ -1186,7 +1189,8 @@
1186
  "\n",
1187
  "# Run the inference\n",
1188
  "uv, yaw, prob, neural_map, image_rectified = demo.localize(\n",
1189
- " image, camera, canvas, roll_pitch=gravity)\n",
 
1190
  "\n",
1191
  "# Visualize the predictions\n",
1192
  "overlay = likelihood_overlay(prob.numpy().max(-1), map_viz.mean(-1, keepdims=True))\n",
@@ -1194,7 +1198,7 @@
1194
  "plot_images([overlay, neural_map_rgb], titles=[\"prediction\", \"neural map\"])\n",
1195
  "ax = plt.gcf().axes[0]\n",
1196
  "ax.scatter(*canvas.to_uv(bbox.center), s=5, c=\"red\")\n",
1197
- "plot_dense_rotations(ax, prob, w=0.005, s=1/25)\n",
1198
  "add_circle_inset(ax, uv)\n",
1199
  "plt.show(\"notebook\")\n",
1200
  "\n",
 
14
  "# The highest accuracy is achieved with num_rotations=360\n",
15
  "# but num_rotations=64~128 is often sufficient.\n",
16
  "# To reduce the memory usage, we can reduce the tile size in the next cell.\n",
17
+ "demo = Demo(num_rotations=256, device=\"cpu\")"
18
  ]
19
  },
20
  {
 
135
  "\n",
136
  "# Show the query area in an interactive map\n",
137
  "from maploc.osm.viz import GeoPlotter\n",
138
+ "\n",
139
  "plot = GeoPlotter(zoom=16)\n",
140
  "plot.points(prior_latlon[:2], \"red\", name=\"location prior\", size=10)\n",
141
  "plot.bbox(proj.unproject(bbox), \"blue\", name=\"map tile\")\n",
 
143
  "\n",
144
  "# Query OpenStreetMap for this area\n",
145
  "from maploc.osm.tiling import TileManager\n",
146
+ "\n",
147
  "tiler = TileManager.from_bbox(proj, bbox + 10, demo.config.data.pixel_per_meter)\n",
148
  "canvas = tiler.query(bbox)\n",
149
  "\n",
150
  "# Show the inputs to the model: image and raster map\n",
151
  "from maploc.osm.viz import Colormap, plot_nodes\n",
152
  "from maploc.utils.viz_2d import plot_images\n",
153
+ "\n",
154
  "map_viz = Colormap.apply(canvas.raster)\n",
155
  "plot_images([image, map_viz], titles=[\"input image\", \"OpenStreetMap raster\"])\n",
156
  "plot_nodes(1, canvas.raster[2], fontsize=6, size=10)"
 
1189
  "\n",
1190
  "# Run the inference\n",
1191
  "uv, yaw, prob, neural_map, image_rectified = demo.localize(\n",
1192
+ " image, camera, canvas, roll_pitch=gravity\n",
1193
+ ")\n",
1194
  "\n",
1195
  "# Visualize the predictions\n",
1196
  "overlay = likelihood_overlay(prob.numpy().max(-1), map_viz.mean(-1, keepdims=True))\n",
 
1198
  "plot_images([overlay, neural_map_rgb], titles=[\"prediction\", \"neural map\"])\n",
1199
  "ax = plt.gcf().axes[0]\n",
1200
  "ax.scatter(*canvas.to_uv(bbox.center), s=5, c=\"red\")\n",
1201
+ "plot_dense_rotations(ax, prob, w=0.005, s=1 / 25)\n",
1202
  "add_circle_inset(ax, uv)\n",
1203
  "plt.show(\"notebook\")\n",
1204
  "\n",
format.sh ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ all_files=$(git ls-tree --full-tree -r --name-only HEAD .)
2
+ py_files=$(echo "$all_files" | grep ".*\.py$")
3
+ nb_files=$(echo "$all_files" | grep ".*\.ipynb$" | grep -v "^notebooks")
4
+ python -m black $py_files $nb_files
5
+ python -m isort $py_files $nb_files
6
+ python -m flake8 $py_files
maploc/__init__.py CHANGED
@@ -1,11 +1,10 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
3
- from pathlib import Path
4
  import logging
 
5
 
6
  import pytorch_lightning # noqa: F401
7
 
8
-
9
  formatter = logging.Formatter(
10
  fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s",
11
  datefmt="%Y-%m-%d %H:%M:%S",
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
 
3
  import logging
4
+ from pathlib import Path
5
 
6
  import pytorch_lightning # noqa: F401
7
 
 
8
  formatter = logging.Formatter(
9
  fmt="[%(asctime)s %(name)s %(levelname)s] %(message)s",
10
  datefmt="%Y-%m-%d %H:%M:%S",
maploc/data/image.py CHANGED
@@ -1,11 +1,11 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
3
- from typing import Callable, Optional, Union, Sequence
 
4
 
5
  import numpy as np
6
  import torch
7
  import torchvision.transforms.functional as tvf
8
- import collections
9
  from scipy.spatial.transform import Rotation
10
 
11
  from ..utils.geometry import from_homogeneous, to_homogeneous
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
3
+ import collections
4
+ from typing import Callable, Optional, Sequence, Union
5
 
6
  import numpy as np
7
  import torch
8
  import torchvision.transforms.functional as tvf
 
9
  from scipy.spatial.transform import Rotation
10
 
11
  from ..utils.geometry import from_homogeneous, to_homogeneous
maploc/data/kitti/dataset.py CHANGED
@@ -13,12 +13,12 @@ import torch.utils.data as torchdata
13
  from omegaconf import OmegaConf
14
  from scipy.spatial.transform import Rotation
15
 
16
- from ... import logger, DATASETS_PATH
17
  from ...osm.tiling import TileManager
18
  from ..dataset import MapLocDataset
19
  from ..sequential import chunk_sequence
20
  from ..torch import collate, worker_init_fn
21
- from .utils import parse_split_file, parse_gps_file, get_camera_calibration
22
 
23
 
24
  class KittiDataModule(pl.LightningDataModule):
 
13
  from omegaconf import OmegaConf
14
  from scipy.spatial.transform import Rotation
15
 
16
+ from ... import DATASETS_PATH, logger
17
  from ...osm.tiling import TileManager
18
  from ..dataset import MapLocDataset
19
  from ..sequential import chunk_sequence
20
  from ..torch import collate, worker_init_fn
21
+ from .utils import get_camera_calibration, parse_gps_file, parse_split_file
22
 
23
 
24
  class KittiDataModule(pl.LightningDataModule):
maploc/data/kitti/prepare.py CHANGED
@@ -1,9 +1,9 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
3
  import argparse
4
- from pathlib import Path
5
  import shutil
6
  import zipfile
 
7
 
8
  import numpy as np
9
  from tqdm.auto import tqdm
@@ -12,9 +12,9 @@ from ... import logger
12
  from ...osm.tiling import TileManager
13
  from ...osm.viz import GeoPlotter
14
  from ...utils.geo import BoundaryBox, Projection
15
- from ...utils.io import download_file, DATA_URL
16
- from .utils import parse_gps_file
17
  from .dataset import KittiDataModule
 
18
 
19
  split_files = ["test1_files.txt", "test2_files.txt", "train_files.txt"]
20
 
@@ -70,7 +70,7 @@ def download(data_dir: Path):
70
  for seq in tqdm(seqs):
71
  logger.info("Working on %s.", seq)
72
  date = "_".join(seq.split("_")[:3])
73
- url = f"https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/{seq}/{seq}_sync.zip"
74
  seq_dir = data_dir / date / f"{seq}_sync"
75
  if seq_dir.exists():
76
  continue
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
3
  import argparse
 
4
  import shutil
5
  import zipfile
6
+ from pathlib import Path
7
 
8
  import numpy as np
9
  from tqdm.auto import tqdm
 
12
  from ...osm.tiling import TileManager
13
  from ...osm.viz import GeoPlotter
14
  from ...utils.geo import BoundaryBox, Projection
15
+ from ...utils.io import DATA_URL, download_file
 
16
  from .dataset import KittiDataModule
17
+ from .utils import parse_gps_file
18
 
19
  split_files = ["test1_files.txt", "test2_files.txt", "train_files.txt"]
20
 
 
70
  for seq in tqdm(seqs):
71
  logger.info("Working on %s.", seq)
72
  date = "_".join(seq.split("_")[:3])
73
+ url = f"https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/{seq}/{seq}_sync.zip" # noqa E501
74
  seq_dir = data_dir / date / f"{seq}_sync"
75
  if seq_dir.exists():
76
  continue
maploc/data/mapillary/dataset.py CHANGED
@@ -1,10 +1,10 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
3
  import json
4
- from collections import defaultdict
5
  import os
6
  import shutil
7
  import tarfile
 
8
  from pathlib import Path
9
  from typing import Any, Dict, Optional
10
 
@@ -14,7 +14,7 @@ import torch
14
  import torch.utils.data as torchdata
15
  from omegaconf import DictConfig, OmegaConf
16
 
17
- from ... import logger, DATASETS_PATH
18
  from ...osm.tiling import TileManager
19
  from ..dataset import MapLocDataset
20
  from ..sequential import chunk_sequence
@@ -115,7 +115,8 @@ class MapillaryDataModule(pl.LightningDataModule):
115
  if self.cfg.num_classes: # check consistency
116
  if set(groups.keys()) != set(self.cfg.num_classes.keys()):
117
  raise ValueError(
118
- f"Inconsistent groups: {groups.keys()} {self.cfg.num_classes.keys()}"
 
119
  )
120
  for k in groups:
121
  if len(groups[k]) != self.cfg.num_classes[k]:
@@ -125,8 +126,8 @@ class MapillaryDataModule(pl.LightningDataModule):
125
  ppm = self.tile_managers[scene].ppm
126
  if ppm != self.cfg.pixel_per_meter:
127
  raise ValueError(
128
- "The tile manager and the config/model have different ground resolutions: "
129
- f"{ppm} vs {self.cfg.pixel_per_meter}"
130
  )
131
 
132
  logger.info("Loading dump json file %s.", self.dump_filename)
@@ -136,7 +137,8 @@ class MapillaryDataModule(pl.LightningDataModule):
136
  for cam_id, cam_dict in per_seq["cameras"].items():
137
  if cam_dict["model"] != "PINHOLE":
138
  raise ValueError(
139
- f"Unsupported camera model: {cam_dict['model']} for {scene},{seq},{cam_id}"
 
140
  )
141
 
142
  self.image_dirs[scene] = (
@@ -154,7 +156,8 @@ class MapillaryDataModule(pl.LightningDataModule):
154
  self.pack_data()
155
 
156
  def pack_data(self):
157
- # We pack the data into compact tensors that can be shared across processes without copy
 
158
  exclude = {
159
  "compass_angle",
160
  "compass_accuracy",
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
3
  import json
 
4
  import os
5
  import shutil
6
  import tarfile
7
+ from collections import defaultdict
8
  from pathlib import Path
9
  from typing import Any, Dict, Optional
10
 
 
14
  import torch.utils.data as torchdata
15
  from omegaconf import DictConfig, OmegaConf
16
 
17
+ from ... import DATASETS_PATH, logger
18
  from ...osm.tiling import TileManager
19
  from ..dataset import MapLocDataset
20
  from ..sequential import chunk_sequence
 
115
  if self.cfg.num_classes: # check consistency
116
  if set(groups.keys()) != set(self.cfg.num_classes.keys()):
117
  raise ValueError(
118
+ "Inconsistent groups: "
119
+ f"{groups.keys()} {self.cfg.num_classes.keys()}"
120
  )
121
  for k in groups:
122
  if len(groups[k]) != self.cfg.num_classes[k]:
 
126
  ppm = self.tile_managers[scene].ppm
127
  if ppm != self.cfg.pixel_per_meter:
128
  raise ValueError(
129
+ "The tile manager and the config/model have different ground "
130
+ f"resolutions: {ppm} vs {self.cfg.pixel_per_meter}"
131
  )
132
 
133
  logger.info("Loading dump json file %s.", self.dump_filename)
 
137
  for cam_id, cam_dict in per_seq["cameras"].items():
138
  if cam_dict["model"] != "PINHOLE":
139
  raise ValueError(
140
+ "Unsupported camera model: "
141
+ f"{cam_dict['model']} for {scene},{seq},{cam_id}"
142
  )
143
 
144
  self.image_dirs[scene] = (
 
156
  self.pack_data()
157
 
158
  def pack_data(self):
159
+ # We pack the data into compact tensors
160
+ # that can be shared across processes without copy.
161
  exclude = {
162
  "compass_angle",
163
  "compass_accuracy",
maploc/data/mapillary/download.py CHANGED
@@ -1,25 +1,24 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
 
3
  import json
4
  from pathlib import Path
5
 
6
- import numpy as np
7
  import httpx
8
- import asyncio
9
- from aiolimiter import AsyncLimiter
10
  import tqdm
11
-
12
  from opensfm.pygeometry import Camera, Pose
13
  from opensfm.pymap import Shot
14
 
15
  from ... import logger
16
  from ...utils.geo import Projection
17
 
18
-
19
  semaphore = asyncio.Semaphore(100) # number of parallel threads.
20
  image_filename = "{image_id}.jpg"
21
  info_filename = "{image_id}.json"
22
 
 
23
  def retry(times, exceptions):
24
  def decorator(func):
25
  async def wrapper(*args, **kwargs):
@@ -30,9 +29,12 @@ def retry(times, exceptions):
30
  except exceptions:
31
  attempt += 1
32
  return await func(*args, **kwargs)
 
33
  return wrapper
 
34
  return decorator
35
 
 
36
  class MapillaryDownloader:
37
  image_fields = (
38
  "id",
@@ -56,7 +58,7 @@ class MapillaryDownloader:
56
  image_info_url = (
57
  "https://graph.mapillary.com/{image_id}?access_token={token}&fields={fields}"
58
  )
59
- seq_info_url = "https://graph.mapillary.com/image_ids?access_token={token}&sequence_id={seq_id}"
60
  max_requests_per_minute = 50_000
61
 
62
  def __init__(self, token: str):
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
3
+ import asyncio
4
  import json
5
  from pathlib import Path
6
 
 
7
  import httpx
8
+ import numpy as np
 
9
  import tqdm
10
+ from aiolimiter import AsyncLimiter
11
  from opensfm.pygeometry import Camera, Pose
12
  from opensfm.pymap import Shot
13
 
14
  from ... import logger
15
  from ...utils.geo import Projection
16
 
 
17
  semaphore = asyncio.Semaphore(100) # number of parallel threads.
18
  image_filename = "{image_id}.jpg"
19
  info_filename = "{image_id}.json"
20
 
21
+
22
  def retry(times, exceptions):
23
  def decorator(func):
24
  async def wrapper(*args, **kwargs):
 
29
  except exceptions:
30
  attempt += 1
31
  return await func(*args, **kwargs)
32
+
33
  return wrapper
34
+
35
  return decorator
36
 
37
+
38
  class MapillaryDownloader:
39
  image_fields = (
40
  "id",
 
58
  image_info_url = (
59
  "https://graph.mapillary.com/{image_id}?access_token={token}&fields={fields}"
60
  )
61
+ seq_info_url = "https://graph.mapillary.com/image_ids?access_token={token}&sequence_id={seq_id}" # noqa E501
62
  max_requests_per_minute = 50_000
63
 
64
  def __init__(self, token: str):
maploc/data/mapillary/prepare.py CHANGED
@@ -1,17 +1,15 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
3
- import asyncio
4
  import argparse
5
- from collections import defaultdict
6
  import json
7
  import shutil
 
8
  from pathlib import Path
9
  from typing import List
10
 
11
- import numpy as np
12
  import cv2
13
- from tqdm import tqdm
14
- from tqdm.contrib.concurrent import thread_map
15
  from omegaconf import DictConfig, OmegaConf
16
  from opensfm.pygeometry import Camera
17
  from opensfm.pymap import Shot
@@ -19,30 +17,31 @@ from opensfm.undistort import (
19
  perspective_camera_from_fisheye,
20
  perspective_camera_from_perspective,
21
  )
 
 
22
 
23
  from ... import logger
24
  from ...osm.tiling import TileManager
25
  from ...osm.viz import GeoPlotter
26
  from ...utils.geo import BoundaryBox, Projection
27
- from ...utils.io import write_json, download_file, DATA_URL
28
  from ..utils import decompose_rotmat
 
 
 
 
 
 
 
 
29
  from .utils import (
 
 
30
  keyframe_selection,
31
  perspective_camera_from_pano,
32
  scale_camera,
33
- CameraUndistorter,
34
- PanoramaUndistorter,
35
  undistort_shot,
36
  )
37
- from .download import (
38
- MapillaryDownloader,
39
- opensfm_shot_from_info,
40
- image_filename,
41
- fetch_image_infos,
42
- fetch_images_pixels,
43
- )
44
- from .dataset import MapillaryDataModule
45
-
46
 
47
  location_to_params = {
48
  "sanfrancisco_soma": {
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
 
3
  import argparse
4
+ import asyncio
5
  import json
6
  import shutil
7
+ from collections import defaultdict
8
  from pathlib import Path
9
  from typing import List
10
 
 
11
  import cv2
12
+ import numpy as np
 
13
  from omegaconf import DictConfig, OmegaConf
14
  from opensfm.pygeometry import Camera
15
  from opensfm.pymap import Shot
 
17
  perspective_camera_from_fisheye,
18
  perspective_camera_from_perspective,
19
  )
20
+ from tqdm import tqdm
21
+ from tqdm.contrib.concurrent import thread_map
22
 
23
  from ... import logger
24
  from ...osm.tiling import TileManager
25
  from ...osm.viz import GeoPlotter
26
  from ...utils.geo import BoundaryBox, Projection
27
+ from ...utils.io import DATA_URL, download_file, write_json
28
  from ..utils import decompose_rotmat
29
+ from .dataset import MapillaryDataModule
30
+ from .download import (
31
+ MapillaryDownloader,
32
+ fetch_image_infos,
33
+ fetch_images_pixels,
34
+ image_filename,
35
+ opensfm_shot_from_info,
36
+ )
37
  from .utils import (
38
+ CameraUndistorter,
39
+ PanoramaUndistorter,
40
  keyframe_selection,
41
  perspective_camera_from_pano,
42
  scale_camera,
 
 
43
  undistort_shot,
44
  )
 
 
 
 
 
 
 
 
 
45
 
46
  location_to_params = {
47
  "sanfrancisco_soma": {
maploc/data/mapillary/utils.py CHANGED
@@ -6,7 +6,7 @@ from typing import List, Tuple
6
  import cv2
7
  import numpy as np
8
  from opensfm import features
9
- from opensfm.pygeometry import Camera, compute_camera_mapping, Pose
10
  from opensfm.pymap import Shot
11
  from scipy.spatial.transform import Rotation
12
 
 
6
  import cv2
7
  import numpy as np
8
  from opensfm import features
9
+ from opensfm.pygeometry import Camera, Pose, compute_camera_mapping
10
  from opensfm.pymap import Shot
11
  from scipy.spatial.transform import Rotation
12
 
maploc/data/torch.py CHANGED
@@ -4,14 +4,14 @@ import collections
4
  import os
5
 
6
  import torch
 
 
 
7
  from torch.utils.data import get_worker_info
8
  from torch.utils.data._utils.collate import (
9
  default_collate_err_msg_format,
10
  np_str_obj_array_pattern,
11
  )
12
- from lightning_fabric.utilities.seed import pl_worker_init_function
13
- from lightning_utilities.core.apply_func import apply_to_collection
14
- from lightning_fabric.utilities.apply_func import move_data_to_device
15
 
16
 
17
  def collate(batch):
 
4
  import os
5
 
6
  import torch
7
+ from lightning_fabric.utilities.apply_func import move_data_to_device
8
+ from lightning_fabric.utilities.seed import pl_worker_init_function
9
+ from lightning_utilities.core.apply_func import apply_to_collection
10
  from torch.utils.data import get_worker_info
11
  from torch.utils.data._utils.collate import (
12
  default_collate_err_msg_format,
13
  np_str_obj_array_pattern,
14
  )
 
 
 
15
 
16
 
17
  def collate(batch):
maploc/data/utils.py CHANGED
@@ -54,7 +54,5 @@ def decompose_rotmat(R_c2w):
54
  R_cv2xyz = Rotation.from_euler("X", -90, degrees=True)
55
  rot_w2c = R_cv2xyz * Rotation.from_matrix(R_c2w).inv()
56
  roll, pitch, yaw = rot_w2c.as_euler("YXZ", degrees=True)
57
- # rot_w2c_check = R_cv2xyz.inv() * Rotation.from_euler('YXZ', [roll, pitch, yaw], degrees=True)
58
- # np.testing.assert_allclose(rot_w2c_check.as_matrix(), R_c2w.T, rtol=1e-6, atol=1e-6)
59
  # R_plane2c = Rotation.from_euler("ZX", [roll, pitch], degrees=True).as_matrix()
60
  return roll, pitch, yaw
 
54
  R_cv2xyz = Rotation.from_euler("X", -90, degrees=True)
55
  rot_w2c = R_cv2xyz * Rotation.from_matrix(R_c2w).inv()
56
  roll, pitch, yaw = rot_w2c.as_euler("YXZ", degrees=True)
 
 
57
  # R_plane2c = Rotation.from_euler("ZX", [roll, pitch], degrees=True).as_matrix()
58
  return roll, pitch, yaw
maploc/demo.py CHANGED
@@ -2,19 +2,19 @@
2
 
3
  from typing import Optional, Tuple
4
 
5
- import torch
6
  import numpy as np
 
7
 
8
  from . import logger
9
- from .evaluation.run import resolve_checkpoint_path, pretrained_models
 
10
  from .models.orienternet import OrienterNet
11
- from .models.voting import fuse_gps, argmax_xyr
12
- from .data.image import resize_image, pad_image, rectify_image
13
  from .osm.raster import Canvas
14
- from .utils.wrappers import Camera
15
- from .utils.io import read_image
16
- from .utils.geo import BoundaryBox, Projection
17
  from .utils.exif import EXIF
 
 
 
18
 
19
  try:
20
  from geopy.geocoders import Nominatim
 
2
 
3
  from typing import Optional, Tuple
4
 
 
5
  import numpy as np
6
+ import torch
7
 
8
  from . import logger
9
+ from .data.image import pad_image, rectify_image, resize_image
10
+ from .evaluation.run import pretrained_models, resolve_checkpoint_path
11
  from .models.orienternet import OrienterNet
12
+ from .models.voting import argmax_xyr, fuse_gps
 
13
  from .osm.raster import Canvas
 
 
 
14
  from .utils.exif import EXIF
15
+ from .utils.geo import BoundaryBox, Projection
16
+ from .utils.io import read_image
17
+ from .utils.wrappers import Camera
18
 
19
  try:
20
  from geopy.geocoders import Nominatim
maploc/evaluation/kitti.py CHANGED
@@ -4,17 +4,16 @@ import argparse
4
  from pathlib import Path
5
  from typing import Optional, Tuple
6
 
7
- from omegaconf import OmegaConf, DictConfig
8
 
9
  from .. import logger
10
  from ..data import KittiDataModule
11
  from .run import evaluate
12
 
13
-
14
  default_cfg_single = OmegaConf.create({})
15
  # For the sequential evaluation, we need to center the map around the GT location,
16
- # since random offsets would accumulate and leave only the GT location with a valid mask.
17
- # This should not have much impact on the results.
18
  default_cfg_sequential = OmegaConf.create(
19
  {
20
  "data": {
 
4
  from pathlib import Path
5
  from typing import Optional, Tuple
6
 
7
+ from omegaconf import DictConfig, OmegaConf
8
 
9
  from .. import logger
10
  from ..data import KittiDataModule
11
  from .run import evaluate
12
 
 
13
  default_cfg_single = OmegaConf.create({})
14
  # For the sequential evaluation, we need to center the map around the GT location,
15
+ # since random offsets would accumulate and leave only the GT location with
16
+ # a valid mask. This should not have much impact on the results.
17
  default_cfg_sequential = OmegaConf.create(
18
  {
19
  "data": {
maploc/evaluation/mapillary.py CHANGED
@@ -4,14 +4,13 @@ import argparse
4
  from pathlib import Path
5
  from typing import Optional, Tuple
6
 
7
- from omegaconf import OmegaConf, DictConfig
8
 
9
  from .. import logger
10
  from ..conf import data as conf_data_dir
11
  from ..data import MapillaryDataModule
12
  from .run import evaluate
13
 
14
-
15
  split_overrides = {
16
  "val": {
17
  "scenes": [
 
4
  from pathlib import Path
5
  from typing import Optional, Tuple
6
 
7
+ from omegaconf import DictConfig, OmegaConf
8
 
9
  from .. import logger
10
  from ..conf import data as conf_data_dir
11
  from ..data import MapillaryDataModule
12
  from .run import evaluate
13
 
 
14
  split_overrides = {
15
  "val": {
16
  "scenes": [
maploc/evaluation/run.py CHANGED
@@ -2,26 +2,25 @@
2
 
3
  import functools
4
  from itertools import islice
5
- from typing import Callable, Dict, Optional, Tuple
6
  from pathlib import Path
 
7
 
8
  import numpy as np
9
  import torch
10
  from omegaconf import DictConfig, OmegaConf
11
- from torchmetrics import MetricCollection
12
  from pytorch_lightning import seed_everything
 
13
  from tqdm import tqdm
14
 
15
- from .. import logger, EXPERIMENTS_PATH
16
  from ..data.torch import collate, unbatch_to_device
17
- from ..models.voting import argmax_xyr, fuse_gps
18
  from ..models.metrics import AngleError, LateralLongitudinalError, Location2DError
19
  from ..models.sequential import GPSAligner, RigidAligner
 
20
  from ..module import GenericModule
21
- from ..utils.io import download_file, DATA_URL
22
- from .viz import plot_example_single, plot_example_sequential
23
  from .utils import write_dump
24
-
25
 
26
  pretrained_models = dict(
27
  OrienterNet_MGL=("orienternet_mgl.ckpt", dict(num_rotations=256)),
 
2
 
3
  import functools
4
  from itertools import islice
 
5
  from pathlib import Path
6
+ from typing import Callable, Dict, Optional, Tuple
7
 
8
  import numpy as np
9
  import torch
10
  from omegaconf import DictConfig, OmegaConf
 
11
  from pytorch_lightning import seed_everything
12
+ from torchmetrics import MetricCollection
13
  from tqdm import tqdm
14
 
15
+ from .. import EXPERIMENTS_PATH, logger
16
  from ..data.torch import collate, unbatch_to_device
 
17
  from ..models.metrics import AngleError, LateralLongitudinalError, Location2DError
18
  from ..models.sequential import GPSAligner, RigidAligner
19
+ from ..models.voting import argmax_xyr, fuse_gps
20
  from ..module import GenericModule
21
+ from ..utils.io import DATA_URL, download_file
 
22
  from .utils import write_dump
23
+ from .viz import plot_example_sequential, plot_example_single
24
 
25
  pretrained_models = dict(
26
  OrienterNet_MGL=("orienternet_mgl.ckpt", dict(num_rotations=256)),
maploc/evaluation/viz.py CHANGED
@@ -1,18 +1,18 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
 
3
  import numpy as np
4
  import torch
5
- import matplotlib.pyplot as plt
6
 
 
7
  from ..utils.io import write_torch_image
8
- from ..utils.viz_2d import plot_images, features_to_RGB, save_plot
9
  from ..utils.viz_localization import (
 
10
  likelihood_overlay,
11
- plot_pose,
12
  plot_dense_rotations,
13
- add_circle_inset,
14
  )
15
- from ..osm.viz import Colormap, plot_nodes
16
 
17
 
18
  def plot_example_single(
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
3
+ import matplotlib.pyplot as plt
4
  import numpy as np
5
  import torch
 
6
 
7
+ from ..osm.viz import Colormap, plot_nodes
8
  from ..utils.io import write_torch_image
9
+ from ..utils.viz_2d import features_to_RGB, plot_images, save_plot
10
  from ..utils.viz_localization import (
11
+ add_circle_inset,
12
  likelihood_overlay,
 
13
  plot_dense_rotations,
14
+ plot_pose,
15
  )
 
16
 
17
 
18
  def plot_example_single(
maploc/models/orienternet.py CHANGED
@@ -8,7 +8,10 @@ from . import get_model
8
  from .base import BaseModel
9
  from .bev_net import BEVNet
10
  from .bev_projection import CartesianProjection, PolarProjectionDepth
 
 
11
  from .voting import (
 
12
  argmax_xyr,
13
  conv2d_fft_batchwise,
14
  expectation_xyr,
@@ -16,10 +19,7 @@ from .voting import (
16
  mask_yaw_prior,
17
  nll_loss_xyr,
18
  nll_loss_xyr_smoothed,
19
- TemplateSampler,
20
  )
21
- from .map_encoder import MapEncoder
22
- from .metrics import AngleError, AngleRecall, Location2DError, Location2DRecall
23
 
24
 
25
  class OrienterNet(BaseModel):
@@ -106,8 +106,8 @@ class OrienterNet(BaseModel):
106
  if self.conf.add_temperature:
107
  scores = scores * torch.exp(self.temperature)
108
 
109
- # Reweight the different rotations based on the number of valid pixels
110
- # in each template. Axis-aligned rotation have the maximum number of valid pixels.
111
  valid_templates = self.template_sampler(valid_bev.float()[None]) > (1 - 1e-4)
112
  num_valid = valid_templates.float().sum((-3, -2, -1))
113
  scores = scores / num_valid[..., None, None]
 
8
  from .base import BaseModel
9
  from .bev_net import BEVNet
10
  from .bev_projection import CartesianProjection, PolarProjectionDepth
11
+ from .map_encoder import MapEncoder
12
+ from .metrics import AngleError, AngleRecall, Location2DError, Location2DRecall
13
  from .voting import (
14
+ TemplateSampler,
15
  argmax_xyr,
16
  conv2d_fft_batchwise,
17
  expectation_xyr,
 
19
  mask_yaw_prior,
20
  nll_loss_xyr,
21
  nll_loss_xyr_smoothed,
 
22
  )
 
 
23
 
24
 
25
  class OrienterNet(BaseModel):
 
106
  if self.conf.add_temperature:
107
  scores = scores * torch.exp(self.temperature)
108
 
109
+ # Reweight the different rotations based on the number of valid pixels in each
110
+ # template. Axis-aligned rotation have the maximum number of valid pixels.
111
  valid_templates = self.template_sampler(valid_bev.float()[None]) > (1 - 1e-4)
112
  num_valid = valid_templates.float().sum((-3, -2, -1))
113
  scores = scores / num_valid[..., None, None]
maploc/models/sequential.py CHANGED
@@ -3,8 +3,8 @@
3
  import numpy as np
4
  import torch
5
 
6
- from .voting import argmax_xyr, log_softmax_spatial, sample_xyr
7
  from .utils import deg2rad, make_grid, rotmat2d
 
8
 
9
 
10
  def log_gaussian(points, mean, sigma):
 
3
  import numpy as np
4
  import torch
5
 
 
6
  from .utils import deg2rad, make_grid, rotmat2d
7
+ from .voting import argmax_xyr, log_softmax_spatial, sample_xyr
8
 
9
 
10
  def log_gaussian(points, mean, sigma):
maploc/osm/analysis.py CHANGED
@@ -8,6 +8,7 @@ import numpy as np
8
  import plotly.graph_objects as go
9
 
10
  from .parser import (
 
11
  filter_area,
12
  filter_node,
13
  filter_way,
@@ -15,7 +16,6 @@ from .parser import (
15
  parse_area,
16
  parse_node,
17
  parse_way,
18
- Patterns,
19
  )
20
  from .reader import OSMData
21
 
 
8
  import plotly.graph_objects as go
9
 
10
  from .parser import (
11
+ Patterns,
12
  filter_area,
13
  filter_node,
14
  filter_way,
 
16
  parse_area,
17
  parse_node,
18
  parse_way,
 
19
  )
20
  from .reader import OSMData
21
 
maploc/osm/data.py CHANGED
@@ -7,6 +7,7 @@ from typing import Dict, List, Optional, Set, Tuple
7
  import numpy as np
8
 
9
  from .parser import (
 
10
  filter_area,
11
  filter_node,
12
  filter_way,
@@ -14,11 +15,9 @@ from .parser import (
14
  parse_area,
15
  parse_node,
16
  parse_way,
17
- Patterns,
18
  )
19
  from .reader import OSMData, OSMNode, OSMRelation, OSMWay
20
 
21
-
22
  logger = logging.getLogger(__name__)
23
 
24
 
 
7
  import numpy as np
8
 
9
  from .parser import (
10
+ Patterns,
11
  filter_area,
12
  filter_node,
13
  filter_way,
 
15
  parse_area,
16
  parse_node,
17
  parse_way,
 
18
  )
19
  from .reader import OSMData, OSMNode, OSMRelation, OSMWay
20
 
 
21
  logger = logging.getLogger(__name__)
22
 
23
 
maploc/osm/download.py CHANGED
@@ -1,9 +1,9 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
3
  import json
 
4
  from pathlib import Path
5
  from typing import Any, Dict, Optional
6
- from http.client import responses
7
 
8
  import urllib3
9
 
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
3
  import json
4
+ from http.client import responses
5
  from pathlib import Path
6
  from typing import Any, Dict, Optional
 
7
 
8
  import urllib3
9
 
maploc/osm/parser.py CHANGED
@@ -144,7 +144,7 @@ class Patterns:
144
  building="building($|:.*?)*",
145
  parking="amenity:parking",
146
  playground="leisure:(playground|pitch)",
147
- grass="(landuse:grass|landcover:grass|landuse:meadow|landuse:flowerbed|natural:grassland)",
148
  park="leisure:(park|garden|dog_park)",
149
  forest="(landuse:forest|natural:wood)",
150
  water="(natural:water|waterway:*)",
@@ -160,7 +160,7 @@ class Patterns:
160
  building_outline="building($|:.*?)*",
161
  cycleway="highway:cycleway",
162
  path="highway:(pedestrian|footway|steps|path|corridor)",
163
- road="highway:(motorway|trunk|primary|secondary|tertiary|service|construction|track|unclassified|residential|.*_link)",
164
  busway="highway:busway",
165
  tree_row="natural:tree_row", # maybe merge with node?
166
  )
 
144
  building="building($|:.*?)*",
145
  parking="amenity:parking",
146
  playground="leisure:(playground|pitch)",
147
+ grass="(landuse:grass|landcover:grass|landuse:meadow|landuse:flowerbed|natural:grassland)", # noqa E501
148
  park="leisure:(park|garden|dog_park)",
149
  forest="(landuse:forest|natural:wood)",
150
  water="(natural:water|waterway:*)",
 
160
  building_outline="building($|:.*?)*",
161
  cycleway="highway:cycleway",
162
  path="highway:(pedestrian|footway|steps|path|corridor)",
163
+ road="highway:(motorway|trunk|primary|secondary|tertiary|service|construction|track|unclassified|residential|.*_link)", # noqa E501
164
  busway="highway:busway",
165
  tree_row="natural:tree_row", # maybe merge with node?
166
  )
maploc/osm/reader.py CHANGED
@@ -6,8 +6,8 @@ from dataclasses import dataclass, field
6
  from pathlib import Path
7
  from typing import Any, Dict, List, Optional
8
 
9
- from lxml import etree
10
  import numpy as np
 
11
 
12
  from ..utils.geo import BoundaryBox, Projection
13
 
 
6
  from pathlib import Path
7
  from typing import Any, Dict, List, Optional
8
 
 
9
  import numpy as np
10
+ from lxml import etree
11
 
12
  from ..utils.geo import BoundaryBox, Projection
13
 
maploc/osm/tiling.py CHANGED
@@ -6,8 +6,8 @@ from pathlib import Path
6
  from typing import Dict, List, Optional, Tuple
7
 
8
  import numpy as np
9
- from PIL import Image
10
  import rtree
 
11
 
12
  from ..utils.geo import BoundaryBox, Projection
13
  from .data import MapData
 
6
  from typing import Dict, List, Optional, Tuple
7
 
8
  import numpy as np
 
9
  import rtree
10
+ from PIL import Image
11
 
12
  from ..utils.geo import BoundaryBox, Projection
13
  from .data import MapData
maploc/osm/viz.py CHANGED
@@ -3,8 +3,8 @@
3
  import matplotlib as mpl
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
- import plotly.graph_objects as go
7
  import PIL.Image
 
8
 
9
  from ..utils.viz_2d import add_text
10
  from .parser import Groups
 
3
  import matplotlib as mpl
4
  import matplotlib.pyplot as plt
5
  import numpy as np
 
6
  import PIL.Image
7
+ import plotly.graph_objects as go
8
 
9
  from ..utils.viz_2d import add_text
10
  from .parser import Groups
maploc/train.py CHANGED
@@ -1,8 +1,8 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
3
  import os.path as osp
4
- from typing import Optional
5
  from pathlib import Path
 
6
 
7
  import hydra
8
  import pytorch_lightning as pl
@@ -10,7 +10,7 @@ import torch
10
  from omegaconf import DictConfig, OmegaConf
11
  from pytorch_lightning.utilities import rank_zero_only
12
 
13
- from . import logger, pl_logger, EXPERIMENTS_PATH
14
  from .data import modules as data_modules
15
  from .module import GenericModule
16
 
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
3
  import os.path as osp
 
4
  from pathlib import Path
5
+ from typing import Optional
6
 
7
  import hydra
8
  import pytorch_lightning as pl
 
10
  from omegaconf import DictConfig, OmegaConf
11
  from pytorch_lightning.utilities import rank_zero_only
12
 
13
+ from . import EXPERIMENTS_PATH, logger, pl_logger
14
  from .data import modules as data_modules
15
  from .module import GenericModule
16
 
maploc/utils/exif.py CHANGED
@@ -1,9 +1,10 @@
1
  """Copied from opensfm.exif to minimize hard dependencies."""
2
- from pathlib import Path
3
- import json
4
  import datetime
 
5
  import logging
6
- from codecs import encode, decode
 
7
  from typing import Any, Dict, Optional, Tuple
8
 
9
  import exifread
@@ -209,7 +210,7 @@ class EXIF:
209
  orientation = 1
210
  if "Image Orientation" in self.tags:
211
  value = self.tags.get("Image Orientation").values[0]
212
- if type(value) == int and value != 0:
213
  orientation = value
214
  return orientation
215
 
@@ -243,7 +244,8 @@ class EXIF:
243
  else:
244
  altitude = None
245
 
246
- # Check if GPSAltitudeRef is equal to 1, which means GPSAltitude should be negative, reference: http://www.exif.org/Exif2-2.PDF#page=53
 
247
  if (
248
  "GPS GPSAltitudeRef" in self.tags
249
  and self.tags["GPS GPSAltitudeRef"].values[0] == 1
 
1
  """Copied from opensfm.exif to minimize hard dependencies."""
2
+
 
3
  import datetime
4
+ import json
5
  import logging
6
+ from codecs import decode, encode
7
+ from pathlib import Path
8
  from typing import Any, Dict, Optional, Tuple
9
 
10
  import exifread
 
210
  orientation = 1
211
  if "Image Orientation" in self.tags:
212
  value = self.tags.get("Image Orientation").values[0]
213
+ if isinstance(value, int) and value != 0:
214
  orientation = value
215
  return orientation
216
 
 
244
  else:
245
  altitude = None
246
 
247
+ # Check if GPSAltitudeRef is equal to 1, which means GPSAltitude
248
+ # should be negative, reference: http://www.exif.org/Exif2-2.PDF#page=53
249
  if (
250
  "GPS GPSAltitudeRef" in self.tags
251
  and self.tags["GPS GPSAltitudeRef"].values[0] == 1
maploc/utils/geo.py CHANGED
@@ -5,7 +5,6 @@ from typing import Union
5
  import numpy as np
6
  import torch
7
 
8
- from .. import logger
9
  from .geo_opensfm import TopocentricConverter
10
 
11
 
 
5
  import numpy as np
6
  import torch
7
 
 
8
  from .geo_opensfm import TopocentricConverter
9
 
10
 
maploc/utils/geo_opensfm.py CHANGED
@@ -1,7 +1,9 @@
1
  """Copied from opensfm.geo to minimize hard dependencies."""
 
 
 
2
  import numpy as np
3
  from numpy import ndarray
4
- from typing import Tuple
5
 
6
  WGS84_a = 6378137.0
7
  WGS84_b = 6356752.314245
 
1
  """Copied from opensfm.geo to minimize hard dependencies."""
2
+
3
+ from typing import Tuple
4
+
5
  import numpy as np
6
  from numpy import ndarray
 
7
 
8
  WGS84_a = 6378137.0
9
  WGS84_b = 6356752.314245
maploc/utils/io.py CHANGED
@@ -1,12 +1,12 @@
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
3
  import json
4
- import requests
5
  import shutil
6
  from pathlib import Path
7
 
8
  import cv2
9
  import numpy as np
 
10
  import torch
11
  from tqdm.auto import tqdm
12
 
 
1
  # Copyright (c) Meta Platforms, Inc. and affiliates.
2
 
3
  import json
 
4
  import shutil
5
  from pathlib import Path
6
 
7
  import cv2
8
  import numpy as np
9
+ import requests
10
  import torch
11
  from tqdm.auto import tqdm
12
 
requirements/dev.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ black[jupyter]==24.2.0
2
+ isort
3
+ flake8
setup.py CHANGED
@@ -1,11 +1,11 @@
1
- from setuptools import setup, find_packages
2
 
3
  setup(
4
- name='maploc',
5
- version='0.0.0',
6
  packages=find_packages(),
7
- python_requires='>=3.8',
8
- author='Paul-Edouard Sarlin',
9
  long_description_content_type="text/markdown",
10
  classifiers=[
11
  "Programming Language :: Python :: 3",
 
1
+ from setuptools import find_packages, setup
2
 
3
  setup(
4
+ name="maploc",
5
+ version="0.0.0",
6
  packages=find_packages(),
7
+ python_requires=">=3.8",
8
+ author="Paul-Edouard Sarlin",
9
  long_description_content_type="text/markdown",
10
  classifiers=[
11
  "Programming Language :: Python :: 3",