Spaces:
Runtime error
Runtime error
File size: 4,522 Bytes
4121bec |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
# Copyright (c) Facebook, Inc. and its affiliates.
import io
import numpy as np
import torch
from detectron2 import model_zoo
from detectron2.data import DatasetCatalog
from detectron2.data.detection_utils import read_image
from detectron2.modeling import build_model
from detectron2.structures import Boxes, Instances, ROIMasks
from detectron2.utils.file_io import PathManager
"""
Internal utilities for tests. Don't use except for writing tests.
"""
def get_model_no_weights(config_path):
"""
Like model_zoo.get, but do not load any weights (even pretrained)
"""
cfg = model_zoo.get_config(config_path)
if not torch.cuda.is_available():
cfg.MODEL.DEVICE = "cpu"
return build_model(cfg)
def random_boxes(num_boxes, max_coord=100, device="cpu"):
"""
Create a random Nx4 boxes tensor, with coordinates < max_coord.
"""
boxes = torch.rand(num_boxes, 4, device=device) * (max_coord * 0.5)
boxes.clamp_(min=1.0) # tiny boxes cause numerical instability in box regression
# Note: the implementation of this function in torchvision is:
# boxes[:, 2:] += torch.rand(N, 2) * 100
# but it does not guarantee non-negative widths/heights constraints:
# boxes[:, 2] >= boxes[:, 0] and boxes[:, 3] >= boxes[:, 1]:
boxes[:, 2:] += boxes[:, :2]
return boxes
def get_sample_coco_image(tensor=True):
"""
Args:
tensor (bool): if True, returns 3xHxW tensor.
else, returns a HxWx3 numpy array.
Returns:
an image, in BGR color.
"""
try:
file_name = DatasetCatalog.get("coco_2017_val_100")[0]["file_name"]
if not PathManager.exists(file_name):
raise FileNotFoundError()
except IOError:
# for public CI to run
file_name = "http://images.cocodataset.org/train2017/000000000009.jpg"
ret = read_image(file_name, format="BGR")
if tensor:
ret = torch.from_numpy(np.ascontiguousarray(ret.transpose(2, 0, 1)))
return ret
def convert_scripted_instances(instances):
"""
Convert a scripted Instances object to a regular :class:`Instances` object
"""
ret = Instances(instances.image_size)
for name in instances._field_names:
val = getattr(instances, "_" + name, None)
if val is not None:
ret.set(name, val)
return ret
def assert_instances_allclose(input, other, *, rtol=1e-5, msg="", size_as_tensor=False):
"""
Args:
input, other (Instances):
size_as_tensor: compare image_size of the Instances as tensors (instead of tuples).
Useful for comparing outputs of tracing.
"""
if not isinstance(input, Instances):
input = convert_scripted_instances(input)
if not isinstance(other, Instances):
other = convert_scripted_instances(other)
if not msg:
msg = "Two Instances are different! "
else:
msg = msg.rstrip() + " "
size_error_msg = msg + f"image_size is {input.image_size} vs. {other.image_size}!"
if size_as_tensor:
assert torch.equal(
torch.tensor(input.image_size), torch.tensor(other.image_size)
), size_error_msg
else:
assert input.image_size == other.image_size, size_error_msg
fields = sorted(input.get_fields().keys())
fields_other = sorted(other.get_fields().keys())
assert fields == fields_other, msg + f"Fields are {fields} vs {fields_other}!"
for f in fields:
val1, val2 = input.get(f), other.get(f)
if isinstance(val1, (Boxes, ROIMasks)):
# boxes in the range of O(100) and can have a larger tolerance
assert torch.allclose(val1.tensor, val2.tensor, atol=100 * rtol), (
msg + f"Field {f} differs too much!"
)
elif isinstance(val1, torch.Tensor):
if val1.dtype.is_floating_point:
mag = torch.abs(val1).max().cpu().item()
assert torch.allclose(val1, val2, atol=mag * rtol), (
msg + f"Field {f} differs too much!"
)
else:
assert torch.equal(val1, val2), msg + f"Field {f} is different!"
else:
raise ValueError(f"Don't know how to compare type {type(val1)}")
def reload_script_model(module):
"""
Save a jit module and load it back.
Similar to the `getExportImportCopy` function in torch/testing/
"""
buffer = io.BytesIO()
torch.jit.save(module, buffer)
buffer.seek(0)
return torch.jit.load(buffer)
|