davebulaval commited on
Commit
ba092fc
1 Parent(s): 1541189

add torch.no_grad()

Browse files
Files changed (1) hide show
  1. meaningbert.py +32 -30
meaningbert.py CHANGED
@@ -19,6 +19,7 @@ from typing import List, Dict
19
 
20
  import datasets
21
  import evaluate
 
22
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
23
 
24
 
@@ -128,34 +129,35 @@ class MeaningBERT(evaluate.Metric):
128
  )
129
  scorer.eval()
130
 
131
- # We load MeaningBERT tokenizer
132
- tokenizer = AutoTokenizer.from_pretrained("davebulaval/MeaningBERT")
133
-
134
- # We tokenize the text as a pair and return Pytorch Tensors
135
- tokenize_text = tokenizer(
136
- references,
137
- predictions,
138
- truncation=True,
139
- padding=True,
140
- return_tensors="pt",
141
- ).to(device)
142
-
143
- with filter_logging_context():
144
- # We process the text
145
- scores = scorer(**tokenize_text)
146
-
147
- scores = scores.logits.tolist()
148
-
149
- # Flatten the list of list of logits
150
- scores = list(chain(*scores))
151
-
152
- # Handle case of perfect match
153
- if len(matching_index) > 0:
154
- for matching_element_index in matching_index:
155
- scores[matching_element_index] = 100
156
-
157
- output_dict = {
158
- "scores": scores,
159
- "hashcode": hashcode,
160
- }
 
161
  return output_dict
 
19
 
20
  import datasets
21
  import evaluate
22
+ import torch
23
  from transformers import AutoModelForSequenceClassification, AutoTokenizer
24
 
25
 
 
129
  )
130
  scorer.eval()
131
 
132
+ with torch.no_grad():
133
+ # We load MeaningBERT tokenizer
134
+ tokenizer = AutoTokenizer.from_pretrained("davebulaval/MeaningBERT")
135
+
136
+ # We tokenize the text as a pair and return Pytorch Tensors
137
+ tokenize_text = tokenizer(
138
+ references,
139
+ predictions,
140
+ truncation=True,
141
+ padding=True,
142
+ return_tensors="pt",
143
+ ).to(device)
144
+
145
+ with filter_logging_context():
146
+ # We process the text
147
+ scores = scorer(**tokenize_text)
148
+
149
+ scores = scores.logits.tolist()
150
+
151
+ # Flatten the list of list of logits
152
+ scores = list(chain(*scores))
153
+
154
+ # Handle case of perfect match
155
+ if len(matching_index) > 0:
156
+ for matching_element_index in matching_index:
157
+ scores[matching_element_index] = 100
158
+
159
+ output_dict = {
160
+ "scores": scores,
161
+ "hashcode": hashcode,
162
+ }
163
  return output_dict