substra / substra_template /run_compute_plan.py
NimaBoscarino's picture
WIP: Substra orchestrator
73834eb
raw
history blame
900 Bytes
from substra_helpers.substra_runner import SubstraRunner, algo_generator
from substra_helpers.model import CNN
from substra_helpers.dataset import TorchDataset
from substrafl.strategies import FedAvg
import torch
from dotenv import load_dotenv
import os
load_dotenv()
NUM_CLIENTS = int(os.environ["NUM_CLIENTS"])
seed = 42
torch.manual_seed(seed)
model = CNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
runner = SubstraRunner(num_clients=NUM_CLIENTS)
runner.set_up_clients()
runner.prepare_data()
runner.register_data()
runner.register_metric()
runner.algorithm = algo_generator(
model=model,
criterion=criterion,
optimizer=optimizer,
index_generator=runner.index_generator,
dataset=TorchDataset,
seed=seed
)()
runner.strategy = FedAvg()
runner.set_aggregation()
runner.set_testing()
runner.run_compute_plan()