chenjgtea
commited on
Commit
·
f4f441a
1
Parent(s):
2aa8287
torch 更新
Browse files- tool/ctx.py +4 -0
tool/ctx.py
CHANGED
@@ -1,13 +1,17 @@
|
|
1 |
import torch
|
|
|
|
|
2 |
|
3 |
class TorchSeedContext:
|
4 |
def __init__(self, seed):
|
5 |
self.seed = seed
|
6 |
self.state = None
|
7 |
|
|
|
8 |
def __enter__(self):
|
9 |
self.state = torch.random.get_rng_state()
|
10 |
torch.manual_seed(self.seed)
|
11 |
|
|
|
12 |
def __exit__(self, type, value, traceback):
|
13 |
torch.random.set_rng_state(self.state)
|
|
|
1 |
import torch
|
2 |
+
import spaces
|
3 |
+
|
4 |
|
5 |
class TorchSeedContext:
|
6 |
def __init__(self, seed):
|
7 |
self.seed = seed
|
8 |
self.state = None
|
9 |
|
10 |
+
@spaces.GPU
|
11 |
def __enter__(self):
|
12 |
self.state = torch.random.get_rng_state()
|
13 |
torch.manual_seed(self.seed)
|
14 |
|
15 |
+
@spaces.GPU
|
16 |
def __exit__(self, type, value, traceback):
|
17 |
torch.random.set_rng_state(self.state)
|