|
import unittest |
|
from llvmlite import ir |
|
from llvmlite import binding as llvm |
|
from llvmlite.tests import TestCase |
|
|
|
from . import refprune_proto as proto |
|
|
|
|
|
def _iterate_cases(generate_test): |
|
def wrap(fn): |
|
def wrapped(self): |
|
return generate_test(self, fn) |
|
wrapped.__doc__ = f"generated test for {fn.__module__}.{fn.__name__}" |
|
return wrapped |
|
|
|
for k, case_fn in proto.__dict__.items(): |
|
if k.startswith('case'): |
|
yield f'test_{k}', wrap(case_fn) |
|
|
|
|
|
class TestRefPrunePrototype(TestCase): |
|
""" |
|
Test that the prototype is working. |
|
""" |
|
def generate_test(self, case_gen): |
|
nodes, edges, expected = case_gen() |
|
got = proto.FanoutAlgorithm(nodes, edges).run() |
|
self.assertEqual(expected, got) |
|
|
|
|
|
for name, case in _iterate_cases(generate_test): |
|
locals()[name] = case |
|
|
|
|
|
ptr_ty = ir.IntType(8).as_pointer() |
|
|
|
|
|
class TestRefPrunePass(TestCase): |
|
""" |
|
Test that the C++ implementation matches the expected behavior as for |
|
the prototype. |
|
|
|
This generates a LLVM module for each test case, runs the pruner and checks |
|
that the expected results are achieved. |
|
""" |
|
|
|
def make_incref(self, m): |
|
fnty = ir.FunctionType(ir.VoidType(), [ptr_ty]) |
|
return ir.Function(m, fnty, name='NRT_incref') |
|
|
|
def make_decref(self, m): |
|
fnty = ir.FunctionType(ir.VoidType(), [ptr_ty]) |
|
return ir.Function(m, fnty, name='NRT_decref') |
|
|
|
def make_switcher(self, m): |
|
fnty = ir.FunctionType(ir.IntType(32), ()) |
|
return ir.Function(m, fnty, name='switcher') |
|
|
|
def make_brancher(self, m): |
|
fnty = ir.FunctionType(ir.IntType(1), ()) |
|
return ir.Function(m, fnty, name='brancher') |
|
|
|
def generate_ir(self, nodes, edges): |
|
|
|
m = ir.Module() |
|
|
|
incref_fn = self.make_incref(m) |
|
decref_fn = self.make_decref(m) |
|
switcher_fn = self.make_switcher(m) |
|
brancher_fn = self.make_brancher(m) |
|
|
|
fnty = ir.FunctionType(ir.VoidType(), [ptr_ty]) |
|
fn = ir.Function(m, fnty, name='main') |
|
[ptr] = fn.args |
|
ptr.name = 'mem' |
|
|
|
bbmap = {} |
|
for bb in edges: |
|
bbmap[bb] = fn.append_basic_block(bb) |
|
|
|
builder = ir.IRBuilder() |
|
for bb, jump_targets in edges.items(): |
|
builder.position_at_end(bbmap[bb]) |
|
|
|
for action in nodes[bb]: |
|
if action == 'incref': |
|
builder.call(incref_fn, [ptr]) |
|
elif action == 'decref': |
|
builder.call(decref_fn, [ptr]) |
|
else: |
|
raise AssertionError('unreachable') |
|
|
|
|
|
|
|
n_targets = len(jump_targets) |
|
if n_targets == 0: |
|
builder.ret_void() |
|
elif n_targets == 1: |
|
[dst] = jump_targets |
|
builder.branch(bbmap[dst]) |
|
elif n_targets == 2: |
|
[left, right] = jump_targets |
|
sel = builder.call(brancher_fn, ()) |
|
builder.cbranch(sel, bbmap[left], bbmap[right]) |
|
elif n_targets > 2: |
|
sel = builder.call(switcher_fn, ()) |
|
[head, *tail] = jump_targets |
|
|
|
sw = builder.switch(sel, default=bbmap[head]) |
|
for i, dst in enumerate(tail): |
|
sw.add_case(sel.type(i), bbmap[dst]) |
|
else: |
|
raise AssertionError('unreachable') |
|
|
|
return m |
|
|
|
def apply_refprune(self, irmod): |
|
mod = llvm.parse_assembly(str(irmod)) |
|
pm = llvm.ModulePassManager() |
|
pm.add_refprune_pass() |
|
pm.run(mod) |
|
return mod |
|
|
|
def check(self, mod, expected, nodes): |
|
|
|
d = {} |
|
for k, vs in nodes.items(): |
|
n_incref = vs.count('incref') |
|
n_decref = vs.count('decref') |
|
d[k] = {'incref': n_incref, 'decref': n_decref} |
|
for k, stats in d.items(): |
|
if expected.get(k): |
|
stats['incref'] -= 1 |
|
for dec_bb in expected[k]: |
|
d[dec_bb]['decref'] -= 1 |
|
|
|
|
|
for f in mod.functions: |
|
if f.name == 'main': |
|
break |
|
|
|
for bb in f.blocks: |
|
stats = d[bb.name] |
|
text = str(bb) |
|
n_incref = text.count('NRT_incref') |
|
n_decref = text.count('NRT_decref') |
|
self.assertEqual(stats['incref'], n_incref, msg=f'BB {bb}') |
|
self.assertEqual(stats['decref'], n_decref, msg=f'BB {bb}') |
|
|
|
def generate_test(self, case_gen): |
|
nodes, edges, expected = case_gen() |
|
irmod = self.generate_ir(nodes, edges) |
|
outmod = self.apply_refprune(irmod) |
|
self.check(outmod, expected, nodes) |
|
|
|
|
|
for name, case in _iterate_cases(generate_test): |
|
locals()[name] = case |
|
|
|
|
|
class BaseTestByIR(TestCase): |
|
refprune_bitmask = 0 |
|
|
|
prologue = r""" |
|
declare void @NRT_incref(i8* %ptr) |
|
declare void @NRT_decref(i8* %ptr) |
|
""" |
|
|
|
def check(self, irmod, subgraph_limit=None): |
|
mod = llvm.parse_assembly(f"{self.prologue}\n{irmod}") |
|
pm = llvm.ModulePassManager() |
|
if subgraph_limit is None: |
|
pm.add_refprune_pass(self.refprune_bitmask) |
|
else: |
|
pm.add_refprune_pass(self.refprune_bitmask, |
|
subgraph_limit=subgraph_limit) |
|
before = llvm.dump_refprune_stats() |
|
pm.run(mod) |
|
after = llvm.dump_refprune_stats() |
|
return mod, after - before |
|
|
|
|
|
class TestPerBB(BaseTestByIR): |
|
refprune_bitmask = llvm.RefPruneSubpasses.PER_BB |
|
|
|
per_bb_ir_1 = r""" |
|
define void @main(i8* %ptr) { |
|
call void @NRT_incref(i8* %ptr) |
|
call void @NRT_decref(i8* %ptr) |
|
ret void |
|
} |
|
""" |
|
|
|
def test_per_bb_1(self): |
|
mod, stats = self.check(self.per_bb_ir_1) |
|
self.assertEqual(stats.basicblock, 2) |
|
|
|
per_bb_ir_2 = r""" |
|
define void @main(i8* %ptr) { |
|
call void @NRT_incref(i8* %ptr) |
|
call void @NRT_incref(i8* %ptr) |
|
call void @NRT_incref(i8* %ptr) |
|
call void @NRT_decref(i8* %ptr) |
|
call void @NRT_decref(i8* %ptr) |
|
ret void |
|
} |
|
""" |
|
|
|
def test_per_bb_2(self): |
|
mod, stats = self.check(self.per_bb_ir_2) |
|
self.assertEqual(stats.basicblock, 4) |
|
|
|
self.assertIn("call void @NRT_incref(i8* %ptr)", str(mod)) |
|
|
|
per_bb_ir_3 = r""" |
|
define void @main(i8* %ptr, i8* %other) { |
|
call void @NRT_incref(i8* %ptr) |
|
call void @NRT_incref(i8* %ptr) |
|
call void @NRT_decref(i8* %ptr) |
|
call void @NRT_decref(i8* %other) |
|
ret void |
|
} |
|
""" |
|
|
|
def test_per_bb_3(self): |
|
mod, stats = self.check(self.per_bb_ir_3) |
|
self.assertEqual(stats.basicblock, 2) |
|
|
|
self.assertIn("call void @NRT_decref(i8* %other)", str(mod)) |
|
|
|
per_bb_ir_4 = r""" |
|
; reordered |
|
define void @main(i8* %ptr, i8* %other) { |
|
call void @NRT_incref(i8* %ptr) |
|
call void @NRT_decref(i8* %ptr) |
|
call void @NRT_decref(i8* %ptr) |
|
call void @NRT_decref(i8* %other) |
|
call void @NRT_incref(i8* %ptr) |
|
ret void |
|
} |
|
""" |
|
|
|
def test_per_bb_4(self): |
|
mod, stats = self.check(self.per_bb_ir_4) |
|
self.assertEqual(stats.basicblock, 4) |
|
|
|
self.assertIn("call void @NRT_decref(i8* %other)", str(mod)) |
|
|
|
|
|
class TestDiamond(BaseTestByIR): |
|
refprune_bitmask = llvm.RefPruneSubpasses.DIAMOND |
|
|
|
per_diamond_1 = r""" |
|
define void @main(i8* %ptr) { |
|
bb_A: |
|
call void @NRT_incref(i8* %ptr) |
|
br label %bb_B |
|
bb_B: |
|
call void @NRT_decref(i8* %ptr) |
|
ret void |
|
} |
|
""" |
|
|
|
def test_per_diamond_1(self): |
|
mod, stats = self.check(self.per_diamond_1) |
|
self.assertEqual(stats.diamond, 2) |
|
|
|
per_diamond_2 = r""" |
|
define void @main(i8* %ptr, i1 %cond) { |
|
bb_A: |
|
call void @NRT_incref(i8* %ptr) |
|
br i1 %cond, label %bb_B, label %bb_C |
|
bb_B: |
|
br label %bb_D |
|
bb_C: |
|
br label %bb_D |
|
bb_D: |
|
call void @NRT_decref(i8* %ptr) |
|
ret void |
|
} |
|
""" |
|
|
|
def test_per_diamond_2(self): |
|
mod, stats = self.check(self.per_diamond_2) |
|
self.assertEqual(stats.diamond, 2) |
|
|
|
per_diamond_3 = r""" |
|
define void @main(i8* %ptr, i1 %cond) { |
|
bb_A: |
|
call void @NRT_incref(i8* %ptr) |
|
br i1 %cond, label %bb_B, label %bb_C |
|
bb_B: |
|
br label %bb_D |
|
bb_C: |
|
call void @NRT_decref(i8* %ptr) ; reject because of decref in diamond |
|
br label %bb_D |
|
bb_D: |
|
call void @NRT_decref(i8* %ptr) |
|
ret void |
|
} |
|
""" |
|
|
|
def test_per_diamond_3(self): |
|
mod, stats = self.check(self.per_diamond_3) |
|
self.assertEqual(stats.diamond, 0) |
|
|
|
per_diamond_4 = r""" |
|
define void @main(i8* %ptr, i1 %cond) { |
|
bb_A: |
|
call void @NRT_incref(i8* %ptr) |
|
br i1 %cond, label %bb_B, label %bb_C |
|
bb_B: |
|
call void @NRT_incref(i8* %ptr) ; extra incref will not affect prune |
|
br label %bb_D |
|
bb_C: |
|
br label %bb_D |
|
bb_D: |
|
call void @NRT_decref(i8* %ptr) |
|
ret void |
|
} |
|
""" |
|
|
|
def test_per_diamond_4(self): |
|
mod, stats = self.check(self.per_diamond_4) |
|
self.assertEqual(stats.diamond, 2) |
|
|
|
per_diamond_5 = r""" |
|
define void @main(i8* %ptr, i1 %cond) { |
|
bb_A: |
|
call void @NRT_incref(i8* %ptr) |
|
call void @NRT_incref(i8* %ptr) |
|
br i1 %cond, label %bb_B, label %bb_C |
|
bb_B: |
|
br label %bb_D |
|
bb_C: |
|
br label %bb_D |
|
bb_D: |
|
call void @NRT_decref(i8* %ptr) |
|
call void @NRT_decref(i8* %ptr) |
|
ret void |
|
} |
|
""" |
|
|
|
def test_per_diamond_5(self): |
|
mod, stats = self.check(self.per_diamond_5) |
|
self.assertEqual(stats.diamond, 4) |
|
|
|
|
|
class TestFanout(BaseTestByIR): |
|
"""More complex cases are tested in TestRefPrunePass |
|
""" |
|
|
|
refprune_bitmask = llvm.RefPruneSubpasses.FANOUT |
|
|
|
fanout_1 = r""" |
|
define void @main(i8* %ptr, i1 %cond) { |
|
bb_A: |
|
call void @NRT_incref(i8* %ptr) |
|
br i1 %cond, label %bb_B, label %bb_C |
|
bb_B: |
|
call void @NRT_decref(i8* %ptr) |
|
ret void |
|
bb_C: |
|
call void @NRT_decref(i8* %ptr) |
|
ret void |
|
} |
|
""" |
|
|
|
def test_fanout_1(self): |
|
mod, stats = self.check(self.fanout_1) |
|
self.assertEqual(stats.fanout, 3) |
|
|
|
fanout_2 = r""" |
|
define void @main(i8* %ptr, i1 %cond, i8** %excinfo) { |
|
bb_A: |
|
call void @NRT_incref(i8* %ptr) |
|
br i1 %cond, label %bb_B, label %bb_C |
|
bb_B: |
|
call void @NRT_decref(i8* %ptr) |
|
ret void |
|
bb_C: |
|
call void @NRT_decref(i8* %ptr) |
|
br label %bb_B ; illegal jump to other decref |
|
} |
|
""" |
|
|
|
def test_fanout_2(self): |
|
mod, stats = self.check(self.fanout_2) |
|
self.assertEqual(stats.fanout, 0) |
|
|
|
fanout_3 = r""" |
|
define void @main(i8* %ptr, i1 %cond) { |
|
bb_A: |
|
call void @NRT_incref(i8* %ptr) |
|
call void @NRT_incref(i8* %ptr) |
|
br i1 %cond, label %bb_B, label %bb_C |
|
bb_B: |
|
call void @NRT_decref(i8* %ptr) |
|
call void @NRT_decref(i8* %ptr) |
|
call void @NRT_decref(i8* %ptr) |
|
ret void |
|
bb_C: |
|
call void @NRT_decref(i8* %ptr) |
|
call void @NRT_decref(i8* %ptr) |
|
ret void |
|
} |
|
""" |
|
|
|
def test_fanout_3(self): |
|
mod, stats = self.check(self.fanout_3) |
|
self.assertEqual(stats.fanout, 6) |
|
|
|
def test_fanout_3_limited(self): |
|
|
|
|
|
mod, stats = self.check(self.fanout_3, subgraph_limit=1) |
|
self.assertEqual(stats.fanout, 0) |
|
|
|
|
|
class TestFanoutRaise(BaseTestByIR): |
|
refprune_bitmask = llvm.RefPruneSubpasses.FANOUT_RAISE |
|
|
|
fanout_raise_1 = r""" |
|
define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) { |
|
bb_A: |
|
call void @NRT_incref(i8* %ptr) |
|
br i1 %cond, label %bb_B, label %bb_C |
|
bb_B: |
|
call void @NRT_decref(i8* %ptr) |
|
ret i32 0 |
|
bb_C: |
|
store i8* null, i8** %excinfo, !numba_exception_output !0 |
|
ret i32 1 |
|
} |
|
!0 = !{i1 true} |
|
""" |
|
|
|
def test_fanout_raise_1(self): |
|
mod, stats = self.check(self.fanout_raise_1) |
|
self.assertEqual(stats.fanout_raise, 2) |
|
|
|
fanout_raise_2 = r""" |
|
define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) { |
|
bb_A: |
|
call void @NRT_incref(i8* %ptr) |
|
br i1 %cond, label %bb_B, label %bb_C |
|
bb_B: |
|
call void @NRT_decref(i8* %ptr) |
|
ret i32 0 |
|
bb_C: |
|
store i8* null, i8** %excinfo, !numba_exception_typo !0 ; bad metadata |
|
ret i32 1 |
|
} |
|
|
|
!0 = !{i1 true} |
|
""" |
|
|
|
def test_fanout_raise_2(self): |
|
|
|
|
|
mod, stats = self.check(self.fanout_raise_2) |
|
self.assertEqual(stats.fanout_raise, 0) |
|
|
|
fanout_raise_3 = r""" |
|
define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) { |
|
bb_A: |
|
call void @NRT_incref(i8* %ptr) |
|
br i1 %cond, label %bb_B, label %bb_C |
|
bb_B: |
|
call void @NRT_decref(i8* %ptr) |
|
ret i32 0 |
|
bb_C: |
|
store i8* null, i8** %excinfo, !numba_exception_output !0 |
|
ret i32 1 |
|
} |
|
|
|
!0 = !{i32 1} ; ok; use i32 |
|
""" |
|
|
|
def test_fanout_raise_3(self): |
|
mod, stats = self.check(self.fanout_raise_3) |
|
self.assertEqual(stats.fanout_raise, 2) |
|
|
|
fanout_raise_4 = r""" |
|
define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) { |
|
bb_A: |
|
call void @NRT_incref(i8* %ptr) |
|
br i1 %cond, label %bb_B, label %bb_C |
|
bb_B: |
|
ret i32 1 ; BAD; all tails are raising without decref |
|
bb_C: |
|
ret i32 1 ; BAD; all tails are raising without decref |
|
} |
|
|
|
!0 = !{i1 1} |
|
""" |
|
|
|
def test_fanout_raise_4(self): |
|
mod, stats = self.check(self.fanout_raise_4) |
|
self.assertEqual(stats.fanout_raise, 0) |
|
|
|
fanout_raise_5 = r""" |
|
define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) { |
|
bb_A: |
|
call void @NRT_incref(i8* %ptr) |
|
br i1 %cond, label %bb_B, label %bb_C |
|
bb_B: |
|
call void @NRT_decref(i8* %ptr) |
|
br label %common.ret |
|
bb_C: |
|
store i8* null, i8** %excinfo, !numba_exception_output !0 |
|
br label %common.ret |
|
common.ret: |
|
%common.ret.op = phi i32 [ 0, %bb_B ], [ 1, %bb_C ] |
|
ret i32 %common.ret.op |
|
} |
|
!0 = !{i1 1} |
|
""" |
|
|
|
def test_fanout_raise_5(self): |
|
mod, stats = self.check(self.fanout_raise_5) |
|
self.assertEqual(stats.fanout_raise, 2) |
|
|
|
|
|
if __name__ == '__main__': |
|
unittest.main() |
|
|