Upload model
Browse files- config.json +1 -1
- dataset.py +1 -1
- generation_config.json +1 -1
- modelling_cxrmate_ed.py +21 -0
config.json
CHANGED
@@ -105,7 +105,7 @@
|
|
105 |
},
|
106 |
"time_delta_monotonic_inversion": true,
|
107 |
"torch_dtype": "float32",
|
108 |
-
"transformers_version": "4.
|
109 |
"vision_config": {
|
110 |
"_attn_implementation_autoset": false,
|
111 |
"_name_or_path": "aehrc/uniformer_base_tl_384",
|
|
|
105 |
},
|
106 |
"time_delta_monotonic_inversion": true,
|
107 |
"torch_dtype": "float32",
|
108 |
+
"transformers_version": "4.49.0",
|
109 |
"vision_config": {
|
110 |
"_attn_implementation_autoset": false,
|
111 |
"_name_or_path": "aehrc/uniformer_base_tl_384",
|
dataset.py
CHANGED
@@ -33,4 +33,4 @@ class PriorsDataset:
|
|
33 |
batch = self.__getitem__(keys)
|
34 |
n_examples = len(batch[next(iter(batch))])
|
35 |
return [{col: array[i] for col, array in batch.items()} for i in range(n_examples)]
|
36 |
-
|
|
|
33 |
batch = self.__getitem__(keys)
|
34 |
n_examples = len(batch[next(iter(batch))])
|
35 |
return [{col: array[i] for col, array in batch.items()} for i in range(n_examples)]
|
36 |
+
|
generation_config.json
CHANGED
@@ -3,5 +3,5 @@
|
|
3 |
"bos_token_id": 1,
|
4 |
"eos_token_id": 2,
|
5 |
"pad_token_id": 4,
|
6 |
-
"transformers_version": "4.
|
7 |
}
|
|
|
3 |
"bos_token_id": 1,
|
4 |
"eos_token_id": 2,
|
5 |
"pad_token_id": 4,
|
6 |
+
"transformers_version": "4.49.0"
|
7 |
}
|
modelling_cxrmate_ed.py
CHANGED
@@ -796,8 +796,29 @@ class CXRMateEDModel(transformers.LlavaForConditionalGeneration):
|
|
796 |
position_ids[row_indices, col_indices.flatten()] = torch.arange(num_cols, device=time_deltas.device)[None, :].expand(num_rows, -1).flatten()
|
797 |
position_ids.masked_fill_(attention_mask == 0, 1) # Following: https://github.com/huggingface/transformers/blob/c5f0288bc7d76f65996586f79f69fba8867a0e67/src/transformers/models/llama/modeling_llama.py#L1285
|
798 |
|
|
|
|
|
|
|
799 |
return position_ids
|
800 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
801 |
def prepare_index_value_feats(self, table, batch):
|
802 |
|
803 |
index_value_columns = (self.tables[table].get('index_columns', []) + self.tables[table].get('value_columns', []))
|
|
|
796 |
position_ids[row_indices, col_indices.flatten()] = torch.arange(num_cols, device=time_deltas.device)[None, :].expand(num_rows, -1).flatten()
|
797 |
position_ids.masked_fill_(attention_mask == 0, 1) # Following: https://github.com/huggingface/transformers/blob/c5f0288bc7d76f65996586f79f69fba8867a0e67/src/transformers/models/llama/modeling_llama.py#L1285
|
798 |
|
799 |
+
for i in range(position_ids.shape[0]):
|
800 |
+
assert self.validate_position_ids(position_ids[i])
|
801 |
+
|
802 |
return position_ids
|
803 |
|
804 |
+
@staticmethod
|
805 |
+
def validate_position_ids(tensor, repeat_value=1):
|
806 |
+
unique, counts = torch.unique(tensor, return_counts=True)
|
807 |
+
|
808 |
+
# Check if all integers from 0 to tensor.max() exist:
|
809 |
+
full_range = torch.arange(0, tensor.max() + 1, device=tensor.device)
|
810 |
+
if not torch.equal(unique.sort()[0], full_range):
|
811 |
+
return False
|
812 |
+
|
813 |
+
# Check for repeated values except for repeat_value:
|
814 |
+
repeated = unique[counts > 1]
|
815 |
+
if repeated.nelement() == 0:
|
816 |
+
return True
|
817 |
+
if not (repeated.numel() == 1 and repeated.item() == repeat_value):
|
818 |
+
return False
|
819 |
+
|
820 |
+
return True
|
821 |
+
|
822 |
def prepare_index_value_feats(self, table, batch):
|
823 |
|
824 |
index_value_columns = (self.tables[table].get('index_columns', []) + self.tables[table].get('value_columns', []))
|