Update tasks/text.py
Browse files- tasks/text.py +25 -6
tasks/text.py
CHANGED
@@ -11,7 +11,7 @@ from .utils.emissions import tracker, clean_emissions_data, get_space_info
|
|
11 |
|
12 |
router = APIRouter()
|
13 |
|
14 |
-
DESCRIPTION = "
|
15 |
ROUTE = "/text"
|
16 |
|
17 |
@router.post(ROUTE, tags=["Text Task"],
|
@@ -97,14 +97,33 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
97 |
test_dataset = TextDataset(texts, labels, tokenizer)
|
98 |
test_loader = DataLoader(test_dataset, batch_size=16)
|
99 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
100 |
model.eval()
|
101 |
predictions = []
|
|
|
|
|
102 |
with torch.no_grad():
|
103 |
-
for
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
predictions.extend(predicted.cpu().numpy())
|
|
|
|
|
108 |
|
109 |
#--------------------------------------------------------------------------------------------
|
110 |
# YOUR MODEL INFERENCE STOPS HERE
|
@@ -135,4 +154,4 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
135 |
}
|
136 |
}
|
137 |
|
138 |
-
return results
|
|
|
11 |
|
12 |
router = APIRouter()
|
13 |
|
14 |
+
DESCRIPTION = "Evaluate text classification for climate disinformation detection"
|
15 |
ROUTE = "/text"
|
16 |
|
17 |
@router.post(ROUTE, tags=["Text Task"],
|
|
|
97 |
test_dataset = TextDataset(texts, labels, tokenizer)
|
98 |
test_loader = DataLoader(test_dataset, batch_size=16)
|
99 |
|
100 |
+
# model.eval()
|
101 |
+
# predictions = []
|
102 |
+
# with torch.no_grad():
|
103 |
+
# for inputs, labels in test_loader:
|
104 |
+
# inputs, labels = inputs.to('cpu'), labels.to('cpu')
|
105 |
+
# outputs = model(inputs)
|
106 |
+
# _, predicted = torch.max(outputs, 1)
|
107 |
+
# predictions.extend(predicted.cpu().numpy())
|
108 |
model.eval()
|
109 |
predictions = []
|
110 |
+
ground_truth = []
|
111 |
+
DEVICE='cpu'
|
112 |
with torch.no_grad():
|
113 |
+
for batch in test_loader:
|
114 |
+
# Access each component of the batch dictionary
|
115 |
+
input_ids = batch['input_ids'].to(DEVICE)
|
116 |
+
attention_mask = batch['attention_mask'].to(DEVICE)
|
117 |
+
labels = batch['labels'].to(DEVICE)
|
118 |
+
|
119 |
+
# Forward pass
|
120 |
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask)
|
121 |
+
_, predicted = torch.max(outputs.logits, 1)
|
122 |
+
|
123 |
+
# Store predictions and ground truth
|
124 |
predictions.extend(predicted.cpu().numpy())
|
125 |
+
ground_truth.extend(labels.cpu().numpy())
|
126 |
+
|
127 |
|
128 |
#--------------------------------------------------------------------------------------------
|
129 |
# YOUR MODEL INFERENCE STOPS HERE
|
|
|
154 |
}
|
155 |
}
|
156 |
|
157 |
+
return results
|