jemfu commited on
Commit
6adb9db
1 Parent(s): 6949f0a

fix: sagemaker import issue

Browse files
Files changed (1) hide show
  1. block.py +59 -0
block.py CHANGED
@@ -7,6 +7,7 @@ from functools import partial
7
  from typing import Optional
8
 
9
  import torch
 
10
  import torch.nn as nn
11
  import torch.nn.functional as F
12
  from torch import Tensor
@@ -21,6 +22,64 @@ except ImportError:
21
  layer_norm_fn, RMSNorm = None, None
22
 
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  class Block(nn.Module):
25
  def __init__(
26
  self,
 
7
  from typing import Optional
8
 
9
  import torch
10
+ import torch.fx
11
  import torch.nn as nn
12
  import torch.nn.functional as F
13
  from torch import Tensor
 
22
  layer_norm_fn, RMSNorm = None, None
23
 
24
 
25
+ def stochastic_depth(
26
+ input: Tensor, p: float, mode: str, training: bool = True
27
+ ) -> Tensor:
28
+ """
29
+ Implements the Stochastic Depth from `"Deep Networks with Stochastic Depth"
30
+ <https://arxiv.org/abs/1603.09382>`_ used for randomly dropping residual
31
+ branches of residual architectures.
32
+ Args:
33
+ input (Tensor[N, ...]): The input tensor or arbitrary dimensions with the first one
34
+ being its batch i.e. a batch with ``N`` rows.
35
+ p (float): probability of the input to be zeroed.
36
+ mode (str): ``"batch"`` or ``"row"``.
37
+ ``"batch"`` randomly zeroes the entire input, ``"row"`` zeroes
38
+ randomly selected rows from the batch.
39
+ training: apply stochastic depth if is ``True``. Default: ``True``
40
+ Returns:
41
+ Tensor[N, ...]: The randomly zeroed tensor.
42
+ """
43
+ if p < 0.0 or p > 1.0:
44
+ raise ValueError(f"drop probability has to be between 0 and 1, but got {p}")
45
+ if mode not in ["batch", "row"]:
46
+ raise ValueError(f"mode has to be either 'batch' or 'row', but got {mode}")
47
+ if not training or p == 0.0:
48
+ return input
49
+
50
+ survival_rate = 1.0 - p
51
+ if mode == "row":
52
+ size = [input.shape[0]] + [1] * (input.ndim - 1)
53
+ else:
54
+ size = [1] * input.ndim
55
+ noise = torch.empty(size, dtype=input.dtype, device=input.device)
56
+ noise = noise.bernoulli_(survival_rate)
57
+ if survival_rate > 0.0:
58
+ noise.div_(survival_rate)
59
+ return input * noise
60
+
61
+
62
+ torch.fx.wrap("stochastic_depth")
63
+
64
+
65
+ class StochasticDepth(nn.Module):
66
+ """
67
+ See :func:`stochastic_depth`.
68
+ """
69
+
70
+ def __init__(self, p: float, mode: str) -> None:
71
+ super().__init__()
72
+ self.p = p
73
+ self.mode = mode
74
+
75
+ def forward(self, input: Tensor) -> Tensor:
76
+ return stochastic_depth(input, self.p, self.mode, self.training)
77
+
78
+ def __repr__(self) -> str:
79
+ s = f"{self.__class__.__name__}(p={self.p}, mode={self.mode})"
80
+ return s
81
+
82
+
83
  class Block(nn.Module):
84
  def __init__(
85
  self,