import os os.environ["KERAS_BACKEND"] = "torch" import pytorch_lightning as pl import torch import torch.nn as nn from keras.layers import Input, Dense from keras.models import Model from model.tcn_module import TCN class TCNModel(pl.LightningModule): def __init__(self, **config): super(TCNModel, self).__init__() self.save_hyperparameters(config) input_layer = Input(shape=(self.hparams.windows_size, self.hparams.input_size)) self.tcn = TCN(input_shape=(self.hparams.windows_size, self.hparams.input_size))(input_layer) self.linear = Dense(7)(self.tcn) self.model = Model(inputs=input_layer, outputs=self.linear) def forward(self, x): output = self.model(x) return torch.stack([output], dim=1) def move_custom_layers_to_device(model, device): for name, module in model.named_children(): # 如果是标准层,named_children已经处理了 if isinstance(module, nn.Module): continue # 对于非标准层,例如包含在列表或字典中的层 if isinstance(module, list): for sub_module in module: if isinstance(sub_module, nn.Module): sub_module.to(device) elif isinstance(module, dict): for sub_module in module.values(): if isinstance(sub_module, nn.Module): sub_module.to(device)