import torch import torch.nn as nn import json import os import inspect class AestheticScorer(nn.Module): def __init__( self, input_size=0, use_activation=False, dropout=0.2, config=None, hidden_dim=1024, reduce_dims=False, output_activation=None, ): super().__init__() self.config = { "input_size": input_size, "use_activation": use_activation, "dropout": dropout, "hidden_dim": hidden_dim, "reduce_dims": reduce_dims, "output_activation": output_activation, } if config != None: self.config.update(config) layers = [ nn.Linear(self.config["input_size"], self.config["hidden_dim"]), nn.ReLU() if self.config["use_activation"] else None, nn.Dropout(self.config["dropout"]), nn.Linear( self.config["hidden_dim"], round(self.config["hidden_dim"] / (2 if reduce_dims else 1)), ), nn.ReLU() if self.config["use_activation"] else None, nn.Dropout(self.config["dropout"]), nn.Linear( round(self.config["hidden_dim"] / (2 if reduce_dims else 1)), round(self.config["hidden_dim"] / (4 if reduce_dims else 1)), ), nn.ReLU() if self.config["use_activation"] else None, nn.Dropout(self.config["dropout"]), nn.Linear( round(self.config["hidden_dim"] / (4 if reduce_dims else 1)), round(self.config["hidden_dim"] / (8 if reduce_dims else 1)), ), nn.ReLU() if self.config["use_activation"] else None, nn.Linear(round(self.config["hidden_dim"] / (8 if reduce_dims else 1)), 1), ] if self.config["output_activation"] == "sigmoid": layers.append(nn.Sigmoid()) layers = [x for x in layers if x is not None] self.layers = nn.Sequential(*layers) def forward(self, x): if self.config["output_activation"] == "sigmoid": upper, lower = 10, 1 scale = upper - lower return (self.layers(x) * scale) + lower else: return self.layers(x) def save(self, save_name): split_name = os.path.splitext(save_name) with open(f"{split_name[0]}.config", "w") as outfile: outfile.write(json.dumps(self.config, indent=4)) for i in range( 6 ): # saving sometiles fails, so retry 5 times, might be windows issue try: torch.save(self.state_dict(), save_name) break except RuntimeError as e: # check if error contains string "File" if "cannot be opened" in str(e) and i < 5: print("Model save failed, retrying...") else: raise e def preprocess(embeddings): return embeddings / embeddings.norm(p=2, dim=-1, keepdim=True) def load_model(weight_name, device="cuda" if torch.cuda.is_available() else "cpu"): weight_folder = os.path.abspath( os.path.join( inspect.getfile(load_model), "../weights", ) ) weight_path = os.path.join(weight_folder, f"{weight_name}.pth") config_path = os.path.join(weight_folder, f"{weight_name}.config") with open(config_path, "r") as config_file: config = json.load(config_file) model = AestheticScorer(config=config) model.load_state_dict(torch.load(weight_path, map_location=device)) model.eval() return model