chenjgtea commited on
Commit
f4f441a
·
1 Parent(s): 2aa8287

torch 更新

Browse files
Files changed (1) hide show
  1. tool/ctx.py +4 -0
tool/ctx.py CHANGED
@@ -1,13 +1,17 @@
1
  import torch
 
 
2
 
3
  class TorchSeedContext:
4
  def __init__(self, seed):
5
  self.seed = seed
6
  self.state = None
7
 
 
8
  def __enter__(self):
9
  self.state = torch.random.get_rng_state()
10
  torch.manual_seed(self.seed)
11
 
 
12
  def __exit__(self, type, value, traceback):
13
  torch.random.set_rng_state(self.state)
 
1
  import torch
2
+ import spaces
3
+
4
 
5
  class TorchSeedContext:
6
  def __init__(self, seed):
7
  self.seed = seed
8
  self.state = None
9
 
10
+ @spaces.GPU
11
  def __enter__(self):
12
  self.state = torch.random.get_rng_state()
13
  torch.manual_seed(self.seed)
14
 
15
+ @spaces.GPU
16
  def __exit__(self, type, value, traceback):
17
  torch.random.set_rng_state(self.state)