File size: 14,235 Bytes
a7e4fab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
import unittest
from llvmlite import ir
from llvmlite import binding as llvm
from llvmlite.tests import TestCase

from . import refprune_proto as proto


def _iterate_cases(generate_test):
    def wrap(fn):
        def wrapped(self):
            return generate_test(self, fn)
        wrapped.__doc__ = f"generated test for {fn.__module__}.{fn.__name__}"
        return wrapped

    for k, case_fn in proto.__dict__.items():
        if k.startswith('case'):
            yield f'test_{k}', wrap(case_fn)


class TestRefPrunePrototype(TestCase):
    """
    Test that the prototype is working.
    """
    def generate_test(self, case_gen):
        nodes, edges, expected = case_gen()
        got = proto.FanoutAlgorithm(nodes, edges).run()
        self.assertEqual(expected, got)

    # Generate tests
    for name, case in _iterate_cases(generate_test):
        locals()[name] = case


ptr_ty = ir.IntType(8).as_pointer()


class TestRefPrunePass(TestCase):
    """
    Test that the C++ implementation matches the expected behavior as for
    the prototype.

    This generates a LLVM module for each test case, runs the pruner and checks
    that the expected results are achieved.
    """

    def make_incref(self, m):
        fnty = ir.FunctionType(ir.VoidType(), [ptr_ty])
        return ir.Function(m, fnty, name='NRT_incref')

    def make_decref(self, m):
        fnty = ir.FunctionType(ir.VoidType(), [ptr_ty])
        return ir.Function(m, fnty, name='NRT_decref')

    def make_switcher(self, m):
        fnty = ir.FunctionType(ir.IntType(32), ())
        return ir.Function(m, fnty, name='switcher')

    def make_brancher(self, m):
        fnty = ir.FunctionType(ir.IntType(1), ())
        return ir.Function(m, fnty, name='brancher')

    def generate_ir(self, nodes, edges):
        # Build LLVM module for the CFG
        m = ir.Module()

        incref_fn = self.make_incref(m)
        decref_fn = self.make_decref(m)
        switcher_fn = self.make_switcher(m)
        brancher_fn = self.make_brancher(m)

        fnty = ir.FunctionType(ir.VoidType(), [ptr_ty])
        fn = ir.Function(m, fnty, name='main')
        [ptr] = fn.args
        ptr.name = 'mem'
        # populate the BB nodes
        bbmap = {}
        for bb in edges:
            bbmap[bb] = fn.append_basic_block(bb)
        # populate the BB
        builder = ir.IRBuilder()
        for bb, jump_targets in edges.items():
            builder.position_at_end(bbmap[bb])
            # Insert increfs and decrefs
            for action in nodes[bb]:
                if action == 'incref':
                    builder.call(incref_fn, [ptr])
                elif action == 'decref':
                    builder.call(decref_fn, [ptr])
                else:
                    raise AssertionError('unreachable')

            # Insert the terminator.
            # Switch base on the number of jump targets.
            n_targets = len(jump_targets)
            if n_targets == 0:
                builder.ret_void()
            elif n_targets == 1:
                [dst] = jump_targets
                builder.branch(bbmap[dst])
            elif n_targets == 2:
                [left, right] = jump_targets
                sel = builder.call(brancher_fn, ())
                builder.cbranch(sel, bbmap[left], bbmap[right])
            elif n_targets > 2:
                sel = builder.call(switcher_fn, ())
                [head, *tail] = jump_targets

                sw = builder.switch(sel, default=bbmap[head])
                for i, dst in enumerate(tail):
                    sw.add_case(sel.type(i), bbmap[dst])
            else:
                raise AssertionError('unreachable')

        return m

    def apply_refprune(self, irmod):
        mod = llvm.parse_assembly(str(irmod))
        pm = llvm.ModulePassManager()
        pm.add_refprune_pass()
        pm.run(mod)
        return mod

    def check(self, mod, expected, nodes):
        # preprocess incref/decref locations
        d = {}
        for k, vs in nodes.items():
            n_incref = vs.count('incref')
            n_decref = vs.count('decref')
            d[k] = {'incref': n_incref, 'decref': n_decref}
        for k, stats in d.items():
            if expected.get(k):
                stats['incref'] -= 1
                for dec_bb in expected[k]:
                    d[dec_bb]['decref'] -= 1

        # find the main function
        for f in mod.functions:
            if f.name == 'main':
                break
        # check each BB
        for bb in f.blocks:
            stats = d[bb.name]
            text = str(bb)
            n_incref = text.count('NRT_incref')
            n_decref = text.count('NRT_decref')
            self.assertEqual(stats['incref'], n_incref, msg=f'BB {bb}')
            self.assertEqual(stats['decref'], n_decref, msg=f'BB {bb}')

    def generate_test(self, case_gen):
        nodes, edges, expected = case_gen()
        irmod = self.generate_ir(nodes, edges)
        outmod = self.apply_refprune(irmod)
        self.check(outmod, expected, nodes)

    # Generate tests
    for name, case in _iterate_cases(generate_test):
        locals()[name] = case


