Princess3 commited on
Commit
2d8ca2e
·
verified ·
1 Parent(s): f90d88f

Update x.py

Browse files
Files changed (1) hide show
  1. x.py +6 -6
x.py CHANGED
@@ -94,15 +94,15 @@ class DynamicModel(nn.Module):
94
  layers.append(nn.LeakyReLU(negative_slope=0.01, inplace=True))
95
  elif activation == 'elu':
96
  layers.append(nn.ELU(alpha=1.0, inplace=True))
97
- if dropout := lp.get('dropout', 0.0):
98
  layers.append(nn.Dropout(p=dropout))
99
- if lp.get('memory_augmentation', False):
100
  layers.append(MemoryAugmentationLayer(lp['output_size']))
101
- if lp.get('hybrid_attention', False):
102
  layers.append(HybridAttentionLayer(lp['output_size']))
103
- if lp.get('dynamic_flash_attention', False):
104
  layers.append(DynamicFlashAttentionLayer(lp['output_size']))
105
- if lp.get('magic_state', False):
106
  layers.append(MagicStateLayer(lp['output_size']))
107
  return nn.Sequential(*layers)
108
 
@@ -118,7 +118,7 @@ class DynamicModel(nn.Module):
118
 
119
  def parse_xml_file(file_path):
120
  tree, root, layers = ET.parse(file_path), ET.parse(file_path).getroot(), []
121
- for layer in root.findall('.//layer'):
122
  lp = {
123
  'input_size': int(layer.get('input_size', 128)),
124
  'output_size': int(layer.get('output_size', 256)),
 
94
  layers.append(nn.LeakyReLU(negative_slope=0.01, inplace=True))
95
  elif activation == 'elu':
96
  layers.append(nn.ELU(alpha=1.0, inplace=True))
97
+ if dropout := lp.get('dropout', 0.1):
98
  layers.append(nn.Dropout(p=dropout))
99
+ if lp.get('memory_augmentation', True):
100
  layers.append(MemoryAugmentationLayer(lp['output_size']))
101
+ if lp.get('hybrid_attention', True):
102
  layers.append(HybridAttentionLayer(lp['output_size']))
103
+ if lp.get('dynamic_flash_attention', True):
104
  layers.append(DynamicFlashAttentionLayer(lp['output_size']))
105
+ if lp.get('magic_state', True):
106
  layers.append(MagicStateLayer(lp['output_size']))
107
  return nn.Sequential(*layers)
108
 
 
118
 
119
  def parse_xml_file(file_path):
120
  tree, root, layers = ET.parse(file_path), ET.parse(file_path).getroot(), []
121
+ for layer in root.findall('.//label'):
122
  lp = {
123
  'input_size': int(layer.get('input_size', 128)),
124
  'output_size': int(layer.get('output_size', 256)),