File size: 1,301 Bytes
ad93086
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch

from modules import shared, ui_gradio_extensions


class Profiler:
    def __init__(self):
        if not shared.opts.profiling_enable:
            self.profiler = None
            return

        activities = []
        if "CPU" in shared.opts.profiling_activities:
            activities.append(torch.profiler.ProfilerActivity.CPU)
        if "CUDA" in shared.opts.profiling_activities:
            activities.append(torch.profiler.ProfilerActivity.CUDA)

        if not activities:
            self.profiler = None
            return

        self.profiler = torch.profiler.profile(
            activities=activities,
            record_shapes=shared.opts.profiling_record_shapes,
            profile_memory=shared.opts.profiling_profile_memory,
            with_stack=shared.opts.profiling_with_stack
        )

    def __enter__(self):
        if self.profiler:
            self.profiler.__enter__()

        return self

    def __exit__(self, exc_type, exc, exc_tb):
        if self.profiler:
            shared.state.textinfo = "Finishing profile..."

            self.profiler.__exit__(exc_type, exc, exc_tb)

            self.profiler.export_chrome_trace(shared.opts.profiling_filename)


def webpath():
    return ui_gradio_extensions.webpath(shared.opts.profiling_filename)