class BaseTestByIR(TestCase):
    refprune_bitmask = 0

    prologue = r"""
declare void @NRT_incref(i8* %ptr)
declare void @NRT_decref(i8* %ptr)
"""

    def check(self, irmod, subgraph_limit=None):
        mod = llvm.parse_assembly(f"{self.prologue}\n{irmod}")
        pm = llvm.ModulePassManager()
        if subgraph_limit is None:
            pm.add_refprune_pass(self.refprune_bitmask)
        else:
            pm.add_refprune_pass(self.refprune_bitmask,
                                 subgraph_limit=subgraph_limit)
        before = llvm.dump_refprune_stats()
        pm.run(mod)
        after = llvm.dump_refprune_stats()
        return mod, after - before


class TestPerBB(BaseTestByIR):
    refprune_bitmask = llvm.RefPruneSubpasses.PER_BB

    per_bb_ir_1 = r"""
define void @main(i8* %ptr) {
    call void @NRT_incref(i8* %ptr)
    call void @NRT_decref(i8* %ptr)
    ret void
}
"""

    def test_per_bb_1(self):
        mod, stats = self.check(self.per_bb_ir_1)
        self.assertEqual(stats.basicblock, 2)

    per_bb_ir_2 = r"""
define void @main(i8* %ptr) {
    call void @NRT_incref(i8* %ptr)
    call void @NRT_incref(i8* %ptr)
    call void @NRT_incref(i8* %ptr)
    call void @NRT_decref(i8* %ptr)
    call void @NRT_decref(i8* %ptr)
    ret void
}
"""

    def test_per_bb_2(self):
        mod, stats = self.check(self.per_bb_ir_2)
        self.assertEqual(stats.basicblock, 4)
        # not pruned
        self.assertIn("call void @NRT_incref(i8* %ptr)", str(mod))

    per_bb_ir_3 = r"""
define void @main(i8* %ptr, i8* %other) {
    call void @NRT_incref(i8* %ptr)
    call void @NRT_incref(i8* %ptr)
    call void @NRT_decref(i8* %ptr)
    call void @NRT_decref(i8* %other)
    ret void
}
"""

    def test_per_bb_3(self):
        mod, stats = self.check(self.per_bb_ir_3)
        self.assertEqual(stats.basicblock, 2)
        # not pruned
        self.assertIn("call void @NRT_decref(i8* %other)", str(mod))

    per_bb_ir_4 = r"""
; reordered
define void @main(i8* %ptr, i8* %other) {
    call void @NRT_incref(i8* %ptr)
    call void @NRT_decref(i8* %ptr)
    call void @NRT_decref(i8* %ptr)
    call void @NRT_decref(i8* %other)
    call void @NRT_incref(i8* %ptr)
    ret void
}
"""

    def test_per_bb_4(self):
        mod, stats = self.check(self.per_bb_ir_4)
        self.assertEqual(stats.basicblock, 4)
        # not pruned
        self.assertIn("call void @NRT_decref(i8* %other)", str(mod))


