calbors commited on
Commit
12583d4
·
verified ·
1 Parent(s): 40ca248

Upload model

Browse files
Files changed (2) hide show
  1. model.safetensors +1 -1
  2. modeling_phylogpn.py +34 -19
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:e83fde25faac20fb64c1964c7c8f2779059260d11acdc707d71014716880cb77
3
  size 332799280
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65f05a93d49be782d608ddaddd3ed056077922e26890d7acd53b35ad8e7fe540
3
  size 332799280
modeling_phylogpn.py CHANGED
@@ -31,12 +31,20 @@ class RCEWeight(nn.Module):
31
  )
32
 
33
  super().__init__()
34
- self.input_involution_indices = input_involution_indices
35
- self.output_involution_indices = output_involution_indices
 
 
 
36
 
37
  def forward(self, x: torch.Tensor) -> torch.Tensor:
38
- output_involution_indices = torch.tensor(self.output_involution_indices, device=x.device)
39
- input_involution_indices = torch.tensor(self.input_involution_indices, device=x.device)
 
 
 
 
 
40
  return (x + x[output_involution_indices][:, input_involution_indices].flip(2)) / 2
41
 
42
 
@@ -46,10 +54,16 @@ class IEBias(nn.Module):
46
  raise ValueError("`involution_indices` must be an involution")
47
 
48
  super().__init__()
49
- self.involution_indices = involution_indices
 
 
50
 
51
  def forward(self, x: torch.Tensor) -> torch.Tensor:
52
- involution_indices = torch.tensor(self.involution_indices, device=x.device)
 
 
 
 
53
  return (x + x[involution_indices]) / 2
54
 
55
 
@@ -64,23 +78,25 @@ class IEWeight(nn.Module):
64
  )
65
 
66
  super().__init__()
67
- self.input_involution_indices = input_involution_indices
68
- self.output_involution_indices = output_involution_indices
 
 
 
69
 
70
  def forward(self, x: torch.Tensor) -> torch.Tensor:
71
- input_involution_indices = torch.tensor(self.input_involution_indices, device=x.device)
72
- output_involution_indices = torch.tensor(self.output_involution_indices, device=x.device)
 
 
 
 
 
73
  return (x + x[input_involution_indices][:, output_involution_indices]) / 2
74
-
75
 
76
  class RCEByteNetBlock(nn.Module):
77
- def __init__(
78
- self,
79
- outer_involution_indices: List[int],
80
- inner_dim: int,
81
- kernel_size: int,
82
- dilation_rate: int = 1
83
- ):
84
  outer_dim = len(outer_involution_indices)
85
 
86
  if outer_dim % 2 != 0:
@@ -130,7 +146,6 @@ class RCEByteNetBlock(nn.Module):
130
  layers[8], "bias",
131
  IEBias(outer_involution_indices)
132
  )
133
-
134
  self.layers = nn.Sequential(*layers)
135
  self._kernel_size = kernel_size
136
  self._dilation_rate = dilation_rate
 
31
  )
32
 
33
  super().__init__()
34
+ self._input_involution_indices = input_involution_indices
35
+ self._output_involution_indices = output_involution_indices
36
+ self._input_involution_index_tensor = None
37
+ self._output_involution_index_tensor = None
38
+ self._device = None
39
 
40
  def forward(self, x: torch.Tensor) -> torch.Tensor:
41
+ if self._device != x.device:
42
+ self._input_involution_index_tensor = torch.tensor(self._input_involution_indices, device=x.device)
43
+ self._output_involution_index_tensor = torch.tensor(self._output_involution_indices, device=x.device)
44
+ self._device = x.device
45
+
46
+ output_involution_indices = self._output_involution_index_tensor
47
+ input_involution_indices = self._input_involution_index_tensor
48
  return (x + x[output_involution_indices][:, input_involution_indices].flip(2)) / 2
49
 
50
 
 
54
  raise ValueError("`involution_indices` must be an involution")
55
 
56
  super().__init__()
57
+ self._involution_indices = involution_indices
58
+ self._involution_index_tensor = None
59
+ self._device = None
60
 
61
  def forward(self, x: torch.Tensor) -> torch.Tensor:
62
+ if self._device != x.device:
63
+ self._involution_index_tensor = torch.tensor(self._involution_indices, device=x.device)
64
+ self._device = x.device
65
+
66
+ involution_indices = self._involution_index_tensor
67
  return (x + x[involution_indices]) / 2
68
 
69
 
 
78
  )
79
 
80
  super().__init__()
81
+ self._input_involution_indices = input_involution_indices
82
+ self._output_involution_indices = output_involution_indices
83
+ self._input_involution_index_tensor = None
84
+ self._output_involution_index_tensor = None
85
+ self._device = None
86
 
87
  def forward(self, x: torch.Tensor) -> torch.Tensor:
88
+ if self._device != x.device:
89
+ self._input_involution_index_tensor = torch.tensor(self._input_involution_indices, device=x.device)
90
+ self._output_involution_index_tensor = torch.tensor(self._output_involution_indices, device=x.device)
91
+ self._device = x.device
92
+
93
+ output_involution_indices = self._output_involution_index_tensor
94
+ input_involution_indices = self._input_involution_index_tensor
95
  return (x + x[input_involution_indices][:, output_involution_indices]) / 2
96
+
97
 
98
  class RCEByteNetBlock(nn.Module):
99
+ def __init__(self, outer_involution_indices: List[int], inner_dim: int, kernel_size: int, dilation_rate: int = 1):
 
 
 
 
 
 
100
  outer_dim = len(outer_involution_indices)
101
 
102
  if outer_dim % 2 != 0:
 
146
  layers[8], "bias",
147
  IEBias(outer_involution_indices)
148
  )
 
149
  self.layers = nn.Sequential(*layers)
150
  self._kernel_size = kernel_size
151
  self._dilation_rate = dilation_rate