Upload 7 files
Browse files- modeling_indictrans.py +9 -6
modeling_indictrans.py
CHANGED
@@ -61,20 +61,23 @@ def shift_tokens_right(input_ids: torch.Tensor, pad_token_id: int, decoder_start
|
|
61 |
|
62 |
|
63 |
def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
|
64 |
-
|
65 |
-
|
|
|
|
|
|
|
66 |
|
67 |
labels_mask = labels == 1
|
68 |
labels[labels_mask] = -100
|
69 |
|
70 |
-
mask = (
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
|
75 |
return decoder_input_ids, decoder_attention_mask, labels
|
76 |
|
77 |
|
|
|
78 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
79 |
def _make_causal_mask(
|
80 |
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|
|
|
61 |
|
62 |
|
63 |
def prepare_decoder_input_ids_label(decoder_input_ids, decoder_attention_mask):
|
64 |
+
new_decoder_input_ids = decoder_input_ids.clone().detach()
|
65 |
+
new_decoder_attention_mask = decoder_attention_mask.clone().detach()
|
66 |
+
|
67 |
+
labels = torch.full(new_decoder_input_ids.size(),-100)
|
68 |
+
labels[:, :-1] = new_decoder_input_ids[:, 1:]
|
69 |
|
70 |
labels_mask = labels == 1
|
71 |
labels[labels_mask] = -100
|
72 |
|
73 |
+
mask = (new_decoder_input_ids == eos_token_id)
|
74 |
+
new_decoder_input_ids[mask] = 1
|
75 |
+
new_decoder_attention_mask[mask] = 0
|
|
|
76 |
|
77 |
return decoder_input_ids, decoder_attention_mask, labels
|
78 |
|
79 |
|
80 |
+
|
81 |
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
|
82 |
def _make_causal_mask(
|
83 |
input_ids_shape: torch.Size, dtype: torch.dtype, device: torch.device, past_key_values_length: int = 0
|