xmutly's picture
Upload 294 files
e1aaaac verified
raw
history blame
3.06 kB
import json
from dataclasses import dataclass, field, asdict
from datetime import datetime
from pathlib import Path
from typing import Optional, Set
import warnings
import torch
@dataclass
class EvaluationState:
_attacks_to_run: Set[str]
path: Optional[Path] = None
_run_attacks: Set[str] = field(default_factory=set)
_robust_flags: Optional[torch.Tensor] = None
_last_saved: datetime = datetime(1, 1, 1)
_SAVE_TIMEOUT: int = 60
_clean_accuracy: float = float("nan")
def to_disk(self, force: bool = False) -> None:
seconds_since_last_save = (datetime.now() -
self._last_saved).total_seconds()
if self.path is None or (seconds_since_last_save < self._SAVE_TIMEOUT
and not force):
return
self._last_saved = datetime.now()
d = asdict(self)
if self.robust_flags is not None:
d["_robust_flags"] = d["_robust_flags"].cpu().tolist()
d["_run_attacks"] = list(self._run_attacks)
with self.path.open("w", ) as f:
json.dump(d, f, default=str)
@classmethod
def from_disk(cls, path: Path) -> "EvaluationState":
with path.open("r") as f:
d = json.load(f)
d["_robust_flags"] = torch.tensor(d["_robust_flags"], dtype=torch.bool)
d["path"] = Path(d["path"])
if path != d["path"]:
warnings.warn(
UserWarning(
"The given path is different from the one found in the state file."
))
d["_last_saved"] = datetime.fromisoformat(d["_last_saved"])
return cls(**d)
@property
def robust_flags(self) -> Optional[torch.Tensor]:
return self._robust_flags
@robust_flags.setter
def robust_flags(self, robust_flags: torch.Tensor) -> None:
self._robust_flags = robust_flags
self.to_disk(force=True)
@property
def run_attacks(self) -> Set[str]:
return self._run_attacks
def add_run_attack(self, attack: str) -> None:
self._run_attacks.add(attack)
self.to_disk()
@property
def attacks_to_run(self) -> Set[str]:
return self._attacks_to_run
@attacks_to_run.setter
def attacks_to_run(self, _: Set[str]) -> None:
raise ValueError("attacks_to_run cannot be set outside of the constructor")
@property
def clean_accuracy(self) -> float:
return self._clean_accuracy
@clean_accuracy.setter
def clean_accuracy(self, accuracy) -> None:
self._clean_accuracy = accuracy
self.to_disk(force=True)
@property
def robust_accuracy(self) -> float:
if self.robust_flags is None:
raise ValueError("robust_flags is not set yet. Start the attack first.")
if self.attacks_to_run - self.run_attacks:
warnings.warn("You are checking `robust_accuracy` before all the attacks"
" have been run.")
return self.robust_flags.float().mean().item()