import pathlib import gradio from captum.attr import visualization class Markdown(gradio.Markdown): def __init__(self, value, *args, **kwargs): if isinstance(value, pathlib.Path): value = value.read_text() elif isinstance(value, io.TextIOWrapper): value = value.read() super().__init__(value, *args, **kwargs) # from https://discuss.pytorch.org/t/using-scikit-learns-scalers-for-torchvision/53455 class PyTMinMaxScalerVectorized(object): """ Transforms each channel to the range [0, 1]. """ def __init__(self, dimension=-1): self.d = dimension def __call__(self, tensor): d = self.d scale = 1.0 / ( tensor.max(dim=d, keepdim=True)[0] - tensor.min(dim=d, keepdim=True)[0] ) tensor.mul_(scale).sub_(tensor.min(dim=d, keepdim=True)[0]) return tensor # copied out of captum because we need raw html instead of a jupyter widget def visualize_text(datarecords, legend=False): dom = ["
True Label | " "Predicted Label | " "Attribution Label | " # "Attribution Score | " "Word Importance | " ] for datarecord in datarecords: rows.append( "".join( [ "
---|---|---|---|---|