VanshK04 commited on
Commit
eba6cfa
·
verified ·
1 Parent(s): 1a2598d

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +25 -6
tasks/text.py CHANGED
@@ -11,7 +11,7 @@ from .utils.emissions import tracker, clean_emissions_data, get_space_info
11
 
12
  router = APIRouter()
13
 
14
- DESCRIPTION = "BERT Fine tuned"
15
  ROUTE = "/text"
16
 
17
  @router.post(ROUTE, tags=["Text Task"],
@@ -97,14 +97,33 @@ async def evaluate_text(request: TextEvaluationRequest):
97
  test_dataset = TextDataset(texts, labels, tokenizer)
98
  test_loader = DataLoader(test_dataset, batch_size=16)
99
 
 
 
 
 
 
 
 
 
100
  model.eval()
101
  predictions = []
 
 
102
  with torch.no_grad():
103
- for inputs, labels in test_loader:
104
- inputs, labels = inputs.to('cpu'), labels.to('cpu')
105
- outputs = model(inputs)
106
- _, predicted = torch.max(outputs, 1)
 
 
 
 
 
 
 
107
  predictions.extend(predicted.cpu().numpy())
 
 
108
 
109
  #--------------------------------------------------------------------------------------------
110
  # YOUR MODEL INFERENCE STOPS HERE
@@ -135,4 +154,4 @@ async def evaluate_text(request: TextEvaluationRequest):
135
  }
136
  }
137
 
138
- return results
 
11
 
12
  router = APIRouter()
13
 
14
+ DESCRIPTION = "Evaluate text classification for climate disinformation detection"
15
  ROUTE = "/text"
16
 
17
  @router.post(ROUTE, tags=["Text Task"],
 
97
  test_dataset = TextDataset(texts, labels, tokenizer)
98
  test_loader = DataLoader(test_dataset, batch_size=16)
99
 
100
+ # model.eval()
101
+ # predictions = []
102
+ # with torch.no_grad():
103
+ # for inputs, labels in test_loader:
104
+ # inputs, labels = inputs.to('cpu'), labels.to('cpu')
105
+ # outputs = model(inputs)
106
+ # _, predicted = torch.max(outputs, 1)
107
+ # predictions.extend(predicted.cpu().numpy())
108
  model.eval()
109
  predictions = []
110
+ ground_truth = []
111
+ DEVICE='cpu'
112
  with torch.no_grad():
113
+ for batch in test_loader:
114
+ # Access each component of the batch dictionary
115
+ input_ids = batch['input_ids'].to(DEVICE)
116
+ attention_mask = batch['attention_mask'].to(DEVICE)
117
+ labels = batch['labels'].to(DEVICE)
118
+
119
+ # Forward pass
120
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask)
121
+ _, predicted = torch.max(outputs.logits, 1)
122
+
123
+ # Store predictions and ground truth
124
  predictions.extend(predicted.cpu().numpy())
125
+ ground_truth.extend(labels.cpu().numpy())
126
+
127
 
128
  #--------------------------------------------------------------------------------------------
129
  # YOUR MODEL INFERENCE STOPS HERE
 
154
  }
155
  }
156
 
157
+ return results