import torch.cuda class Metric: """ Dumb utility to collect and report average wall-time metrics. """ def __init__(self, label): self.label = label self.measurements = [] def collect(self, measurement): self.measurements.append(measurement) def get_measurements(self): return self.measurements[:] def report(self): print( self.label, torch.quantile(torch.tensor(self.measurements), torch.arange(10) / 10.0), ) def monitor_method_cuda_wall_times(metric, obj, methodname): """ Measure timings for a method on an object or class. For instance: >>> metric = Metric('!LNORM') >>> monitor_method_wall_times(metric, LayerNorm, 'forward') """ oldmeth = getattr(obj, methodname) start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) def newmeth(*args, **kw): start_event.record() try: return oldmeth(*args, **kw) finally: end_event.record() torch.cuda.synchronize() elapsed = start_event.elapsed_time(end_event) metric.collect(elapsed) metric.report() setattr(obj, methodname, newmeth)