|
import os |
|
import json |
|
import shutil |
|
from pathlib import Path |
|
from typing import Dict |
|
|
|
from PIL import ImageFile |
|
ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
|
LOCAL_DATADIR = None |
|
|
|
def setup(local_dir='./data/usm-training-data/data'): |
|
|
|
|
|
tmp_datadir = Path('/tmp/data/data') |
|
local_test_datadir = Path('./data/usm-test-data-x/data') |
|
local_val_datadir = Path(local_dir) |
|
|
|
os.system('pwd') |
|
os.system('ls -lahtr .') |
|
|
|
if tmp_datadir.exists() and not local_test_datadir.exists(): |
|
global LOCAL_DATADIR |
|
LOCAL_DATADIR = local_test_datadir |
|
|
|
print(f"Linking {tmp_datadir} to {LOCAL_DATADIR} (we are in the test environment)") |
|
LOCAL_DATADIR.parent.mkdir(parents=True, exist_ok=True) |
|
LOCAL_DATADIR.symlink_to(tmp_datadir) |
|
else: |
|
LOCAL_DATADIR = local_val_datadir |
|
print(f"Using {LOCAL_DATADIR} as the data directory (we are running locally)") |
|
|
|
|
|
|
|
|
|
assert LOCAL_DATADIR.exists(), f"Data directory {LOCAL_DATADIR} does not exist" |
|
return LOCAL_DATADIR |
|
|
|
|
|
|
|
|
|
import importlib |
|
from pathlib import Path |
|
import subprocess |
|
|
|
def download_package(package_name, path_to_save='packages'): |
|
""" |
|
Downloads a package using pip and saves it to a specified directory. |
|
|
|
Parameters: |
|
package_name (str): The name of the package to download. |
|
path_to_save (str): The path to the directory where the package will be saved. |
|
""" |
|
try: |
|
|
|
subprocess.check_call([subprocess.sys.executable, "-m", "pip", "download", package_name, |
|
"-d", str(Path(path_to_save)/package_name), |
|
"--platform", "manylinux1_x86_64", |
|
"--python-version", "38", |
|
"--only-binary=:all:"]) |
|
print(f'Package "{package_name}" downloaded successfully') |
|
except subprocess.CalledProcessError as e: |
|
print(f'Failed to downloaded package "{package_name}". Error: {e}') |
|
|
|
|
|
def install_package_from_local_file(package_name, folder='packages'): |
|
""" |
|
Installs a package from a local .whl file or a directory containing .whl files using pip. |
|
|
|
Parameters: |
|
path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files. |
|
""" |
|
try: |
|
pth = str(Path(folder) / package_name) |
|
subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install", |
|
"--no-index", |
|
"--find-links", pth, |
|
package_name]) |
|
print(f"Package installed successfully from {pth}") |
|
except subprocess.CalledProcessError as e: |
|
print(f"Failed to install package from {pth}. Error: {e}") |
|
|
|
|
|
def importt(module_name, as_name=None): |
|
""" |
|
Imports a module and returns it. |
|
|
|
Parameters: |
|
module_name (str): The name of the module to import. |
|
as_name (str): The name to use for the imported module. If None, the original module name will be used. |
|
|
|
Returns: |
|
The imported module. |
|
""" |
|
for _ in range(2): |
|
try: |
|
if as_name is None: |
|
print(f'imported {module_name}') |
|
return importlib.import_module(module_name) |
|
else: |
|
print(f'imported {module_name} as {as_name}') |
|
return importlib.import_module(module_name, as_name) |
|
except ModuleNotFoundError as e: |
|
install_package_from_local_file(module_name) |
|
print(f"Failed to import module {module_name}. Error: {e}") |
|
|
|
|
|
def prepare_submission(): |
|
|
|
if Path('requirements.txt').exists(): |
|
print('downloading packages from requirements.txt') |
|
Path('packages').mkdir(exist_ok=True) |
|
with open('requirements.txt') as f: |
|
packages = f.readlines() |
|
for p in packages: |
|
download_package(p.strip()) |
|
|
|
print('all packages downloaded. Don\'t foget to include the packages in the submission by adding them with git lfs.') |
|
|
|
|
|
def Rt_to_eye_target(im, K, R, t): |
|
height = im.height |
|
focal_length = K[0,0] |
|
fov = 2.0 * np.arctan2((0.5 * height), focal_length) / (np.pi / 180.0) |
|
|
|
x_axis, y_axis, z_axis = R |
|
|
|
eye = -(R.T @ t).squeeze() |
|
z_axis = z_axis.squeeze() |
|
target = eye + z_axis |
|
up = -y_axis |
|
|
|
return eye, target, up, fov |
|
|
|
|
|
|
|
import contextlib |
|
import tempfile |
|
from pathlib import Path |
|
|
|
@contextlib.contextmanager |
|
def working_directory(path): |
|
"""Changes working directory and returns to previous on exit.""" |
|
prev_cwd = Path.cwd() |
|
os.chdir(path) |
|
try: |
|
yield |
|
finally: |
|
os.chdir(prev_cwd) |
|
|
|
@contextlib.contextmanager |
|
def temp_working_directory(): |
|
with tempfile.TemporaryDirectory(dir='.') as D: |
|
with working_directory(D): |
|
yield |
|
|
|
|
|
|
|
def proc(row, split='train'): |
|
out = {} |
|
out['__key__'] = None |
|
out['__imagekey__'] = [] |
|
for k, v in row.items(): |
|
key_parts = k.split('.') |
|
colname = key_parts[0] |
|
if colname == 'ade20k': |
|
out['__imagekey__'].append(key_parts[1]) |
|
if colname in {'ade20k', 'depthcm', 'gestalt'}: |
|
if colname in out: |
|
out[colname].append(v) |
|
else: |
|
out[colname] = [v] |
|
elif colname in {'wireframe', 'mesh'}: |
|
out.update({a: b for a,b in v.items()}) |
|
elif colname in 'kr': |
|
out[colname.upper()] = v |
|
else: |
|
out[colname] = v |
|
return Sample(out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
from . import read_write_colmap |
|
def decode_colmap(s): |
|
with temp_working_directory(): |
|
|
|
with open('points3D.bin', 'wb') as stream: |
|
stream.write(s['points3d']) |
|
|
|
|
|
with open('cameras.bin', 'wb') as stream: |
|
stream.write(s['cameras']) |
|
|
|
|
|
with open('images.bin', 'wb') as stream: |
|
stream.write(s['images']) |
|
|
|
|
|
cameras, images, points3D = read_write_colmap.read_model( |
|
path='.', ext='.bin' |
|
) |
|
return cameras, images, points3D |
|
|
|
from PIL import Image |
|
import io |
|
def decode(row): |
|
cameras, images, points3D = decode_colmap(row) |
|
|
|
out = {} |
|
|
|
for k, v in row.items(): |
|
|
|
if k in {'ade20k', 'depthcm', 'gestalt'}: |
|
|
|
v = [Image.open(io.BytesIO(im)) for im in v] |
|
if k in out: |
|
out[k].extend(v) |
|
else: |
|
out[k] = v |
|
elif k in {'wireframe', 'mesh'}: |
|
|
|
v = dict(np.load(io.BytesIO(v))) |
|
out.update({a: b for a,b in v.items()}) |
|
elif k in 'kr': |
|
out[k.upper()] = v |
|
elif k == 'cameras': |
|
out[k] = cameras |
|
elif k == 'images': |
|
out[k] = images |
|
elif k =='points3d': |
|
out[k] = points3D |
|
else: |
|
out[k] = v |
|
|
|
return Sample(out) |
|
|
|
|
|
class Sample(Dict): |
|
def __repr__(self): |
|
return str({k: v.shape if hasattr(v, 'shape') else [type(v[0])] if isinstance(v, list) else type(v) for k,v in self.items()}) |
|
|
|
|
|
|
|
def get_params(): |
|
exmaple_param_dict = { |
|
"competition_id": "usm3d/S23DR", |
|
"competition_type": "script", |
|
"metric": "custom", |
|
"token": "hf_**********************************", |
|
"team_id": "local-test-team_id", |
|
"submission_id": "local-test-submission_id", |
|
"submission_id_col": "__key__", |
|
"submission_cols": [ |
|
"__key__", |
|
"wf_edges", |
|
"wf_vertices", |
|
"edge_semantics" |
|
], |
|
"submission_rows": 180, |
|
"output_path": ".", |
|
"submission_repo": "<THE HF MODEL ID of THIS REPO", |
|
"time_limit": 7200, |
|
"dataset": "usm3d/usm-test-data-x", |
|
"submission_filenames": [ |
|
"submission.parquet" |
|
] |
|
} |
|
|
|
param_path = Path('params.json') |
|
|
|
if not param_path.exists(): |
|
print('params.json not found (this means we probably aren\'t in the test env). Using example params.') |
|
params = exmaple_param_dict |
|
else: |
|
print('found params.json (this means we are probably in the test env). Using params from file.') |
|
with param_path.open() as f: |
|
params = json.load(f) |
|
print(params) |
|
return params |
|
|
|
|
|
|
|
import webdataset as wds |
|
import numpy as np |
|
|
|
def get_dataset(decode='pil', proc=proc, split='train', dataset_type='webdataset'): |
|
if LOCAL_DATADIR is None: |
|
raise ValueError('LOCAL_DATADIR is not set. Please run setup() first.') |
|
|
|
local_dir = Path(LOCAL_DATADIR) |
|
if split != 'all': |
|
local_dir = local_dir / split |
|
|
|
paths = [str(p) for p in local_dir.rglob('*.tar.gz')] |
|
|
|
dataset = wds.WebDataset(paths) |
|
if decode is not None: |
|
dataset = dataset.decode(decode) |
|
else: |
|
dataset = dataset.decode() |
|
|
|
dataset = dataset.map(proc) |
|
|
|
if dataset_type == 'webdataset': |
|
return dataset |
|
|
|
if dataset_type == 'hf': |
|
import datasets |
|
from datasets import Features, Value, Sequence, Image, Array2D |
|
|
|
if split == 'train': |
|
return datasets.IterableDataset.from_generator(lambda: dataset.iterator()) |
|
elif split == 'val': |
|
return datasets.IterableDataset.from_generator(lambda: dataset.iterator()) |
|
|
|
|
|
|