Sadjad Alikhani
commited on
Commit
•
da21525
1
Parent(s):
fcd80b2
Update input_preprocess.py
Browse files- input_preprocess.py +2 -1
input_preprocess.py
CHANGED
@@ -225,6 +225,7 @@ def make_sample(user_idx, patch, word2id, n_patches, n_masks, patch_size, gen_ra
|
|
225 |
Returns:
|
226 |
sample (list): Generated sample for the user.
|
227 |
"""
|
|
|
228 |
|
229 |
tokens = patch[user_idx]
|
230 |
input_ids = np.vstack((word2id['[CLS]'], tokens))
|
@@ -244,7 +245,7 @@ def make_sample(user_idx, patch, word2id, n_patches, n_masks, patch_size, gen_ra
|
|
244 |
input_ids[pos] = np.random.rand(patch_size)
|
245 |
elif rnd_num < 0.9:
|
246 |
input_ids[pos] = word2id['[MASK]']
|
247 |
-
|
248 |
return [input_ids, masked_tokens, masked_pos]
|
249 |
|
250 |
|
|
|
225 |
Returns:
|
226 |
sample (list): Generated sample for the user.
|
227 |
"""
|
228 |
+
set_random_seed()
|
229 |
|
230 |
tokens = patch[user_idx]
|
231 |
input_ids = np.vstack((word2id['[CLS]'], tokens))
|
|
|
245 |
input_ids[pos] = np.random.rand(patch_size)
|
246 |
elif rnd_num < 0.9:
|
247 |
input_ids[pos] = word2id['[MASK]']
|
248 |
+
print(f'masked_pos: {masked_pos}')
|
249 |
return [input_ids, masked_tokens, masked_pos]
|
250 |
|
251 |
|