anicolson commited on
Commit
dca7dcb
·
verified ·
1 Parent(s): 9d4087e

Upload model

Browse files
Files changed (4) hide show
  1. config.json +1 -1
  2. dataset.py +1 -1
  3. generation_config.json +1 -1
  4. modelling_cxrmate_ed.py +21 -0
config.json CHANGED
@@ -105,7 +105,7 @@
105
  },
106
  "time_delta_monotonic_inversion": true,
107
  "torch_dtype": "float32",
108
- "transformers_version": "4.46.3",
109
  "vision_config": {
110
  "_attn_implementation_autoset": false,
111
  "_name_or_path": "aehrc/uniformer_base_tl_384",
 
105
  },
106
  "time_delta_monotonic_inversion": true,
107
  "torch_dtype": "float32",
108
+ "transformers_version": "4.49.0",
109
  "vision_config": {
110
  "_attn_implementation_autoset": false,
111
  "_name_or_path": "aehrc/uniformer_base_tl_384",
dataset.py CHANGED
@@ -33,4 +33,4 @@ class PriorsDataset:
33
  batch = self.__getitem__(keys)
34
  n_examples = len(batch[next(iter(batch))])
35
  return [{col: array[i] for col, array in batch.items()} for i in range(n_examples)]
36
-
 
33
  batch = self.__getitem__(keys)
34
  n_examples = len(batch[next(iter(batch))])
35
  return [{col: array[i] for col, array in batch.items()} for i in range(n_examples)]
36
+
generation_config.json CHANGED
@@ -3,5 +3,5 @@
3
  "bos_token_id": 1,
4
  "eos_token_id": 2,
5
  "pad_token_id": 4,
6
- "transformers_version": "4.46.3"
7
  }
 
3
  "bos_token_id": 1,
4
  "eos_token_id": 2,
5
  "pad_token_id": 4,
6
+ "transformers_version": "4.49.0"
7
  }
modelling_cxrmate_ed.py CHANGED
@@ -796,8 +796,29 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
796
  position_ids[row_indices, col_indices.flatten()] = torch.arange(num_cols, device=time_deltas.device)[None, :].expand(num_rows, -1).flatten()
797
  position_ids.masked_fill_(attention_mask == 0, 1) # Following: https://github.com/huggingface/transformers/blob/c5f0288bc7d76f65996586f79f69fba8867a0e67/src/transformers/models/llama/modeling_llama.py#L1285
798
 
 
 
 
799
  return position_ids
800
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
801
  def prepare_index_value_feats(self, table, batch):
802
 
803
  index_value_columns = (self.tables[table].get('index_columns', []) + self.tables[table].get('value_columns', []))
 
796
  position_ids[row_indices, col_indices.flatten()] = torch.arange(num_cols, device=time_deltas.device)[None, :].expand(num_rows, -1).flatten()
797
  position_ids.masked_fill_(attention_mask == 0, 1) # Following: https://github.com/huggingface/transformers/blob/c5f0288bc7d76f65996586f79f69fba8867a0e67/src/transformers/models/llama/modeling_llama.py#L1285
798
 
799
+ for i in range(position_ids.shape[0]):
800
+ assert self.validate_position_ids(position_ids[i])
801
+
802
  return position_ids
803
 
804
+ @staticmethod
805
+ def validate_position_ids(tensor, repeat_value=1):
806
+ unique, counts = torch.unique(tensor, return_counts=True)
807
+
808
+ # Check if all integers from 0 to tensor.max() exist:
809
+ full_range = torch.arange(0, tensor.max() + 1, device=tensor.device)
810
+ if not torch.equal(unique.sort()[0], full_range):
811
+ return False
812
+
813
+ # Check for repeated values except for repeat_value:
814
+ repeated = unique[counts > 1]
815
+ if repeated.nelement() == 0:
816
+ return True
817
+ if not (repeated.numel() == 1 and repeated.item() == repeat_value):
818
+ return False
819
+
820
+ return True
821
+
822
  def prepare_index_value_feats(self, table, batch):
823
 
824
  index_value_columns = (self.tables[table].get('index_columns', []) + self.tables[table].get('value_columns', []))