File size: 33,428 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import queue
import re
from collections import defaultdict
import torch
from torch.utils.tensorboard._pytorch_graph import NodePy, NodePyIO, NodePyOP, GraphPy
CLASSTYPE_KIND = 'ClassType'
GETATTR_KIND = 'prim::GetAttr'
CAT_KIND = 'aten::cat'
LIST_CONSTRUCT_KIND = 'prim::ListConstruct'
LIST_UNPACK_KIND = 'prim::ListUnpack'
TUPLE_CONSTRUCT_KIND = 'prim::TupleConstruct'
TUPLE_UNPACK_KIND = 'prim::TupleUnpack'
_logger = logging.getLogger(__name__)
def build_module_graph(model, dummy_input):
return TorchModuleGraph(model, dummy_input)
def build_graph(model, dummy_input, verbose=False):
g = TorchProtoGraph(model, dummy_input, verbose)
return g.graph_def, g.stepstats
def parse_traced_name(module_name):
prefix = 'TracedModule['
suffix = ']'
if module_name.startswith(prefix) and module_name.endswith(suffix):
module_name = module_name[len(prefix):-len(suffix)]
return module_name
class TorchGraph:
"""
This class is to extract pytorch model topology graph by tracing
"""
def __init__(self, model=None, dummy_input=None, traced_model=None):
"""
Parameters
----------
model : pytorch model
The model user wants to speed up
dummy_input : pytorch tensor
The dummy input for ```jit.trace```, users should put it on right device before pass in
traced_model : torch._C.torch.jit.TopLevelTracedModule
An alredy traced model, if traced_model is not None, then TorchGraph will build the graph
based on this traced model and won't trace the model again.
"""
assert torch.__version__ >= '1.3.1'
# check if the input is legal
if traced_model is not None:
assert isinstance(traced_model, torch.jit.TopLevelTracedModule)
self.trace = traced_model
# it's ok if the graph is already unpacked
torch._C._jit_pass_inline(self.trace.graph)
elif model is not None and dummy_input is not None:
self.bound_model = model
self._trace(model, dummy_input)
else:
raise Exception(
'Please provide model & dummy_input or the traced_model as inputs')
def _trace(self, model, dummy_input):
with torch.onnx.set_training(model, False):
# import torch.jit
self.trace = torch.jit.trace(model, dummy_input, check_trace=False)
torch._C._jit_pass_inline(self.trace.graph)
class TorchProtoGraph(TorchGraph):
"""
Generates model graph for pytorch models in protobuf, this implementation
is borrowed from pytorch v1.4.0, and fixed following issues:
https://github.com/pytorch/pytorch/issues/33691
https://github.com/pytorch/pytorch/issues/33670
"""
def __init__(self, model, dummy_input, verbose=False):
super().__init__(model, dummy_input)
from tensorboard.compat.proto.config_pb2 import RunMetadata
from tensorboard.compat.proto.graph_pb2 import GraphDef
from tensorboard.compat.proto.step_stats_pb2 import StepStats, DeviceStepStats
from tensorboard.compat.proto.versions_pb2 import VersionDef
list_of_nodes = self.parse(self.trace.graph, self.trace, dummy_input)
if verbose:
print(self.trace.graph)
self.stepstats = RunMetadata(step_stats=StepStats(
dev_stats=[DeviceStepStats(device="/device:CPU:0")]))
self.graph_def = GraphDef(
node=list_of_nodes, versions=VersionDef(producer=22))
def parse(self, graph, trace, args=None, omit_useless_nodes=True):
"""This method parses an optimized PyTorch model graph and produces
a list of nodes and node stats for eventual conversion to TensorBoard
protobuf format.
Args:
graph (PyTorch module): The model graph to be parsed.
trace (PyTorch JIT TracedModule): The model trace to be parsed.
args (tuple): input tensor[s] for the model.
omit_useless_nodes (boolean): Whether to remove nodes from the graph.
"""
nodes_py = GraphPy()
for node in graph.inputs():
if omit_useless_nodes:
if not node.uses(): # number of user of the node (= number of outputs/ fanout)
continue
if node.type().kind() != CLASSTYPE_KIND:
nodes_py.append(NodePyIO(node, 'input'))
attr_to_scope = dict()
def node_to_name(d):
return str(d).split(":")[0].strip()
for node in graph.nodes():
if node.kind() == GETATTR_KIND:
attr_name = node.s('name')
node_name = node_to_name(node)
parent = node.input().node()
# If the parent node is not the top-level "self" node
if parent.kind() == GETATTR_KIND:
parent_scope = attr_to_scope[node_to_name(parent)]
attr_scope = parent_scope.split('/')[-1]
attr_to_scope[node_name] = '{}/{}.{}'.format(
parent_scope, attr_scope, attr_name)
else:
attr_to_scope[node_name] = '__module.{}'.format(attr_name)
# We don't need classtype nodes; scope will provide this information
if node.output().type().kind() != CLASSTYPE_KIND:
node_py = NodePyOP(node)
node_py.scopeName = attr_to_scope[node_name]
nodes_py.append(node_py)
else:
nodes_py.append(NodePyOP(node))
# Create sink nodes for output ops
for i, node in enumerate(graph.outputs()):
node_py = NodePyIO(node, 'output')
node_py.debugName = "output.{}".format(i + 1)
node_py.inputs = [node.debugName()]
nodes_py.append(node_py)
alias_to_name = dict()
base_name = parse_traced_name(trace._name)
for name, module in trace.named_modules(prefix='__module'):
mod_name = parse_traced_name(module._name)
attr_name = name.split('.')[-1]
alias_to_name[name] = '{}[{}]'.format(mod_name, attr_name)
for node in nodes_py.nodes_op:
module_aliases = node.scopeName.split('/')[-1].split('.')
module_name = ''
for i, alias in enumerate(module_aliases):
if i == 0:
module_name = alias
node.scopeName = base_name
else:
module_name += '.' + alias
node.scopeName += '/' + \
(alias_to_name[module_name]
if module_name in alias_to_name else alias)
nodes_py.populate_namespace_from_OP_to_IO()
return nodes_py.to_proto()
class NodePyGroup(NodePy):
"""
This class is used to represent a graph node which consists of multiple jit traced nodes. In a pytorch trace graph,
there are multiple nodes are traced for one torch.nn.Module object, we group them together to form a single node to
represent the torch.nn.Module object. We also group some functional call trace nodes together to form a new node.
"""
def __init__(self, name, unique_name, node_type, op_type, node_cpps, inputs=None, outputs=None, key_node=None):
"""
Parameters:
-----------
name: str
node name, such as `conv1`, `backbone.classifier`
unique_name: str
A global unique name for current node. Due to some modules,
such as relu, may be reused several times, so the scopename
is not suitable as the global unique identifier, so we add a
unique_name for each node as the global unique identifier.
We should use the unique_name to traverset the module graph.
node_type: str
`module` or `func`
op_type: str
operation type, such as `Conv2d`, `aten::view`
node_cpps: list of torch._C.Node
jit trace nodes which are included in this new node
inputs: list of str
All the inputs of this node, each element is debugName of one input
outputs: list of str
All the outputs of this node, each element is debugName of one output
key_node: torch._C.Node
The key node of this NodePyGroup.
"""
super(NodePyGroup, self).__init__(name, [])
self.node_cpps = node_cpps
self.name = name
self.unique_name = unique_name
self.op_type = op_type
self.type = node_type
self.nodes = []
self.auxiliary = None
self.add_nodes(node_cpps)
self.inputs = inputs
self.outputs = outputs
# The core node in this NodePyGroup
self.key_node = key_node
def add_nodes(self, node_cpps):
for node_cpp in node_cpps:
nodepy = NodePyOP(node_cpp)
nodepy.name = node_cpp.scopeName() + '_' + node_cpp.kind()
self.nodes.append(nodepy)
def sub_node_names(self):
return [x.name for x in self.nodes]
def __repr__(self):
return 'name: {}, type: {}, op_type: {}, sub_nodes: {}, inputs: {}, outputs: {}, aux: {}'.format(
self.name, self.type, self.op_type, self.sub_node_names(),
self.inputs, self.outputs, self.auxiliary
)
class TorchModuleGraph(TorchGraph):
"""
Generates model graph, each node is created from single or multiple jit trace nodes.
"""
def __init__(self, model=None, dummy_input=None, traced_model=None):
super().__init__(model, dummy_input, traced_model)
self.global_count = 0
self.name_to_node, self.input_to_node, self.output_to_node = self._build_graph()
self._extract_auxiliary_info()
def _expand_key_func_node(self, node, nodes, input_to_node, output_to_node,
module_type):
"""
For trace graph nodes, some nodes are not in modules, these nodes are usually generated by
the functions directly called in module ```forward```. For such nodes, some of them are
trivial op which are label by ```prim::```, some of them are not such ops which is call
non-prim ops. This function is to merge neighbor prim ops to a non-prim op, to construct
a node.
Parameters
----------
node : trace graph node
The non-prim node to expand
nodes : list of trace graph node
All the trace graph nodes within the same scope as the non-prim node
input_to_node : dict
key: input name, value: a node that uses this input
output_to_node : dict
key: output name, value: a node that generates this output
module_type : str
can be 'module' or 'func'
Returns
-------
node
the expanded non-prim node
"""
# TODO: scope name could be empty
node_name = '.'.join([self._get_module_name(
node.scopeName()), node.kind(), str(self.global_count)])
unique_name = node_name
_logger.debug("expand non-prim node, node name: %s", node_name)
self.global_count += 1
op_type = node.kind()
node_group = [node]
inputs = list()
outputs = list()
node_queue = queue.Queue()
node_queue.put(node)
while not node_queue.empty():
curr_node = node_queue.get()
for _input in curr_node.inputs():
input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes:
predecessor_node = output_to_node[input_name]
if not self._is_key_func(predecessor_node):
node_group.append(predecessor_node)
node_queue.put(predecessor_node)
else:
inputs.append(input_name)
else:
inputs.append(input_name)
for output in node.outputs():
outputs.append(output.debugName())
nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=inputs, outputs=outputs, key_node=node)
return nodepy
def _expand_module_node(self, node, node_name, unique_name, op_type, nodes,
input_to_node, output_to_node, module_type):
"""
merge the adjacent nodes of the module. The difference between the
_expand_module_node and _expand_non_prim_node is that, the _expand_non_prim_node
only merge the prim:: nodes into the aten:: node, in contrast,the _expand_module_node
will merge all adjacent nodes into a same nodepy group.
Parameters
----------
node : trace graph node
The non-prim node to expand
node_name : str
specify the node_name for NodePyGroup
unique_name : str
unique_name for the NodePyGroup
op_type : str
specify the op_type for the NodePyGroup
nodes : list of trace graph node
All the trace graph nodes within the same scope as the non-prim node
input_to_node : dict
key: input name, value: a node that uses this input
output_to_node : dict
key: output name, value: a node that generates this output
module_type : str
can be 'module' or 'func'
Returns
-------
node
the expanded non-prim node
"""
_logger.debug("expand module node, node name: %s", node_name)
self.global_count += 1
if not op_type:
op_type = node.kind()
node_group = [node]
inputs = list()
outputs = list()
node_queue = queue.Queue()
node_queue.put(node)
visited = {node}
while not node_queue.empty():
curr_node = node_queue.get()
for _input in curr_node.inputs():
input_name = _input.debugName()
if input_name in output_to_node and output_to_node[input_name] in nodes:
predecessor_node = output_to_node[input_name]
if predecessor_node not in visited:
node_group.append(predecessor_node)
node_queue.put(predecessor_node)
visited.add(predecessor_node)
else:
inputs.append(input_name)
for _output in curr_node.outputs():
output_name = _output.debugName()
if output_name in input_to_node and input_to_node[output_name] in nodes:
successor_node = input_to_node[output_name]
if successor_node not in visited:
node_group.append(successor_node)
node_queue.put(successor_node)
visited.add(successor_node)
else:
outputs.append(output_name)
nodepy = NodePyGroup(node_name, unique_name, module_type, op_type,
node_group, inputs=inputs, outputs=outputs)
return nodepy
def _extract_cat_info(self, node_group, cpp_node):
"""
Extract the detail information of the cat operation,
such the order of the input tensor, the shape of each
input tensor, the output shape, and the cat dimension.
Parameters
----------
node_group : NodePyGroup
cpp_node: torch._C.Node
It should be ```aten::cat``` node
Returns
-------
dict
Include auxiliary information for the cat operation.
This dict objec has four keys: 'cat_dim', 'out_shape',
'in_order' and 'in_shape'. cat_dim is the dimension of
the cat operation to concat the input tensors. out_shape
is the shape of the output tensor of the cat operation.
in_order is an ordered list which contains the corresponding
parent operaion nodes of the input tensors. in_shape is also
an ordered list that contains the input shapes of the input
tensor.
"""
# only suport the cat operation
assert cpp_node.kind() == CAT_KIND
cat_info = {}
# get the shape of the output tensor
t_output = cpp_node.output()
out_shape = t_output.type().sizes()
cat_info['out_shape'] = out_shape
# get the cat dimension
inputs = cpp_node.inputs()
cat_dim = list(inputs)[1].toIValue()
cat_info['cat_dim'] = cat_dim
# get the order of the input tensors
# To get the order of the input tensors, we need
# to be aware of the topology of the model, which
# means we should extract the auxiliary information
# after the build_index function.
input_order = []
list_construct_cpp = list(cpp_node.inputs())[0].node()
input_tensors = list(list_construct_cpp.inputs())
for _tensor in input_tensors:
debug_name = _tensor.debugName()
input_order.append(self.output_to_node[debug_name].unique_name)
cat_info['in_order'] = input_order
input_shapes = [t.type().sizes() for t in input_tensors]
cat_info['in_shape'] = input_shapes
return cat_info
def _extract_linear_shape_info(self, node_group):
"""
Extract linear shape input/output tensor shape info from its aten::addmm op.
Parameters
----------
node_group : NodePyGroup
NodePyGroup object associated with the linear module.
Returns
-------
dict
Include shape of input tensor and shape of output tensor
"""
for cpp_node in node_group.node_cpps:
if cpp_node.kind() == 'aten::addmm':
# https://github.com/pytorch/pytorch/blob/1.6/torch/nn/functional.py#L1682
# inputs of aten::addmm:
# inputs[0] is bias
# inputs[1] is input data
# inputs[2] is weight
t_input = list(cpp_node.inputs())[1]
t_output = cpp_node.output()
assert isinstance(t_input.type(), torch._C.TensorType)
assert isinstance(t_output.type(), torch._C.TensorType)
in_shape = t_input.type().sizes()
out_shape = t_output.type().sizes()
return {'in_shape': in_shape, 'out_shape': out_shape}
return None
def _extract_shape_info(self, node):
"""
Extract the shape information of ```aten::view``` node
Parameters
----------
node : trace graph node
It should be ```aten::view``` node
Returns
-------
dict
Include shape of input tensor and shape of output tensor
"""
t_input = None
for _input in node.inputs():
t_input = _input
break
t_output = node.output()
assert isinstance(t_input.type(), torch._C.TensorType)
assert isinstance(t_output.type(), torch._C.TensorType)
in_shape = t_input.type().sizes()
out_shape = t_output.type().sizes()
return {'in_shape': in_shape, 'out_shape': out_shape}
def _extract_leaf_modules(self):
"""
Extract leaf modules from the given graph. Leaf module means it does not have submodules.
To extract leaf modules because only leaf module can be replaced. And shape inference can
be done in leaf module level. Other shape inference is done in lower level i.e.,
operation level.
Returns
-------
list
a list of scope name of all the leaf modules
"""
def is_parent(name1, name2):
"""
check if name1 is parent node of name2, for example:
name1: aa.bb, name2: aa.bb.cc, return True
name1: aa.b, name2: aa.bb, return False
"""
parts1, parts2 = name1.split('.'), name2.split('.')
if len(parts1) >= len(parts2):
return False
for i, _ in enumerate(parts1):
if parts2[i] != parts1[i]:
return False
return True
module_names = sorted([x[0]
for x in self.trace.named_modules() if x[0]])
leaf_nodes = []
for i, name in enumerate(module_names):
if i + 1 >= len(module_names) or not is_parent(name, module_names[i + 1]):
leaf_nodes.append(name)
return leaf_nodes
def _get_module_name(self, scope_name):
"""
Retrieve module name from scope name.
Parameters:
-----------
scope_name: str
scope_name of a graph node, for example:
for pytorch 1.3.1: MyModel/BackboneModel[backbone]/Conv2d[conv2]
for pytorch 1.4.0: __module.backbone/__module.backbone.conv2
Returns:
-------
str
module name, such as backbone.conv2
"""
if torch.__version__ >= '1.4.0':
return scope_name.split('/')[-1].replace('__module.', '')
else:
return '.'.join(re.findall(r'\[(.*?)\]', scope_name))
def _build_index(self, nodes_op):
name_to_node = dict()
input_to_node = defaultdict(list)
output_to_node = dict()
for node in nodes_op:
name_to_node[node.unique_name] = node
for _input in node.inputs:
input_to_node[_input].append(node)
for output in node.outputs:
assert not output in output_to_node, \
"One output cannot be generated by multiple nodes"
output_to_node[output] = node
return name_to_node, input_to_node, output_to_node
def _is_key_func(self, node_cpp):
"""
Judge if a cpp node is a key function node.
If so, we should not merge this node into the
adjacent node.
"""
if node_cpp.kind().startswith('aten::'):
# the nodes that start with 'aten' are key function
# nodes
return True
if node_cpp.kind() in [LIST_UNPACK_KIND, TUPLE_UNPACK_KIND]:
# We cannot merge the List/Tuple
# Unpack func into other nodes, else it
# may lead to a graph construction error.
# The reason why we donnot take the construct node
# also as a key node is that `cat` operation node need
# the last(previous) visited node to infer the mask. If
# we take the Construct node as the important node, the
# predecessor of the `cat` node will always be a construct
# node, which means we cannot infer the mask for the cat
# operation.
return True
return False
def unpack_manually(self):
"""
Unpack the tensor tuple or tensor list manually,
and remove the ListUnpack/TupleUnpack node from
the graph. Note: this function will change the
graph structure.
"""
if hasattr(self, 'unpacked'):
# if already unpacked the tuple/list manually
return
for node in self.nodes_py.nodes_op:
if node.op_type in [TUPLE_UNPACK_KIND, LIST_UNPACK_KIND]:
unpack_cpp = node.key_node
last_cpp = list(unpack_cpp.inputs())[0].node()
if last_cpp.kind() in [TUPLE_CONSTRUCT_KIND, LIST_CONSTRUCT_KIND]:
# we need check if the tensor tuple or tensor list is produced
# by a list/tuple construct node. If so, we can unpack the tuple
# or list manunally.
_logger.debug('List/Tuple Construct Node(cpp) %s', str(last_cpp))
_logger.debug('List/Tuple Unpack Node(cpp) %s', str(unpack_cpp))
assert len(list(unpack_cpp.outputs())) == len(list(last_cpp.inputs()))
errmsg = '%s Input number: %d if inconsistent with the output number %d' % (unpack_cpp, \
len(node.inputs), len(list(last_cpp.inputs())))
assert len(node.inputs) == len(list(last_cpp.inputs())), errmsg
for _debug_input, _debug_output in zip(node.inputs, node.outputs):
# _debug_input = _input.debugName()
# _debug_output = _output.debugName()
if _debug_input in self.input_to_node and _debug_output in self.input_to_node:
# input_to_node[_debug_input] is a list of NodePyGroup, because
# one tensor can be used as input for multiple nodes at the same time.
# note that, in this case, the construct cpp node and unpack cpp node
# will be merged into the same NodePyGroup, so we remove the `node` from
# input_to_node[_debug_input] and directly connect this tensor to the
# input_to_node[_debug_output]
self.input_to_node[_debug_input].remove(node)
# add the following nodes of _output into the input_to_node[_debug_input]
self.input_to_node[_debug_input].extend(self.input_to_node[_debug_output])
# just remove the _debug_output from the grapgh index. So that we can also skip
# the construct and tuple
if _debug_output in self.input_to_node:
for following_node in self.input_to_node[_debug_output]:
_tmp_index = following_node.inputs.index(_debug_output)
following_node.inputs[_tmp_index] = _debug_input
self.unpacked = True
def _build_graph(self):
"""
Build graph using our defined format from jit trace.
There are basically three steps: first, construct necessary information (data structures),
second, extract all the modules to convert to node, Third, extract all functions to convert
to node.
Returns
-------
dict
use name to index nodes, key: node name, value: node
dict
use input (its name) to index nodes,
key: input, value: list of nodes that take this input
dict
use output (its name) to index nodes,
key: output, value: node that generates this output
"""
omit_useless_nodes = True
graph = self.trace.graph
# _logger.debug(graph)
# build output mapping, from output debugName to its node
output_to_node = {x.debugName(): n for n in graph.nodes()
for x in n.outputs()}
# build input mapping, from input debugName to its node
input_to_node = {x.debugName(): n for n in graph.nodes()
for x in n.inputs()}
# build module mapping, from module name to all nodes (as list) under this module scope
module_to_nodes = defaultdict(list)
# the mapping of function (non-module in forward) to nodes, key is scope name
func_to_nodes = defaultdict(list)
nodes_py = GraphPy()
for node in graph.inputs():
if omit_useless_nodes:
if not node.uses(): # number of user of the node (= number of outputs/ fanout)
continue
if node.type().kind() != 'ClassType':
nodes_py.append(NodePyIO(node, 'input'))
self.leaf_modules = self._extract_leaf_modules()
module_to_type = {name: parse_traced_name(
module._name) for name, module in self.trace.named_modules()}
# associate module name with their trace graph nodes
for node in graph.nodes():
module_name = self._get_module_name(node.scopeName())
if module_name in self.leaf_modules:
module_to_nodes[module_name].append(node)
else:
func_to_nodes[node.scopeName()].append(node)
# build node group for module
for module_name, node_cpps in module_to_nodes.items():
use_count = 0
merged = set()
for node in node_cpps:
if node not in merged:
# modules that have same scope name may have different locations in the
# graph. Futhermore, there are also lots of prim:: nodes that in node_cpps,
# so we also need to call the expand_module_node.
unique_name = module_name
if use_count > 0:
unique_name = module_name + '.%d' % use_count
node_group = self._expand_module_node(
node, module_name, unique_name, module_to_type[module_name],
node_cpps, input_to_node, output_to_node, 'module')
nodes_py.nodes_op.append(node_group)
use_count += 1
merged.update(node_group.node_cpps)
# each scope_name may have multiple funcs, we split them and create node for each of them
# build node group for torch.nn.functional
for _, nodes in func_to_nodes.items():
# extract non prim:: nodes
key_func_nodes = list()
for node in nodes:
if self._is_key_func(node):
# find the key function nodes
key_func_nodes.append(node)
# for each non prim node, expand it
for node in key_func_nodes:
node_group = self._expand_key_func_node(
node, nodes, input_to_node, output_to_node, 'func')
nodes_py.nodes_op.append(node_group)
# get shape infor for view (aten::view) func
# if node_group.op_type in ['aten::view', 'aten::flatten']:
# node_group.auxiliary = self._extract_shape_info(node)
for node in graph.outputs(): # Create sink nodes for output ops
node_py = NodePyIO(node, 'output')
nodes_py.append(node_py)
self.nodes_py = nodes_py
# build index
return self._build_index(self.nodes_py.nodes_op)
def _extract_auxiliary_info(self):
"""
Extract the auxiliary information for the nodegroups
if necessary. For example, view/flatten operations may
need the shape of the input tensor and output tensor.
"""
# extract the input & output shape for the view and flatten
for node_group in self.nodes_py.nodes_op:
if node_group.op_type in ['aten::view', 'aten::flatten', 'aten::mean', 'aten::reshape']:
# get shape infor for view (aten::view) func
cpp_node = list(filter(lambda x: x.kind() == node_group.op_type,
node_group.node_cpps))[0]
node_group.auxiliary = self._extract_shape_info(cpp_node)
elif node_group.op_type == 'Linear':
node_group.auxiliary = self._extract_linear_shape_info(node_group)
elif node_group.op_type == CAT_KIND:
# get the detail information for cat func
cpp_node = list(filter(lambda x: x.kind() == node_group.op_type,
node_group.node_cpps))[0]
node_group.auxiliary = self._extract_cat_info(
node_group, cpp_node)
def find_predecessors(self, unique_name):
"""
Find predecessor node of the given node
Parameters
----------
unique_name : str
The unique name of the node
Returns
-------
list
a list of nodes who are the given node's predecessor
"""
predecessors = []
for _input in self.name_to_node[unique_name].inputs:
if not _input in self.output_to_node:
_logger.debug("cannot find node with %s as its output", _input)
else:
node_py = self.output_to_node[_input]
predecessors.append(node_py.unique_name)
return predecessors
def find_successors(self, unique_name):
"""
Find successor nodes of the given node
Parameters
----------
unique_name : str
The unique name of the node
Returns
-------
list
a list of nodes who are the given node's successor
"""
successors = []
for output in self.name_to_node[unique_name].outputs:
if output not in self.input_to_node:
# may reach the output of the whole graph
continue
nodes_py = self.input_to_node[output]
for node_py in nodes_py:
successors.append(node_py.unique_name)
return successors
|