Bug&Fix due to Tensors Located on Different Files

#1
by ArthurPan - opened
Files changed (1) hide show
  1. handler.py +1 -2
handler.py CHANGED
@@ -26,8 +26,7 @@ class EndpointHandler():
26
  def get_prediction(self, instance):
27
  instance["attention_mask"] = [[1] * len(instance["input_ids"])]
28
  for key in ["input_ids", "token_type_ids", "attention_mask"]:
29
- instance[key] = torch.tensor(instance[key]).unsqueeze(0) # Batch size = 1
30
- instance[key].to(self.device)
31
 
32
  output = self.model(input_ids=instance["input_ids"],
33
  attention_mask=instance["attention_mask"],
 
26
  def get_prediction(self, instance):
27
  instance["attention_mask"] = [[1] * len(instance["input_ids"])]
28
  for key in ["input_ids", "token_type_ids", "attention_mask"]:
29
+ instance[key] = torch.tensor(instance[key]).unsqueeze(0).to(self.device) # Batch size = 1
 
30
 
31
  output = self.model(input_ids=instance["input_ids"],
32
  attention_mask=instance["attention_mask"],