Roaoch commited on
Commit
82ad620
·
1 Parent(s): ddde8f2
Files changed (1) hide show
  1. src/cyberclaasic.py +2 -2
src/cyberclaasic.py CHANGED
@@ -40,8 +40,8 @@ class CyberClassic(torch.nn.Module):
40
  decoded = self.tokenizer.batch_decode(generated, skip_special_tokens=True)
41
 
42
  decoded_tokens = self.discriminator_tokenizer(decoded, return_tensors='pt', padding=True, truncation=True)
43
- score = self.discriminator(decoded_tokens)
44
- index = int(torch.argmax(score))
45
 
46
  return decoded[index]
47
 
 
40
  decoded = self.tokenizer.batch_decode(generated, skip_special_tokens=True)
41
 
42
  decoded_tokens = self.discriminator_tokenizer(decoded, return_tensors='pt', padding=True, truncation=True)
43
+ score = self.discriminator(**decoded_tokens)
44
+ index = int(torch.argmax(score.logits))
45
 
46
  return decoded[index]
47