Sadjad Alikhani commited on
Commit
da21525
1 Parent(s): fcd80b2

Update input_preprocess.py

Browse files
Files changed (1) hide show
  1. input_preprocess.py +2 -1
input_preprocess.py CHANGED
@@ -225,6 +225,7 @@ def make_sample(user_idx, patch, word2id, n_patches, n_masks, patch_size, gen_ra
225
  Returns:
226
  sample (list): Generated sample for the user.
227
  """
 
228
 
229
  tokens = patch[user_idx]
230
  input_ids = np.vstack((word2id['[CLS]'], tokens))
@@ -244,7 +245,7 @@ def make_sample(user_idx, patch, word2id, n_patches, n_masks, patch_size, gen_ra
244
  input_ids[pos] = np.random.rand(patch_size)
245
  elif rnd_num < 0.9:
246
  input_ids[pos] = word2id['[MASK]']
247
-
248
  return [input_ids, masked_tokens, masked_pos]
249
 
250
 
 
225
  Returns:
226
  sample (list): Generated sample for the user.
227
  """
228
+ set_random_seed()
229
 
230
  tokens = patch[user_idx]
231
  input_ids = np.vstack((word2id['[CLS]'], tokens))
 
245
  input_ids[pos] = np.random.rand(patch_size)
246
  elif rnd_num < 0.9:
247
  input_ids[pos] = word2id['[MASK]']
248
+ print(f'masked_pos: {masked_pos}')
249
  return [input_ids, masked_tokens, masked_pos]
250
 
251