File size: 900 Bytes
04a30fc
 
 
 
 
 
 
73834eb
 
 
 
 
 
04a30fc
 
 
 
 
 
73834eb
04a30fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
from substra_helpers.substra_runner import SubstraRunner, algo_generator
from substra_helpers.model import CNN
from substra_helpers.dataset import TorchDataset
from substrafl.strategies import FedAvg

import torch

from dotenv import load_dotenv
import os
load_dotenv()

NUM_CLIENTS = int(os.environ["NUM_CLIENTS"])

seed = 42
torch.manual_seed(seed)
model = CNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()

runner = SubstraRunner(num_clients=NUM_CLIENTS)
runner.set_up_clients()
runner.prepare_data()
runner.register_data()
runner.register_metric()

runner.algorithm = algo_generator(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    index_generator=runner.index_generator,
    dataset=TorchDataset,
    seed=seed
)()

runner.strategy = FedAvg()

runner.set_aggregation()
runner.set_testing()

runner.run_compute_plan()