Spaces:
Runtime error
Runtime error
File size: 900 Bytes
04a30fc 73834eb 04a30fc 73834eb 04a30fc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 |
from substra_helpers.substra_runner import SubstraRunner, algo_generator
from substra_helpers.model import CNN
from substra_helpers.dataset import TorchDataset
from substrafl.strategies import FedAvg
import torch
from dotenv import load_dotenv
import os
load_dotenv()
NUM_CLIENTS = int(os.environ["NUM_CLIENTS"])
seed = 42
torch.manual_seed(seed)
model = CNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
runner = SubstraRunner(num_clients=NUM_CLIENTS)
runner.set_up_clients()
runner.prepare_data()
runner.register_data()
runner.register_metric()
runner.algorithm = algo_generator(
model=model,
criterion=criterion,
optimizer=optimizer,
index_generator=runner.index_generator,
dataset=TorchDataset,
seed=seed
)()
runner.strategy = FedAvg()
runner.set_aggregation()
runner.set_testing()
runner.run_compute_plan()
|