|
|
|
|
|
import collections
|
|
import copy
|
|
import functools
|
|
import logging
|
|
import numpy as np
|
|
import os
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
from unittest import mock
|
|
import caffe2.python.utils as putils
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from caffe2.proto import caffe2_pb2
|
|
from caffe2.python import core, net_drawer, workspace
|
|
from torch.nn.functional import interpolate as interp
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
|
|
|
def to_device(t, device_str):
|
|
"""
|
|
This function is a replacement of .to(another_device) such that it allows the
|
|
casting to be traced properly by explicitly calling the underlying copy ops.
|
|
It also avoids introducing unncessary op when casting to the same device.
|
|
"""
|
|
src = t.device
|
|
dst = torch.device(device_str)
|
|
|
|
if src == dst:
|
|
return t
|
|
elif src.type == "cuda" and dst.type == "cpu":
|
|
return torch.ops._caffe2.CopyGPUToCPU(t)
|
|
elif src.type == "cpu" and dst.type == "cuda":
|
|
return torch.ops._caffe2.CopyCPUToGPU(t)
|
|
else:
|
|
raise RuntimeError("Can't cast tensor from device {} to device {}".format(src, dst))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def BilinearInterpolation(tensor_in, up_scale):
|
|
assert up_scale % 2 == 0, "Scale should be even"
|
|
|
|
def upsample_filt(size):
|
|
factor = (size + 1) // 2
|
|
if size % 2 == 1:
|
|
center = factor - 1
|
|
else:
|
|
center = factor - 0.5
|
|
|
|
og = np.ogrid[:size, :size]
|
|
return (1 - abs(og[0] - center) / factor) * (1 - abs(og[1] - center) / factor)
|
|
|
|
kernel_size = int(up_scale) * 2
|
|
bil_filt = upsample_filt(kernel_size)
|
|
|
|
dim = int(tensor_in.shape[1])
|
|
kernel = np.zeros((dim, dim, kernel_size, kernel_size), dtype=np.float32)
|
|
kernel[range(dim), range(dim), :, :] = bil_filt
|
|
|
|
tensor_out = F.conv_transpose2d(
|
|
tensor_in,
|
|
weight=to_device(torch.Tensor(kernel), tensor_in.device),
|
|
bias=None,
|
|
stride=int(up_scale),
|
|
padding=int(up_scale / 2),
|
|
)
|
|
|
|
return tensor_out
|
|
|
|
|
|
|
|
|
|
|
|
def onnx_compatibale_interpolate(
|
|
input, size=None, scale_factor=None, mode="nearest", align_corners=None
|
|
):
|
|
|
|
|
|
if size is None and scale_factor is not None:
|
|
if input.dim() == 4:
|
|
if isinstance(scale_factor, (int, float)):
|
|
height_scale, width_scale = (scale_factor, scale_factor)
|
|
else:
|
|
assert isinstance(scale_factor, (tuple, list))
|
|
assert len(scale_factor) == 2
|
|
height_scale, width_scale = scale_factor
|
|
|
|
assert not align_corners, "No matching C2 op for align_corners == True"
|
|
if mode == "nearest":
|
|
return torch.ops._caffe2.ResizeNearest(
|
|
input, order="NCHW", width_scale=width_scale, height_scale=height_scale
|
|
)
|
|
elif mode == "bilinear":
|
|
logger.warning(
|
|
"Use F.conv_transpose2d for bilinear interpolate"
|
|
" because there's no such C2 op, this may cause significant"
|
|
" slowdown and the boundary pixels won't be as same as"
|
|
" using F.interpolate due to padding."
|
|
)
|
|
assert height_scale == width_scale
|
|
return BilinearInterpolation(input, up_scale=height_scale)
|
|
logger.warning("Output size is not static, it might cause ONNX conversion issue")
|
|
|
|
return interp(input, size, scale_factor, mode, align_corners)
|
|
|
|
|
|
def mock_torch_nn_functional_interpolate():
|
|
def decorator(func):
|
|
@functools.wraps(func)
|
|
def _mock_torch_nn_functional_interpolate(*args, **kwargs):
|
|
if torch.onnx.is_in_onnx_export():
|
|
with mock.patch(
|
|
"torch.nn.functional.interpolate", side_effect=onnx_compatibale_interpolate
|
|
):
|
|
return func(*args, **kwargs)
|
|
else:
|
|
return func(*args, **kwargs)
|
|
|
|
return _mock_torch_nn_functional_interpolate
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
|
|
|
class ScopedWS:
|
|
def __init__(self, ws_name, is_reset, is_cleanup=False):
|
|
self.ws_name = ws_name
|
|
self.is_reset = is_reset
|
|
self.is_cleanup = is_cleanup
|
|
self.org_ws = ""
|
|
|
|
def __enter__(self):
|
|
self.org_ws = workspace.CurrentWorkspace()
|
|
if self.ws_name is not None:
|
|
workspace.SwitchWorkspace(self.ws_name, True)
|
|
if self.is_reset:
|
|
workspace.ResetWorkspace()
|
|
|
|
return workspace
|
|
|
|
def __exit__(self, *args):
|
|
if self.is_cleanup:
|
|
workspace.ResetWorkspace()
|
|
if self.ws_name is not None:
|
|
workspace.SwitchWorkspace(self.org_ws)
|
|
|
|
|
|
def fetch_any_blob(name):
|
|
bb = None
|
|
try:
|
|
bb = workspace.FetchBlob(name)
|
|
except TypeError:
|
|
bb = workspace.FetchInt8Blob(name)
|
|
except Exception as e:
|
|
logger.error("Get blob {} error: {}".format(name, e))
|
|
|
|
return bb
|
|
|
|
|
|
|
|
|
|
|
|
def get_pb_arg(pb, arg_name):
|
|
for x in pb.arg:
|
|
if x.name == arg_name:
|
|
return x
|
|
return None
|
|
|
|
|
|
def get_pb_arg_valf(pb, arg_name, default_val):
|
|
arg = get_pb_arg(pb, arg_name)
|
|
return arg.f if arg is not None else default_val
|
|
|
|
|
|
def get_pb_arg_floats(pb, arg_name, default_val):
|
|
arg = get_pb_arg(pb, arg_name)
|
|
return list(map(float, arg.floats)) if arg is not None else default_val
|
|
|
|
|
|
def get_pb_arg_ints(pb, arg_name, default_val):
|
|
arg = get_pb_arg(pb, arg_name)
|
|
return list(map(int, arg.ints)) if arg is not None else default_val
|
|
|
|
|
|
def get_pb_arg_vali(pb, arg_name, default_val):
|
|
arg = get_pb_arg(pb, arg_name)
|
|
return arg.i if arg is not None else default_val
|
|
|
|
|
|
def get_pb_arg_vals(pb, arg_name, default_val):
|
|
arg = get_pb_arg(pb, arg_name)
|
|
return arg.s if arg is not None else default_val
|
|
|
|
|
|
def get_pb_arg_valstrings(pb, arg_name, default_val):
|
|
arg = get_pb_arg(pb, arg_name)
|
|
return list(arg.strings) if arg is not None else default_val
|
|
|
|
|
|
def check_set_pb_arg(pb, arg_name, arg_attr, arg_value, allow_override=False):
|
|
arg = get_pb_arg(pb, arg_name)
|
|
if arg is None:
|
|
arg = putils.MakeArgument(arg_name, arg_value)
|
|
assert hasattr(arg, arg_attr)
|
|
pb.arg.extend([arg])
|
|
if allow_override and getattr(arg, arg_attr) != arg_value:
|
|
logger.warning(
|
|
"Override argument {}: {} -> {}".format(arg_name, getattr(arg, arg_attr), arg_value)
|
|
)
|
|
setattr(arg, arg_attr, arg_value)
|
|
else:
|
|
assert arg is not None
|
|
assert getattr(arg, arg_attr) == arg_value, "Existing value {}, new value {}".format(
|
|
getattr(arg, arg_attr), arg_value
|
|
)
|
|
|
|
|
|
def _create_const_fill_op_from_numpy(name, tensor, device_option=None):
|
|
assert type(tensor) == np.ndarray
|
|
kTypeNameMapper = {
|
|
np.dtype("float32"): "GivenTensorFill",
|
|
np.dtype("int32"): "GivenTensorIntFill",
|
|
np.dtype("int64"): "GivenTensorInt64Fill",
|
|
np.dtype("uint8"): "GivenTensorStringFill",
|
|
}
|
|
|
|
args_dict = {}
|
|
if tensor.dtype == np.dtype("uint8"):
|
|
args_dict.update({"values": [str(tensor.data)], "shape": [1]})
|
|
else:
|
|
args_dict.update({"values": tensor, "shape": tensor.shape})
|
|
|
|
if device_option is not None:
|
|
args_dict["device_option"] = device_option
|
|
|
|
return core.CreateOperator(kTypeNameMapper[tensor.dtype], [], [name], **args_dict)
|
|
|
|
|
|
def _create_const_fill_op_from_c2_int8_tensor(name, int8_tensor):
|
|
assert type(int8_tensor) == workspace.Int8Tensor
|
|
kTypeNameMapper = {
|
|
np.dtype("int32"): "Int8GivenIntTensorFill",
|
|
np.dtype("uint8"): "Int8GivenTensorFill",
|
|
}
|
|
|
|
tensor = int8_tensor.data
|
|
assert tensor.dtype in [np.dtype("uint8"), np.dtype("int32")]
|
|
values = tensor.tobytes() if tensor.dtype == np.dtype("uint8") else tensor
|
|
|
|
return core.CreateOperator(
|
|
kTypeNameMapper[tensor.dtype],
|
|
[],
|
|
[name],
|
|
values=values,
|
|
shape=tensor.shape,
|
|
Y_scale=int8_tensor.scale,
|
|
Y_zero_point=int8_tensor.zero_point,
|
|
)
|
|
|
|
|
|
def create_const_fill_op(
|
|
name: str,
|
|
blob: Union[np.ndarray, workspace.Int8Tensor],
|
|
device_option: Optional[caffe2_pb2.DeviceOption] = None,
|
|
) -> caffe2_pb2.OperatorDef:
|
|
"""
|
|
Given a blob object, return the Caffe2 operator that creates this blob
|
|
as constant. Currently support NumPy tensor and Caffe2 Int8Tensor.
|
|
"""
|
|
|
|
tensor_type = type(blob)
|
|
assert tensor_type in [
|
|
np.ndarray,
|
|
workspace.Int8Tensor,
|
|
], 'Error when creating const fill op for "{}", unsupported blob type: {}'.format(
|
|
name, type(blob)
|
|
)
|
|
|
|
if tensor_type == np.ndarray:
|
|
return _create_const_fill_op_from_numpy(name, blob, device_option)
|
|
elif tensor_type == workspace.Int8Tensor:
|
|
assert device_option is None
|
|
return _create_const_fill_op_from_c2_int8_tensor(name, blob)
|
|
|
|
|
|
def construct_init_net_from_params(
|
|
params: Dict[str, Any], device_options: Optional[Dict[str, caffe2_pb2.DeviceOption]] = None
|
|
) -> caffe2_pb2.NetDef:
|
|
"""
|
|
Construct the init_net from params dictionary
|
|
"""
|
|
init_net = caffe2_pb2.NetDef()
|
|
device_options = device_options or {}
|
|
for name, blob in params.items():
|
|
if isinstance(blob, str):
|
|
logger.warning(
|
|
(
|
|
"Blob {} with type {} is not supported in generating init net,"
|
|
" skipped.".format(name, type(blob))
|
|
)
|
|
)
|
|
continue
|
|
init_net.op.extend(
|
|
[create_const_fill_op(name, blob, device_option=device_options.get(name, None))]
|
|
)
|
|
init_net.external_output.append(name)
|
|
return init_net
|
|
|
|
|
|
def get_producer_map(ssa):
|
|
"""
|
|
Return dict from versioned blob to (i, j),
|
|
where i is index of producer op, j is the index of output of that op.
|
|
"""
|
|
producer_map = {}
|
|
for i in range(len(ssa)):
|
|
outputs = ssa[i][1]
|
|
for j, outp in enumerate(outputs):
|
|
producer_map[outp] = (i, j)
|
|
return producer_map
|
|
|
|
|
|
def get_consumer_map(ssa):
|
|
"""
|
|
Return dict from versioned blob to list of (i, j),
|
|
where i is index of consumer op, j is the index of input of that op.
|
|
"""
|
|
consumer_map = collections.defaultdict(list)
|
|
for i in range(len(ssa)):
|
|
inputs = ssa[i][0]
|
|
for j, inp in enumerate(inputs):
|
|
consumer_map[inp].append((i, j))
|
|
return consumer_map
|
|
|
|
|
|
def get_params_from_init_net(
|
|
init_net: caffe2_pb2.NetDef,
|
|
) -> [Dict[str, Any], Dict[str, caffe2_pb2.DeviceOption]]:
|
|
"""
|
|
Take the output blobs from init_net by running it.
|
|
Outputs:
|
|
params: dict from blob name to numpy array
|
|
device_options: dict from blob name to the device option of its creating op
|
|
"""
|
|
|
|
|
|
def _get_device_option(producer_op):
|
|
if producer_op.type == "CopyGPUToCPU":
|
|
return caffe2_pb2.DeviceOption()
|
|
else:
|
|
return producer_op.device_option
|
|
|
|
with ScopedWS("__get_params_from_init_net__", is_reset=True, is_cleanup=True) as ws:
|
|
ws.RunNetOnce(init_net)
|
|
params = {b: fetch_any_blob(b) for b in init_net.external_output}
|
|
ssa, versions = core.get_ssa(init_net)
|
|
producer_map = get_producer_map(ssa)
|
|
device_options = {
|
|
b: _get_device_option(init_net.op[producer_map[(b, versions[b])][0]])
|
|
for b in init_net.external_output
|
|
}
|
|
return params, device_options
|
|
|
|
|
|
def _updater_raise(op, input_types, output_types):
|
|
raise RuntimeError(
|
|
"Failed to apply updater for op {} given input_types {} and"
|
|
" output_types {}".format(op, input_types, output_types)
|
|
)
|
|
|
|
|
|
def _generic_status_identifier(
|
|
predict_net: caffe2_pb2.NetDef,
|
|
status_updater: Callable,
|
|
known_status: Dict[Tuple[str, int], Any],
|
|
) -> Dict[Tuple[str, int], Any]:
|
|
"""
|
|
Statically infer the status of each blob, the status can be such as device type
|
|
(CPU/GPU), layout (NCHW/NHWC), data type (float32/int8), etc. "Blob" here
|
|
is versioned blob (Tuple[str, int]) in the format compatible with ssa.
|
|
Inputs:
|
|
predict_net: the caffe2 network
|
|
status_updater: a callable, given an op and the status of its input/output,
|
|
it returns the updated status of input/output. `None` is used for
|
|
representing unknown status.
|
|
known_status: a dict containing known status, used as initialization.
|
|
Outputs:
|
|
A dict mapping from versioned blob to its status
|
|
"""
|
|
ssa, versions = core.get_ssa(predict_net)
|
|
versioned_ext_input = [(b, 0) for b in predict_net.external_input]
|
|
versioned_ext_output = [(b, versions[b]) for b in predict_net.external_output]
|
|
all_versioned_blobs = set().union(*[set(x[0] + x[1]) for x in ssa])
|
|
|
|
allowed_vbs = all_versioned_blobs.union(versioned_ext_input).union(versioned_ext_output)
|
|
assert all(k in allowed_vbs for k in known_status)
|
|
assert all(v is not None for v in known_status.values())
|
|
_known_status = copy.deepcopy(known_status)
|
|
|
|
def _check_and_update(key, value):
|
|
assert value is not None
|
|
if key in _known_status:
|
|
if not _known_status[key] == value:
|
|
raise RuntimeError(
|
|
"Confilict status for {}, existing status {}, new status {}".format(
|
|
key, _known_status[key], value
|
|
)
|
|
)
|
|
_known_status[key] = value
|
|
|
|
def _update_i(op, ssa_i):
|
|
versioned_inputs = ssa_i[0]
|
|
versioned_outputs = ssa_i[1]
|
|
|
|
inputs_status = [_known_status.get(b, None) for b in versioned_inputs]
|
|
outputs_status = [_known_status.get(b, None) for b in versioned_outputs]
|
|
|
|
new_inputs_status, new_outputs_status = status_updater(op, inputs_status, outputs_status)
|
|
|
|
for versioned_blob, status in zip(
|
|
versioned_inputs + versioned_outputs, new_inputs_status + new_outputs_status
|
|
):
|
|
if status is not None:
|
|
_check_and_update(versioned_blob, status)
|
|
|
|
for op, ssa_i in zip(predict_net.op, ssa):
|
|
_update_i(op, ssa_i)
|
|
for op, ssa_i in zip(reversed(predict_net.op), reversed(ssa)):
|
|
_update_i(op, ssa_i)
|
|
|
|
|
|
|
|
|
|
for k in all_versioned_blobs:
|
|
if k not in _known_status:
|
|
raise NotImplementedError(
|
|
"Can not infer the status for {}. Currently only support the case where"
|
|
" a single forward and backward pass can identify status for all blobs.".format(k)
|
|
)
|
|
|
|
return _known_status
|
|
|
|
|
|
def infer_device_type(
|
|
predict_net: caffe2_pb2.NetDef,
|
|
known_status: Dict[Tuple[str, int], Any],
|
|
device_name_style: str = "caffe2",
|
|
) -> Dict[Tuple[str, int], str]:
|
|
"""Return the device type ("cpu" or "gpu"/"cuda") of each (versioned) blob"""
|
|
|
|
assert device_name_style in ["caffe2", "pytorch"]
|
|
_CPU_STR = "cpu"
|
|
_GPU_STR = "gpu" if device_name_style == "caffe2" else "cuda"
|
|
|
|
def _copy_cpu_to_gpu_updater(op, input_types, output_types):
|
|
if input_types[0] == _GPU_STR or output_types[0] == _CPU_STR:
|
|
_updater_raise(op, input_types, output_types)
|
|
return ([_CPU_STR], [_GPU_STR])
|
|
|
|
def _copy_gpu_to_cpu_updater(op, input_types, output_types):
|
|
if input_types[0] == _CPU_STR or output_types[0] == _GPU_STR:
|
|
_updater_raise(op, input_types, output_types)
|
|
return ([_GPU_STR], [_CPU_STR])
|
|
|
|
def _other_ops_updater(op, input_types, output_types):
|
|
non_none_types = [x for x in input_types + output_types if x is not None]
|
|
if len(non_none_types) > 0:
|
|
the_type = non_none_types[0]
|
|
if not all(x == the_type for x in non_none_types):
|
|
_updater_raise(op, input_types, output_types)
|
|
else:
|
|
the_type = None
|
|
return ([the_type for _ in op.input], [the_type for _ in op.output])
|
|
|
|
def _device_updater(op, *args, **kwargs):
|
|
return {
|
|
"CopyCPUToGPU": _copy_cpu_to_gpu_updater,
|
|
"CopyGPUToCPU": _copy_gpu_to_cpu_updater,
|
|
}.get(op.type, _other_ops_updater)(op, *args, **kwargs)
|
|
|
|
return _generic_status_identifier(predict_net, _device_updater, known_status)
|
|
|
|
|
|
|
|
|
|
|
|
def _modify_blob_names(ops, blob_rename_f):
|
|
ret = []
|
|
|
|
def _replace_list(blob_list, replaced_list):
|
|
del blob_list[:]
|
|
blob_list.extend(replaced_list)
|
|
|
|
for x in ops:
|
|
cur = copy.deepcopy(x)
|
|
_replace_list(cur.input, list(map(blob_rename_f, cur.input)))
|
|
_replace_list(cur.output, list(map(blob_rename_f, cur.output)))
|
|
ret.append(cur)
|
|
|
|
return ret
|
|
|
|
|
|
def _rename_blob(name, blob_sizes, blob_ranges):
|
|
def _list_to_str(bsize):
|
|
ret = ", ".join([str(x) for x in bsize])
|
|
ret = "[" + ret + "]"
|
|
return ret
|
|
|
|
ret = name
|
|
if blob_sizes is not None and name in blob_sizes:
|
|
ret += "\n" + _list_to_str(blob_sizes[name])
|
|
if blob_ranges is not None and name in blob_ranges:
|
|
ret += "\n" + _list_to_str(blob_ranges[name])
|
|
|
|
return ret
|
|
|
|
|
|
|
|
def save_graph(net, file_name, graph_name="net", op_only=True, blob_sizes=None, blob_ranges=None):
|
|
blob_rename_f = functools.partial(_rename_blob, blob_sizes=blob_sizes, blob_ranges=blob_ranges)
|
|
return save_graph_base(net, file_name, graph_name, op_only, blob_rename_f)
|
|
|
|
|
|
def save_graph_base(net, file_name, graph_name="net", op_only=True, blob_rename_func=None):
|
|
graph = None
|
|
ops = net.op
|
|
if blob_rename_func is not None:
|
|
ops = _modify_blob_names(ops, blob_rename_func)
|
|
if not op_only:
|
|
graph = net_drawer.GetPydotGraph(ops, graph_name, rankdir="TB")
|
|
else:
|
|
graph = net_drawer.GetPydotGraphMinimal(
|
|
ops, graph_name, rankdir="TB", minimal_dependency=True
|
|
)
|
|
|
|
try:
|
|
par_dir = os.path.dirname(file_name)
|
|
if not os.path.exists(par_dir):
|
|
os.makedirs(par_dir)
|
|
|
|
format = os.path.splitext(os.path.basename(file_name))[-1]
|
|
if format == ".png":
|
|
graph.write_png(file_name)
|
|
elif format == ".pdf":
|
|
graph.write_pdf(file_name)
|
|
elif format == ".svg":
|
|
graph.write_svg(file_name)
|
|
else:
|
|
print("Incorrect format {}".format(format))
|
|
except Exception as e:
|
|
print("Error when writing graph to image {}".format(e))
|
|
|
|
return graph
|
|
|
|
|
|
|
|
|
|
|
|
def group_norm_replace_aten_with_caffe2(predict_net: caffe2_pb2.NetDef):
|
|
"""
|
|
For ONNX exported model, GroupNorm will be represented as ATen op,
|
|
this can be a drop in replacement from ATen to GroupNorm
|
|
"""
|
|
count = 0
|
|
for op in predict_net.op:
|
|
if op.type == "ATen":
|
|
op_name = get_pb_arg_vals(op, "operator", None)
|
|
if op_name and op_name.decode() == "group_norm":
|
|
op.arg.remove(get_pb_arg(op, "operator"))
|
|
|
|
if get_pb_arg_vali(op, "cudnn_enabled", None):
|
|
op.arg.remove(get_pb_arg(op, "cudnn_enabled"))
|
|
|
|
num_groups = get_pb_arg_vali(op, "num_groups", None)
|
|
if num_groups is not None:
|
|
op.arg.remove(get_pb_arg(op, "num_groups"))
|
|
check_set_pb_arg(op, "group", "i", num_groups)
|
|
|
|
op.type = "GroupNorm"
|
|
count += 1
|
|
if count > 1:
|
|
logger.info("Replaced {} ATen operator to GroupNormOp".format(count))
|
|
|
|
|
|
|
|
|
|
|
|
def alias(x, name, is_backward=False):
|
|
if not torch.onnx.is_in_onnx_export():
|
|
return x
|
|
assert isinstance(x, torch.Tensor)
|
|
return torch.ops._caffe2.AliasWithName(x, name, is_backward=is_backward)
|
|
|
|
|
|
def fuse_alias_placeholder(predict_net, init_net):
|
|
"""Remove AliasWithName placeholder and rename the input/output of it"""
|
|
|
|
for i, op in enumerate(predict_net.op):
|
|
if op.type == "AliasWithName":
|
|
assert len(op.input) == 1
|
|
assert len(op.output) == 1
|
|
name = get_pb_arg_vals(op, "name", None).decode()
|
|
is_backward = bool(get_pb_arg_vali(op, "is_backward", 0))
|
|
rename_op_input(predict_net, init_net, i, 0, name, from_producer=is_backward)
|
|
rename_op_output(predict_net, i, 0, name)
|
|
|
|
|
|
new_ops = []
|
|
for op in predict_net.op:
|
|
if op.type != "AliasWithName":
|
|
new_ops.append(op)
|
|
else:
|
|
|
|
assert op.input == op.output
|
|
assert op.input[0] == op.arg[0].s.decode()
|
|
del predict_net.op[:]
|
|
predict_net.op.extend(new_ops)
|
|
|
|
|
|
|
|
|
|
|
|
class IllegalGraphTransformError(ValueError):
|
|
"""When a graph transform function call can't be executed."""
|
|
|
|
|
|
def _rename_versioned_blob_in_proto(
|
|
proto: caffe2_pb2.NetDef,
|
|
old_name: str,
|
|
new_name: str,
|
|
version: int,
|
|
ssa: List[Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]],
|
|
start_versions: Dict[str, int],
|
|
end_versions: Dict[str, int],
|
|
):
|
|
"""In given proto, rename all blobs with matched version"""
|
|
|
|
for op, i_th_ssa in zip(proto.op, ssa):
|
|
versioned_inputs, versioned_outputs = i_th_ssa
|
|
for i in range(len(op.input)):
|
|
if versioned_inputs[i] == (old_name, version):
|
|
op.input[i] = new_name
|
|
for i in range(len(op.output)):
|
|
if versioned_outputs[i] == (old_name, version):
|
|
op.output[i] = new_name
|
|
|
|
if start_versions.get(old_name, 0) == version:
|
|
for i in range(len(proto.external_input)):
|
|
if proto.external_input[i] == old_name:
|
|
proto.external_input[i] = new_name
|
|
|
|
if end_versions.get(old_name, 0) == version:
|
|
for i in range(len(proto.external_output)):
|
|
if proto.external_output[i] == old_name:
|
|
proto.external_output[i] = new_name
|
|
|
|
|
|
def rename_op_input(
|
|
predict_net: caffe2_pb2.NetDef,
|
|
init_net: caffe2_pb2.NetDef,
|
|
op_id: int,
|
|
input_id: int,
|
|
new_name: str,
|
|
from_producer: bool = False,
|
|
):
|
|
"""
|
|
Rename the op_id-th operator in predict_net, change it's input_id-th input's
|
|
name to the new_name. It also does automatic re-route and change
|
|
external_input and init_net if necessary.
|
|
- It requires the input is only consumed by this op.
|
|
- This function modifies predict_net and init_net in-place.
|
|
- When from_producer is enable, this also updates other operators that consumes
|
|
the same input. Be cautious because may trigger unintended behavior.
|
|
"""
|
|
assert isinstance(predict_net, caffe2_pb2.NetDef)
|
|
assert isinstance(init_net, caffe2_pb2.NetDef)
|
|
|
|
init_net_ssa, init_net_versions = core.get_ssa(init_net)
|
|
predict_net_ssa, predict_net_versions = core.get_ssa(
|
|
predict_net, copy.deepcopy(init_net_versions)
|
|
)
|
|
|
|
versioned_inputs, versioned_outputs = predict_net_ssa[op_id]
|
|
old_name, version = versioned_inputs[input_id]
|
|
|
|
if from_producer:
|
|
producer_map = get_producer_map(predict_net_ssa)
|
|
if not (old_name, version) in producer_map:
|
|
raise NotImplementedError(
|
|
"Can't find producer, the input {} is probably from"
|
|
" init_net, this is not supported yet.".format(old_name)
|
|
)
|
|
producer = producer_map[(old_name, version)]
|
|
rename_op_output(predict_net, producer[0], producer[1], new_name)
|
|
return
|
|
|
|
def contain_targets(op_ssa):
|
|
return (old_name, version) in op_ssa[0]
|
|
|
|
is_consumer = [contain_targets(op_ssa) for op_ssa in predict_net_ssa]
|
|
if sum(is_consumer) > 1:
|
|
raise IllegalGraphTransformError(
|
|
(
|
|
"Input '{}' of operator(#{}) are consumed by other ops, please use"
|
|
+ " rename_op_output on the producer instead. Offending op: \n{}"
|
|
).format(old_name, op_id, predict_net.op[op_id])
|
|
)
|
|
|
|
|
|
_rename_versioned_blob_in_proto(
|
|
init_net, old_name, new_name, version, init_net_ssa, {}, init_net_versions
|
|
)
|
|
|
|
_rename_versioned_blob_in_proto(
|
|
predict_net,
|
|
old_name,
|
|
new_name,
|
|
version,
|
|
predict_net_ssa,
|
|
init_net_versions,
|
|
predict_net_versions,
|
|
)
|
|
|
|
|
|
def rename_op_output(predict_net: caffe2_pb2.NetDef, op_id: int, output_id: int, new_name: str):
|
|
"""
|
|
Rename the op_id-th operator in predict_net, change it's output_id-th input's
|
|
name to the new_name. It also does automatic re-route and change
|
|
external_output and if necessary.
|
|
- It allows multiple consumers of its output.
|
|
- This function modifies predict_net in-place, doesn't need init_net.
|
|
"""
|
|
assert isinstance(predict_net, caffe2_pb2.NetDef)
|
|
|
|
ssa, blob_versions = core.get_ssa(predict_net)
|
|
|
|
versioned_inputs, versioned_outputs = ssa[op_id]
|
|
old_name, version = versioned_outputs[output_id]
|
|
|
|
|
|
_rename_versioned_blob_in_proto(
|
|
predict_net, old_name, new_name, version, ssa, {}, blob_versions
|
|
)
|
|
|
|
|
|
def get_sub_graph_external_input_output(
|
|
predict_net: caffe2_pb2.NetDef, sub_graph_op_indices: List[int]
|
|
) -> Tuple[List[Tuple[str, int]], List[Tuple[str, int]]]:
|
|
"""
|
|
Return the list of external input/output of sub-graph,
|
|
each element is tuple of the name and corresponding version in predict_net.
|
|
|
|
external input/output is defined the same way as caffe2 NetDef.
|
|
"""
|
|
ssa, versions = core.get_ssa(predict_net)
|
|
|
|
all_inputs = []
|
|
all_outputs = []
|
|
for op_id in sub_graph_op_indices:
|
|
all_inputs += [inp for inp in ssa[op_id][0] if inp not in all_inputs]
|
|
all_outputs += list(ssa[op_id][1])
|
|
|
|
|
|
|
|
ext_inputs = [inp for inp in all_inputs if inp not in all_outputs]
|
|
|
|
|
|
|
|
all_other_inputs = sum(
|
|
(ssa[i][0] for i in range(len(ssa)) if i not in sub_graph_op_indices),
|
|
[(outp, versions[outp]) for outp in predict_net.external_output],
|
|
)
|
|
ext_outputs = [outp for outp in all_outputs if outp in set(all_other_inputs)]
|
|
|
|
return ext_inputs, ext_outputs
|
|
|
|
|
|
class DiGraph:
|
|
"""A DAG representation of caffe2 graph, each vertice is a versioned blob."""
|
|
|
|
def __init__(self):
|
|
self.vertices = set()
|
|
self.graph = collections.defaultdict(list)
|
|
|
|
def add_edge(self, u, v):
|
|
self.graph[u].append(v)
|
|
self.vertices.add(u)
|
|
self.vertices.add(v)
|
|
|
|
|
|
def get_all_paths(self, s, d):
|
|
visited = {k: False for k in self.vertices}
|
|
path = []
|
|
all_paths = []
|
|
|
|
def _get_all_paths_util(graph, u, d, visited, path):
|
|
visited[u] = True
|
|
path.append(u)
|
|
if u == d:
|
|
all_paths.append(copy.deepcopy(path))
|
|
else:
|
|
for i in graph[u]:
|
|
if not visited[i]:
|
|
_get_all_paths_util(graph, i, d, visited, path)
|
|
path.pop()
|
|
visited[u] = False
|
|
|
|
_get_all_paths_util(self.graph, s, d, visited, path)
|
|
return all_paths
|
|
|
|
@staticmethod
|
|
def from_ssa(ssa):
|
|
graph = DiGraph()
|
|
for op_id in range(len(ssa)):
|
|
for inp in ssa[op_id][0]:
|
|
for outp in ssa[op_id][1]:
|
|
graph.add_edge(inp, outp)
|
|
return graph
|
|
|
|
|
|
def _get_dependency_chain(ssa, versioned_target, versioned_source):
|
|
"""
|
|
Return the index list of relevant operator to produce target blob from source blob,
|
|
if there's no dependency, return empty list.
|
|
"""
|
|
|
|
|
|
|
|
|
|
consumer_map = get_consumer_map(ssa)
|
|
producer_map = get_producer_map(ssa)
|
|
start_op = min(x[0] for x in consumer_map[versioned_source]) - 15
|
|
end_op = (
|
|
producer_map[versioned_target][0] + 15 if versioned_target in producer_map else start_op
|
|
)
|
|
sub_graph_ssa = ssa[start_op : end_op + 1]
|
|
if len(sub_graph_ssa) > 30:
|
|
logger.warning(
|
|
"Subgraph bebetween {} and {} is large (from op#{} to op#{}), it"
|
|
" might take non-trival time to find all paths between them.".format(
|
|
versioned_source, versioned_target, start_op, end_op
|
|
)
|
|
)
|
|
|
|
dag = DiGraph.from_ssa(sub_graph_ssa)
|
|
paths = dag.get_all_paths(versioned_source, versioned_target)
|
|
ops_in_paths = [[producer_map[blob][0] for blob in path[1:]] for path in paths]
|
|
return sorted(set().union(*[set(ops) for ops in ops_in_paths]))
|
|
|
|
|
|
def identify_reshape_sub_graph(predict_net: caffe2_pb2.NetDef) -> List[List[int]]:
|
|
"""
|
|
Idenfity the reshape sub-graph in a protobuf.
|
|
The reshape sub-graph is defined as matching the following pattern:
|
|
|
|
(input_blob) -> Op_1 -> ... -> Op_N -> (new_shape) -ββ
|
|
β-------------------------------------------> Reshape -> (output_blob)
|
|
|
|
Return:
|
|
List of sub-graphs, each sub-graph is represented as a list of indices
|
|
of the relavent ops, [Op_1, Op_2, ..., Op_N, Reshape]
|
|
"""
|
|
|
|
ssa, _ = core.get_ssa(predict_net)
|
|
|
|
ret = []
|
|
for i, op in enumerate(predict_net.op):
|
|
if op.type == "Reshape":
|
|
assert len(op.input) == 2
|
|
input_ssa = ssa[i][0]
|
|
data_source = input_ssa[0]
|
|
shape_source = input_ssa[1]
|
|
op_indices = _get_dependency_chain(ssa, shape_source, data_source)
|
|
ret.append(op_indices + [i])
|
|
return ret
|
|
|
|
|
|
def remove_reshape_for_fc(predict_net, params):
|
|
"""
|
|
In PyTorch nn.Linear has to take 2D tensor, this often leads to reshape
|
|
a 4D tensor to 2D by calling .view(). However this (dynamic) reshaping
|
|
doesn't work well with ONNX and Int8 tools, and cause using extra
|
|
ops (eg. ExpandDims) that might not be available on mobile.
|
|
Luckily Caffe2 supports 4D tensor for FC, so we can remove those reshape
|
|
after exporting ONNX model.
|
|
"""
|
|
from caffe2.python import core
|
|
|
|
|
|
|
|
|
|
|
|
reshape_sub_graphs = identify_reshape_sub_graph(predict_net)
|
|
sub_graphs_to_remove = []
|
|
for reshape_sub_graph in reshape_sub_graphs:
|
|
reshape_op_id = reshape_sub_graph[-1]
|
|
assert predict_net.op[reshape_op_id].type == "Reshape"
|
|
ssa, _ = core.get_ssa(predict_net)
|
|
reshape_output = ssa[reshape_op_id][1][0]
|
|
consumers = [i for i in range(len(ssa)) if reshape_output in ssa[i][0]]
|
|
if all(predict_net.op[consumer].type == "FC" for consumer in consumers):
|
|
|
|
|
|
ext_inputs, ext_outputs = get_sub_graph_external_input_output(
|
|
predict_net, reshape_sub_graph
|
|
)
|
|
non_params_ext_inputs = [inp for inp in ext_inputs if inp[1] != 0]
|
|
if len(non_params_ext_inputs) == 1 and len(ext_outputs) == 1:
|
|
sub_graphs_to_remove.append(reshape_sub_graph)
|
|
|
|
|
|
|
|
|
|
|
|
remove_op_ids = []
|
|
params_to_remove = []
|
|
for sub_graph in sub_graphs_to_remove:
|
|
logger.info(
|
|
"Remove Reshape sub-graph:\n{}".format(
|
|
"".join(["(#{:>4})\n{}".format(i, predict_net.op[i]) for i in sub_graph])
|
|
)
|
|
)
|
|
reshape_op_id = sub_graph[-1]
|
|
new_reshap_output = predict_net.op[reshape_op_id].input[0]
|
|
rename_op_output(predict_net, reshape_op_id, 0, new_reshap_output)
|
|
ext_inputs, ext_outputs = get_sub_graph_external_input_output(predict_net, sub_graph)
|
|
non_params_ext_inputs = [inp for inp in ext_inputs if inp[1] != 0]
|
|
params_ext_inputs = [inp for inp in ext_inputs if inp[1] == 0]
|
|
assert len(non_params_ext_inputs) == 1 and len(ext_outputs) == 1
|
|
assert ext_outputs[0][0] == non_params_ext_inputs[0][0]
|
|
assert ext_outputs[0][1] == non_params_ext_inputs[0][1] + 1
|
|
remove_op_ids.extend(sub_graph)
|
|
params_to_remove.extend(params_ext_inputs)
|
|
|
|
predict_net = copy.deepcopy(predict_net)
|
|
new_ops = [op for i, op in enumerate(predict_net.op) if i not in remove_op_ids]
|
|
del predict_net.op[:]
|
|
predict_net.op.extend(new_ops)
|
|
for versioned_params in params_to_remove:
|
|
name = versioned_params[0]
|
|
logger.info("Remove params: {} from init_net and predict_net.external_input".format(name))
|
|
del params[name]
|
|
predict_net.external_input.remove(name)
|
|
|
|
return predict_net, params
|
|
|
|
|
|
def fuse_copy_between_cpu_and_gpu(predict_net: caffe2_pb2.NetDef):
|
|
"""
|
|
In-place fuse extra copy ops between cpu/gpu for the following case:
|
|
a -CopyAToB-> b -CopyBToA> c1 -NextOp1-> d1
|
|
-CopyBToA> c2 -NextOp2-> d2
|
|
The fused network will look like:
|
|
a -NextOp1-> d1
|
|
-NextOp2-> d2
|
|
"""
|
|
|
|
_COPY_OPS = ["CopyCPUToGPU", "CopyGPUToCPU"]
|
|
|
|
def _fuse_once(predict_net):
|
|
ssa, blob_versions = core.get_ssa(predict_net)
|
|
consumer_map = get_consumer_map(ssa)
|
|
versioned_external_output = [
|
|
(name, blob_versions[name]) for name in predict_net.external_output
|
|
]
|
|
|
|
for op_id, op in enumerate(predict_net.op):
|
|
if op.type in _COPY_OPS:
|
|
fw_copy_versioned_output = ssa[op_id][1][0]
|
|
consumer_ids = [x[0] for x in consumer_map[fw_copy_versioned_output]]
|
|
reverse_op_type = _COPY_OPS[1 - _COPY_OPS.index(op.type)]
|
|
|
|
is_fusable = (
|
|
len(consumer_ids) > 0
|
|
and fw_copy_versioned_output not in versioned_external_output
|
|
and all(
|
|
predict_net.op[_op_id].type == reverse_op_type
|
|
and ssa[_op_id][1][0] not in versioned_external_output
|
|
for _op_id in consumer_ids
|
|
)
|
|
)
|
|
|
|
if is_fusable:
|
|
for rv_copy_op_id in consumer_ids:
|
|
|
|
rs_copy_versioned_output = ssa[rv_copy_op_id][1][0]
|
|
next_op_id, inp_id = consumer_map[rs_copy_versioned_output][0]
|
|
predict_net.op[next_op_id].input[inp_id] = op.input[0]
|
|
|
|
new_ops = [
|
|
op
|
|
for i, op in enumerate(predict_net.op)
|
|
if i != op_id and i not in consumer_ids
|
|
]
|
|
del predict_net.op[:]
|
|
predict_net.op.extend(new_ops)
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
while _fuse_once(predict_net):
|
|
pass
|
|
|
|
|
|
def remove_dead_end_ops(net_def: caffe2_pb2.NetDef):
|
|
"""remove ops if its output is not used or not in external_output"""
|
|
ssa, versions = core.get_ssa(net_def)
|
|
versioned_external_output = [(name, versions[name]) for name in net_def.external_output]
|
|
consumer_map = get_consumer_map(ssa)
|
|
removed_op_ids = set()
|
|
|
|
def _is_dead_end(versioned_blob):
|
|
return not (
|
|
versioned_blob in versioned_external_output
|
|
or (
|
|
len(consumer_map[versioned_blob]) > 0
|
|
and all(x[0] not in removed_op_ids for x in consumer_map[versioned_blob])
|
|
)
|
|
)
|
|
|
|
for i, ssa_i in reversed(list(enumerate(ssa))):
|
|
versioned_outputs = ssa_i[1]
|
|
if all(_is_dead_end(outp) for outp in versioned_outputs):
|
|
removed_op_ids.add(i)
|
|
|
|
|
|
new_ops = [op for i, op in enumerate(net_def.op) if i not in removed_op_ids]
|
|
del net_def.op[:]
|
|
net_def.op.extend(new_ops)
|
|
|