Spaces:
Runtime error
Runtime error
from substra_helpers.substra_runner import SubstraRunner, algo_generator | |
from substra_helpers.model import CNN | |
from substra_helpers.dataset import TorchDataset | |
from substrafl.strategies import FedAvg | |
import torch | |
from dotenv import load_dotenv | |
import os | |
load_dotenv() | |
NUM_CLIENTS = int(os.environ["NUM_CLIENTS"]) | |
seed = 42 | |
torch.manual_seed(seed) | |
model = CNN() | |
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) | |
criterion = torch.nn.CrossEntropyLoss() | |
runner = SubstraRunner(num_clients=NUM_CLIENTS) | |
runner.set_up_clients() | |
runner.prepare_data() | |
runner.register_data() | |
runner.register_metric() | |
runner.algorithm = algo_generator( | |
model=model, | |
criterion=criterion, | |
optimizer=optimizer, | |
index_generator=runner.index_generator, | |
dataset=TorchDataset, | |
seed=seed | |
)() | |
runner.strategy = FedAvg() | |
runner.set_aggregation() | |
runner.set_testing() | |
runner.run_compute_plan() | |