PeteBleackley commited on
Commit
ca642d2
·
1 Parent(s): 1a9032d

Ensure tokenizer is on GPU

Browse files
Files changed (2) hide show
  1. qarac/models/QaracEncoderModel.py +1 -2
  2. scripts.py +1 -1
qarac/models/QaracEncoderModel.py CHANGED
@@ -47,8 +47,7 @@ class QaracEncoderModel(transformers.PreTrainedModel):
47
  Vector representing the document
48
 
49
  """
50
- print('Encoder',self.encoder.device)
51
- print('Head',self.head.device)
52
  if attention_mask is None and 'attention_mask' in input_ids:
53
  (input_ids,attention_mask) = (input_ids['input_ids'],input_ids['attention_mask'])
54
  print('input_ids',input_ids.device)
 
47
  Vector representing the document
48
 
49
  """
50
+
 
51
  if attention_mask is None and 'attention_mask' in input_ids:
52
  (input_ids,attention_mask) = (input_ids['input_ids'],input_ids['attention_mask'])
53
  print('input_ids',input_ids.device)
scripts.py CHANGED
@@ -120,7 +120,7 @@ def prepare_training_datasets():
120
  def train_models(path,progress=gradio.Progress(track_tqdm=True)):
121
  device = torch.device('cuda:0')
122
  torch.cuda.empty_cache()
123
- tokenizer = tokenizers.Tokenizer.from_pretrained('roberta-base')
124
  trainer = qarac.models.QaracTrainerModel.QaracTrainerModel('roberta-base',
125
  tokenizer)
126
 
 
120
  def train_models(path,progress=gradio.Progress(track_tqdm=True)):
121
  device = torch.device('cuda:0')
122
  torch.cuda.empty_cache()
123
+ tokenizer = tokenizers.Tokenizer.from_pretrained('roberta-base').to(device)
124
  trainer = qarac.models.QaracTrainerModel.QaracTrainerModel('roberta-base',
125
  tokenizer)
126