class TestDiamond(BaseTestByIR):
    refprune_bitmask = llvm.RefPruneSubpasses.DIAMOND

    per_diamond_1 = r"""
define void @main(i8* %ptr) {
bb_A:
    call void @NRT_incref(i8* %ptr)
    br label %bb_B
bb_B:
    call void @NRT_decref(i8* %ptr)
    ret void
}
"""

    def test_per_diamond_1(self):
        mod, stats = self.check(self.per_diamond_1)
        self.assertEqual(stats.diamond, 2)

    per_diamond_2 = r"""
define void @main(i8* %ptr, i1 %cond) {
bb_A:
    call void @NRT_incref(i8* %ptr)
    br i1 %cond, label %bb_B, label %bb_C
bb_B:
    br label %bb_D
bb_C:
    br label %bb_D
bb_D:
    call void @NRT_decref(i8* %ptr)
    ret void
}
"""

    def test_per_diamond_2(self):
        mod, stats = self.check(self.per_diamond_2)
        self.assertEqual(stats.diamond, 2)

    per_diamond_3 = r"""
define void @main(i8* %ptr, i1 %cond) {
bb_A:
    call void @NRT_incref(i8* %ptr)
    br i1 %cond, label %bb_B, label %bb_C
bb_B:
    br label %bb_D
bb_C:
    call void @NRT_decref(i8* %ptr)  ; reject because of decref in diamond
    br label %bb_D
bb_D:
    call void @NRT_decref(i8* %ptr)
    ret void
}
"""

    def test_per_diamond_3(self):
        mod, stats = self.check(self.per_diamond_3)
        self.assertEqual(stats.diamond, 0)

    per_diamond_4 = r"""
define void @main(i8* %ptr, i1 %cond) {
bb_A:
    call void @NRT_incref(i8* %ptr)
    br i1 %cond, label %bb_B, label %bb_C
bb_B:
    call void @NRT_incref(i8* %ptr)     ; extra incref will not affect prune
    br label %bb_D
bb_C:
    br label %bb_D
bb_D:
    call void @NRT_decref(i8* %ptr)
    ret void
}
"""

    def test_per_diamond_4(self):
        mod, stats = self.check(self.per_diamond_4)
        self.assertEqual(stats.diamond, 2)

    per_diamond_5 = r"""
define void @main(i8* %ptr, i1 %cond) {
bb_A:
    call void @NRT_incref(i8* %ptr)
    call void @NRT_incref(i8* %ptr)
    br i1 %cond, label %bb_B, label %bb_C
bb_B:
    br label %bb_D
bb_C:
    br label %bb_D
bb_D:
    call void @NRT_decref(i8* %ptr)
    call void @NRT_decref(i8* %ptr)
    ret void
}
"""

    def test_per_diamond_5(self):
        mod, stats = self.check(self.per_diamond_5)
        self.assertEqual(stats.diamond, 4)


class TestFanout(BaseTestByIR):
    """More complex cases are tested in TestRefPrunePass
    """

    refprune_bitmask = llvm.RefPruneSubpasses.FANOUT

    fanout_1 = r"""
define void @main(i8* %ptr, i1 %cond) {
bb_A:
    call void @NRT_incref(i8* %ptr)
    br i1 %cond, label %bb_B, label %bb_C
bb_B:
    call void @NRT_decref(i8* %ptr)
    ret void
bb_C:
    call void @NRT_decref(i8* %ptr)
    ret void
}
"""

    def test_fanout_1(self):
        mod, stats = self.check(self.fanout_1)
        self.assertEqual(stats.fanout, 3)

    fanout_2 = r"""
define void @main(i8* %ptr, i1 %cond, i8** %excinfo) {
bb_A:
    call void @NRT_incref(i8* %ptr)
    br i1 %cond, label %bb_B, label %bb_C
bb_B:
    call void @NRT_decref(i8* %ptr)
    ret void
bb_C:
    call void @NRT_decref(i8* %ptr)
    br label %bb_B                      ; illegal jump to other decref
}
"""

    def test_fanout_2(self):
        mod, stats = self.check(self.fanout_2)
        self.assertEqual(stats.fanout, 0)

    fanout_3 = r"""
define void @main(i8* %ptr, i1 %cond) {
bb_A:
    call void @NRT_incref(i8* %ptr)
    call void @NRT_incref(i8* %ptr)
    br i1 %cond, label %bb_B, label %bb_C
bb_B:
    call void @NRT_decref(i8* %ptr)
    call void @NRT_decref(i8* %ptr)
    call void @NRT_decref(i8* %ptr)
    ret void
bb_C:
    call void @NRT_decref(i8* %ptr)
    call void @NRT_decref(i8* %ptr)
    ret void
}
"""

    def test_fanout_3(self):
        mod, stats = self.check(self.fanout_3)
        self.assertEqual(stats.fanout, 6)

    def test_fanout_3_limited(self):
        # With subgraph limit at 1, it is essentially turning off the fanout
        # pruner.
        mod, stats = self.check(self.fanout_3, subgraph_limit=1)
        self.assertEqual(stats.fanout, 0)


