reach-vb's picture
reach-vb HF staff
544099ee60f85f9b5462f64b19d52cabecc6d7f3e7dc75f037db0030e18c7d16
a7e4fab
raw
history blame
14.2 kB
import unittest
from llvmlite import ir
from llvmlite import binding as llvm
from llvmlite.tests import TestCase
from . import refprune_proto as proto
def _iterate_cases(generate_test):
def wrap(fn):
def wrapped(self):
return generate_test(self, fn)
wrapped.__doc__ = f"generated test for {fn.__module__}.{fn.__name__}"
return wrapped
for k, case_fn in proto.__dict__.items():
if k.startswith('case'):
yield f'test_{k}', wrap(case_fn)
class TestRefPrunePrototype(TestCase):
"""
Test that the prototype is working.
"""
def generate_test(self, case_gen):
nodes, edges, expected = case_gen()
got = proto.FanoutAlgorithm(nodes, edges).run()
self.assertEqual(expected, got)
# Generate tests
for name, case in _iterate_cases(generate_test):
locals()[name] = case
ptr_ty = ir.IntType(8).as_pointer()
class TestRefPrunePass(TestCase):
"""
Test that the C++ implementation matches the expected behavior as for
the prototype.
This generates a LLVM module for each test case, runs the pruner and checks
that the expected results are achieved.
"""
def make_incref(self, m):
fnty = ir.FunctionType(ir.VoidType(), [ptr_ty])
return ir.Function(m, fnty, name='NRT_incref')
def make_decref(self, m):
fnty = ir.FunctionType(ir.VoidType(), [ptr_ty])
return ir.Function(m, fnty, name='NRT_decref')
def make_switcher(self, m):
fnty = ir.FunctionType(ir.IntType(32), ())
return ir.Function(m, fnty, name='switcher')
def make_brancher(self, m):
fnty = ir.FunctionType(ir.IntType(1), ())
return ir.Function(m, fnty, name='brancher')
def generate_ir(self, nodes, edges):
# Build LLVM module for the CFG
m = ir.Module()
incref_fn = self.make_incref(m)
decref_fn = self.make_decref(m)
switcher_fn = self.make_switcher(m)
brancher_fn = self.make_brancher(m)
fnty = ir.FunctionType(ir.VoidType(), [ptr_ty])
fn = ir.Function(m, fnty, name='main')
[ptr] = fn.args
ptr.name = 'mem'
# populate the BB nodes
bbmap = {}
for bb in edges:
bbmap[bb] = fn.append_basic_block(bb)
# populate the BB
builder = ir.IRBuilder()
for bb, jump_targets in edges.items():
builder.position_at_end(bbmap[bb])
# Insert increfs and decrefs
for action in nodes[bb]:
if action == 'incref':
builder.call(incref_fn, [ptr])
elif action == 'decref':
builder.call(decref_fn, [ptr])
else:
raise AssertionError('unreachable')
# Insert the terminator.
# Switch base on the number of jump targets.
n_targets = len(jump_targets)
if n_targets == 0:
builder.ret_void()
elif n_targets == 1:
[dst] = jump_targets
builder.branch(bbmap[dst])
elif n_targets == 2:
[left, right] = jump_targets
sel = builder.call(brancher_fn, ())
builder.cbranch(sel, bbmap[left], bbmap[right])
elif n_targets > 2:
sel = builder.call(switcher_fn, ())
[head, *tail] = jump_targets
sw = builder.switch(sel, default=bbmap[head])
for i, dst in enumerate(tail):
sw.add_case(sel.type(i), bbmap[dst])
else:
raise AssertionError('unreachable')
return m
def apply_refprune(self, irmod):
mod = llvm.parse_assembly(str(irmod))
pm = llvm.ModulePassManager()
pm.add_refprune_pass()
pm.run(mod)
return mod
def check(self, mod, expected, nodes):
# preprocess incref/decref locations
d = {}
for k, vs in nodes.items():
n_incref = vs.count('incref')
n_decref = vs.count('decref')
d[k] = {'incref': n_incref, 'decref': n_decref}
for k, stats in d.items():
if expected.get(k):
stats['incref'] -= 1
for dec_bb in expected[k]:
d[dec_bb]['decref'] -= 1
# find the main function
for f in mod.functions:
if f.name == 'main':
break
# check each BB
for bb in f.blocks:
stats = d[bb.name]
text = str(bb)
n_incref = text.count('NRT_incref')
n_decref = text.count('NRT_decref')
self.assertEqual(stats['incref'], n_incref, msg=f'BB {bb}')
self.assertEqual(stats['decref'], n_decref, msg=f'BB {bb}')
def generate_test(self, case_gen):
nodes, edges, expected = case_gen()
irmod = self.generate_ir(nodes, edges)
outmod = self.apply_refprune(irmod)
self.check(outmod, expected, nodes)
# Generate tests
for name, case in _iterate_cases(generate_test):
locals()[name] = case
class BaseTestByIR(TestCase):
refprune_bitmask = 0
prologue = r"""
declare void @NRT_incref(i8* %ptr)
declare void @NRT_decref(i8* %ptr)
"""
def check(self, irmod, subgraph_limit=None):
mod = llvm.parse_assembly(f"{self.prologue}\n{irmod}")
pm = llvm.ModulePassManager()
if subgraph_limit is None:
pm.add_refprune_pass(self.refprune_bitmask)
else:
pm.add_refprune_pass(self.refprune_bitmask,
subgraph_limit=subgraph_limit)
before = llvm.dump_refprune_stats()
pm.run(mod)
after = llvm.dump_refprune_stats()
return mod, after - before
class TestPerBB(BaseTestByIR):
refprune_bitmask = llvm.RefPruneSubpasses.PER_BB
per_bb_ir_1 = r"""
define void @main(i8* %ptr) {
call void @NRT_incref(i8* %ptr)
call void @NRT_decref(i8* %ptr)
ret void
}
"""
def test_per_bb_1(self):
mod, stats = self.check(self.per_bb_ir_1)
self.assertEqual(stats.basicblock, 2)
per_bb_ir_2 = r"""
define void @main(i8* %ptr) {
call void @NRT_incref(i8* %ptr)
call void @NRT_incref(i8* %ptr)
call void @NRT_incref(i8* %ptr)
call void @NRT_decref(i8* %ptr)
call void @NRT_decref(i8* %ptr)
ret void
}
"""
def test_per_bb_2(self):
mod, stats = self.check(self.per_bb_ir_2)
self.assertEqual(stats.basicblock, 4)
# not pruned
self.assertIn("call void @NRT_incref(i8* %ptr)", str(mod))
per_bb_ir_3 = r"""
define void @main(i8* %ptr, i8* %other) {
call void @NRT_incref(i8* %ptr)
call void @NRT_incref(i8* %ptr)
call void @NRT_decref(i8* %ptr)
call void @NRT_decref(i8* %other)
ret void
}
"""
def test_per_bb_3(self):
mod, stats = self.check(self.per_bb_ir_3)
self.assertEqual(stats.basicblock, 2)
# not pruned
self.assertIn("call void @NRT_decref(i8* %other)", str(mod))
per_bb_ir_4 = r"""
; reordered
define void @main(i8* %ptr, i8* %other) {
call void @NRT_incref(i8* %ptr)
call void @NRT_decref(i8* %ptr)
call void @NRT_decref(i8* %ptr)
call void @NRT_decref(i8* %other)
call void @NRT_incref(i8* %ptr)
ret void
}
"""
def test_per_bb_4(self):
mod, stats = self.check(self.per_bb_ir_4)
self.assertEqual(stats.basicblock, 4)
# not pruned
self.assertIn("call void @NRT_decref(i8* %other)", str(mod))
class TestDiamond(BaseTestByIR):
refprune_bitmask = llvm.RefPruneSubpasses.DIAMOND
per_diamond_1 = r"""
define void @main(i8* %ptr) {
bb_A:
call void @NRT_incref(i8* %ptr)
br label %bb_B
bb_B:
call void @NRT_decref(i8* %ptr)
ret void
}
"""
def test_per_diamond_1(self):
mod, stats = self.check(self.per_diamond_1)
self.assertEqual(stats.diamond, 2)
per_diamond_2 = r"""
define void @main(i8* %ptr, i1 %cond) {
bb_A:
call void @NRT_incref(i8* %ptr)
br i1 %cond, label %bb_B, label %bb_C
bb_B:
br label %bb_D
bb_C:
br label %bb_D
bb_D:
call void @NRT_decref(i8* %ptr)
ret void
}
"""
def test_per_diamond_2(self):
mod, stats = self.check(self.per_diamond_2)
self.assertEqual(stats.diamond, 2)
per_diamond_3 = r"""
define void @main(i8* %ptr, i1 %cond) {
bb_A:
call void @NRT_incref(i8* %ptr)
br i1 %cond, label %bb_B, label %bb_C
bb_B:
br label %bb_D
bb_C:
call void @NRT_decref(i8* %ptr) ; reject because of decref in diamond
br label %bb_D
bb_D:
call void @NRT_decref(i8* %ptr)
ret void
}
"""
def test_per_diamond_3(self):
mod, stats = self.check(self.per_diamond_3)
self.assertEqual(stats.diamond, 0)
per_diamond_4 = r"""
define void @main(i8* %ptr, i1 %cond) {
bb_A:
call void @NRT_incref(i8* %ptr)
br i1 %cond, label %bb_B, label %bb_C
bb_B:
call void @NRT_incref(i8* %ptr) ; extra incref will not affect prune
br label %bb_D
bb_C:
br label %bb_D
bb_D:
call void @NRT_decref(i8* %ptr)
ret void
}
"""
def test_per_diamond_4(self):
mod, stats = self.check(self.per_diamond_4)
self.assertEqual(stats.diamond, 2)
per_diamond_5 = r"""
define void @main(i8* %ptr, i1 %cond) {
bb_A:
call void @NRT_incref(i8* %ptr)
call void @NRT_incref(i8* %ptr)
br i1 %cond, label %bb_B, label %bb_C
bb_B:
br label %bb_D
bb_C:
br label %bb_D
bb_D:
call void @NRT_decref(i8* %ptr)
call void @NRT_decref(i8* %ptr)
ret void
}
"""
def test_per_diamond_5(self):
mod, stats = self.check(self.per_diamond_5)
self.assertEqual(stats.diamond, 4)
class TestFanout(BaseTestByIR):
"""More complex cases are tested in TestRefPrunePass
"""
refprune_bitmask = llvm.RefPruneSubpasses.FANOUT
fanout_1 = r"""
define void @main(i8* %ptr, i1 %cond) {
bb_A:
call void @NRT_incref(i8* %ptr)
br i1 %cond, label %bb_B, label %bb_C
bb_B:
call void @NRT_decref(i8* %ptr)
ret void
bb_C:
call void @NRT_decref(i8* %ptr)
ret void
}
"""
def test_fanout_1(self):
mod, stats = self.check(self.fanout_1)
self.assertEqual(stats.fanout, 3)
fanout_2 = r"""
define void @main(i8* %ptr, i1 %cond, i8** %excinfo) {
bb_A:
call void @NRT_incref(i8* %ptr)
br i1 %cond, label %bb_B, label %bb_C
bb_B:
call void @NRT_decref(i8* %ptr)
ret void
bb_C:
call void @NRT_decref(i8* %ptr)
br label %bb_B ; illegal jump to other decref
}
"""
def test_fanout_2(self):
mod, stats = self.check(self.fanout_2)
self.assertEqual(stats.fanout, 0)
fanout_3 = r"""
define void @main(i8* %ptr, i1 %cond) {
bb_A:
call void @NRT_incref(i8* %ptr)
call void @NRT_incref(i8* %ptr)
br i1 %cond, label %bb_B, label %bb_C
bb_B:
call void @NRT_decref(i8* %ptr)
call void @NRT_decref(i8* %ptr)
call void @NRT_decref(i8* %ptr)
ret void
bb_C:
call void @NRT_decref(i8* %ptr)
call void @NRT_decref(i8* %ptr)
ret void
}
"""
def test_fanout_3(self):
mod, stats = self.check(self.fanout_3)
self.assertEqual(stats.fanout, 6)
def test_fanout_3_limited(self):
# With subgraph limit at 1, it is essentially turning off the fanout
# pruner.
mod, stats = self.check(self.fanout_3, subgraph_limit=1)
self.assertEqual(stats.fanout, 0)
class TestFanoutRaise(BaseTestByIR):
refprune_bitmask = llvm.RefPruneSubpasses.FANOUT_RAISE
fanout_raise_1 = r"""
define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) {
bb_A:
call void @NRT_incref(i8* %ptr)
br i1 %cond, label %bb_B, label %bb_C
bb_B:
call void @NRT_decref(i8* %ptr)
ret i32 0
bb_C:
store i8* null, i8** %excinfo, !numba_exception_output !0
ret i32 1
}
!0 = !{i1 true}
"""
def test_fanout_raise_1(self):
mod, stats = self.check(self.fanout_raise_1)
self.assertEqual(stats.fanout_raise, 2)
fanout_raise_2 = r"""
define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) {
bb_A:
call void @NRT_incref(i8* %ptr)
br i1 %cond, label %bb_B, label %bb_C
bb_B:
call void @NRT_decref(i8* %ptr)
ret i32 0
bb_C:
store i8* null, i8** %excinfo, !numba_exception_typo !0 ; bad metadata
ret i32 1
}
!0 = !{i1 true}
"""
def test_fanout_raise_2(self):
# This is ensuring that fanout_raise is not pruning when the metadata
# is incorrectly named.
mod, stats = self.check(self.fanout_raise_2)
self.assertEqual(stats.fanout_raise, 0)
fanout_raise_3 = r"""
define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) {
bb_A:
call void @NRT_incref(i8* %ptr)
br i1 %cond, label %bb_B, label %bb_C
bb_B:
call void @NRT_decref(i8* %ptr)
ret i32 0
bb_C:
store i8* null, i8** %excinfo, !numba_exception_output !0
ret i32 1
}
!0 = !{i32 1} ; ok; use i32
"""
def test_fanout_raise_3(self):
mod, stats = self.check(self.fanout_raise_3)
self.assertEqual(stats.fanout_raise, 2)
fanout_raise_4 = r"""
define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) {
bb_A:
call void @NRT_incref(i8* %ptr)
br i1 %cond, label %bb_B, label %bb_C
bb_B:
ret i32 1 ; BAD; all tails are raising without decref
bb_C:
ret i32 1 ; BAD; all tails are raising without decref
}
!0 = !{i1 1}
"""
def test_fanout_raise_4(self):
mod, stats = self.check(self.fanout_raise_4)
self.assertEqual(stats.fanout_raise, 0)
fanout_raise_5 = r"""
define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) {
bb_A:
call void @NRT_incref(i8* %ptr)
br i1 %cond, label %bb_B, label %bb_C
bb_B:
call void @NRT_decref(i8* %ptr)
br label %common.ret
bb_C:
store i8* null, i8** %excinfo, !numba_exception_output !0
br label %common.ret
common.ret:
%common.ret.op = phi i32 [ 0, %bb_B ], [ 1, %bb_C ]
ret i32 %common.ret.op
}
!0 = !{i1 1}
"""
def test_fanout_raise_5(self):
mod, stats = self.check(self.fanout_raise_5)
self.assertEqual(stats.fanout_raise, 2)
if __name__ == '__main__':
unittest.main()