|
import json |
|
from dataclasses import dataclass, field, asdict |
|
from datetime import datetime |
|
from pathlib import Path |
|
from typing import Optional, Set |
|
import warnings |
|
|
|
import torch |
|
|
|
|
|
@dataclass |
|
class EvaluationState: |
|
_attacks_to_run: Set[str] |
|
path: Optional[Path] = None |
|
_run_attacks: Set[str] = field(default_factory=set) |
|
_robust_flags: Optional[torch.Tensor] = None |
|
_last_saved: datetime = datetime(1, 1, 1) |
|
_SAVE_TIMEOUT: int = 60 |
|
_clean_accuracy: float = float("nan") |
|
|
|
def to_disk(self, force: bool = False) -> None: |
|
seconds_since_last_save = (datetime.now() - |
|
self._last_saved).total_seconds() |
|
if self.path is None or (seconds_since_last_save < self._SAVE_TIMEOUT |
|
and not force): |
|
return |
|
self._last_saved = datetime.now() |
|
d = asdict(self) |
|
if self.robust_flags is not None: |
|
d["_robust_flags"] = d["_robust_flags"].cpu().tolist() |
|
d["_run_attacks"] = list(self._run_attacks) |
|
with self.path.open("w", ) as f: |
|
json.dump(d, f, default=str) |
|
|
|
@classmethod |
|
def from_disk(cls, path: Path) -> "EvaluationState": |
|
with path.open("r") as f: |
|
d = json.load(f) |
|
d["_robust_flags"] = torch.tensor(d["_robust_flags"], dtype=torch.bool) |
|
d["path"] = Path(d["path"]) |
|
if path != d["path"]: |
|
warnings.warn( |
|
UserWarning( |
|
"The given path is different from the one found in the state file." |
|
)) |
|
d["_last_saved"] = datetime.fromisoformat(d["_last_saved"]) |
|
return cls(**d) |
|
|
|
@property |
|
def robust_flags(self) -> Optional[torch.Tensor]: |
|
return self._robust_flags |
|
|
|
@robust_flags.setter |
|
def robust_flags(self, robust_flags: torch.Tensor) -> None: |
|
self._robust_flags = robust_flags |
|
self.to_disk(force=True) |
|
|
|
@property |
|
def run_attacks(self) -> Set[str]: |
|
return self._run_attacks |
|
|
|
def add_run_attack(self, attack: str) -> None: |
|
self._run_attacks.add(attack) |
|
self.to_disk() |
|
|
|
@property |
|
def attacks_to_run(self) -> Set[str]: |
|
return self._attacks_to_run |
|
|
|
@attacks_to_run.setter |
|
def attacks_to_run(self, _: Set[str]) -> None: |
|
raise ValueError("attacks_to_run cannot be set outside of the constructor") |
|
|
|
@property |
|
def clean_accuracy(self) -> float: |
|
return self._clean_accuracy |
|
|
|
@clean_accuracy.setter |
|
def clean_accuracy(self, accuracy) -> None: |
|
self._clean_accuracy = accuracy |
|
self.to_disk(force=True) |
|
|
|
@property |
|
def robust_accuracy(self) -> float: |
|
if self.robust_flags is None: |
|
raise ValueError("robust_flags is not set yet. Start the attack first.") |
|
if self.attacks_to_run - self.run_attacks: |
|
warnings.warn("You are checking `robust_accuracy` before all the attacks" |
|
" have been run.") |
|
return self.robust_flags.float().mean().item() |