class TestFanoutRaise(BaseTestByIR):
    refprune_bitmask = llvm.RefPruneSubpasses.FANOUT_RAISE

    fanout_raise_1 = r"""
define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) {
bb_A:
    call void @NRT_incref(i8* %ptr)
    br i1 %cond, label %bb_B, label %bb_C
bb_B:
    call void @NRT_decref(i8* %ptr)
    ret i32 0
bb_C:
    store i8* null, i8** %excinfo, !numba_exception_output !0
    ret i32 1
}
!0 = !{i1 true}
"""

    def test_fanout_raise_1(self):
        mod, stats = self.check(self.fanout_raise_1)
        self.assertEqual(stats.fanout_raise, 2)

    fanout_raise_2 = r"""
define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) {
bb_A:
    call void @NRT_incref(i8* %ptr)
    br i1 %cond, label %bb_B, label %bb_C
bb_B:
    call void @NRT_decref(i8* %ptr)
    ret i32 0
bb_C:
    store i8* null, i8** %excinfo, !numba_exception_typo !0      ; bad metadata
    ret i32 1
}

!0 = !{i1 true}
"""

    def test_fanout_raise_2(self):
        # This is ensuring that fanout_raise is not pruning when the metadata
        # is incorrectly named.
        mod, stats = self.check(self.fanout_raise_2)
        self.assertEqual(stats.fanout_raise, 0)

    fanout_raise_3 = r"""
define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) {
bb_A:
    call void @NRT_incref(i8* %ptr)
    br i1 %cond, label %bb_B, label %bb_C
bb_B:
    call void @NRT_decref(i8* %ptr)
    ret i32 0
bb_C:
    store i8* null, i8** %excinfo, !numba_exception_output !0
    ret i32 1
}

!0 = !{i32 1}       ; ok; use i32
"""

    def test_fanout_raise_3(self):
        mod, stats = self.check(self.fanout_raise_3)
        self.assertEqual(stats.fanout_raise, 2)

    fanout_raise_4 = r"""
define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) {
bb_A:
    call void @NRT_incref(i8* %ptr)
    br i1 %cond, label %bb_B, label %bb_C
bb_B:
    ret i32 1    ; BAD; all tails are raising without decref
bb_C:
    ret i32 1    ; BAD; all tails are raising without decref
}

!0 = !{i1 1}
"""

    def test_fanout_raise_4(self):
        mod, stats = self.check(self.fanout_raise_4)
        self.assertEqual(stats.fanout_raise, 0)

    fanout_raise_5 = r"""
define i32 @main(i8* %ptr, i1 %cond, i8** %excinfo) {
bb_A:
    call void @NRT_incref(i8* %ptr)
    br i1 %cond, label %bb_B, label %bb_C
bb_B:
    call void @NRT_decref(i8* %ptr)
    br label %common.ret
bb_C:
    store i8* null, i8** %excinfo, !numba_exception_output !0
    br label %common.ret
common.ret:
    %common.ret.op = phi i32 [ 0, %bb_B ], [ 1, %bb_C ]
    ret i32 %common.ret.op
}
!0 = !{i1 1}
"""

    def test_fanout_raise_5(self):
        mod, stats = self.check(self.fanout_raise_5)
        self.assertEqual(stats.fanout_raise, 2)


if __name__ == '__main__':
    unittest.main()