import difflib import torch def get_layer(l_name, library=torch.nn): """Return layer object handler from library e.g. from torch.nn E.g. if l_name=="elu", returns torch.nn.ELU. Args: l_name (string): Case insensitive name for layer in library (e.g. .'elu'). library (module): Name of library/module where to search for object handler with l_name e.g. "torch.nn". Returns: layer_handler (object): handler for the requested layer e.g. (torch.nn.ELU) """ all_torch_layers = [x for x in dir(torch.nn)] match = [x for x in all_torch_layers if l_name.lower() == x.lower()] if len(match) == 0: close_matches = difflib.get_close_matches( l_name, [x.lower() for x in all_torch_layers] ) raise NotImplementedError( "Layer with name {} not found in {}.\n Closest matches: {}".format( l_name, str(library), close_matches ) ) elif len(match) > 1: close_matches = difflib.get_close_matches( l_name, [x.lower() for x in all_torch_layers] ) raise NotImplementedError( "Multiple matchs for layer with name {} not found in {}.\n " "All matches: {}".format(l_name, str(library), close_matches) ) else: # valid layer_handler = getattr(library, match[0]) return layer_handler