VanshK04 commited on
Commit
250433e
·
verified ·
1 Parent(s): 9685f7b

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +27 -4
tasks/text.py CHANGED
@@ -9,7 +9,13 @@ from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
  router = APIRouter()
11
 
12
- DESCRIPTION = "Random Baseline"
 
 
 
 
 
 
13
  ROUTE = "/text"
14
 
15
  @router.post(ROUTE, tags=["Text Task"],
@@ -56,10 +62,27 @@ async def evaluate_text(request: TextEvaluationRequest):
56
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
57
  #--------------------------------------------------------------------------------------------
58
 
59
- # Make random predictions (placeholder for actual model inference)
60
- true_labels = test_dataset["label"]
61
- predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
 
 
 
63
  #--------------------------------------------------------------------------------------------
64
  # YOUR MODEL INFERENCE STOPS HERE
65
  #--------------------------------------------------------------------------------------------
 
9
 
10
  router = APIRouter()
11
 
12
+ import numpy as np
13
+
14
+ from transformers import AutoTokenizer,BertForSequenceClassification,AutoModelForSequenceClassification,Trainer, TrainingArguments,DataCollatorWithPadding
15
+ from datasets import Dataset
16
+ import torch
17
+
18
+ DESCRIPTION = "BERT-Fine Tune"
19
  ROUTE = "/text"
20
 
21
  @router.post(ROUTE, tags=["Text Task"],
 
62
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
63
  #--------------------------------------------------------------------------------------------
64
 
65
+ model = BertForSequenceClassification.from_pretrained("Oriaz/climate_change_bert_classif")
66
+ tokenizer = AutoTokenizer.from_pretrained("Oriaz/climate_change_bert_classif")
67
+
68
+ ## Data prep
69
+ def preprocess_function(df):
70
+ return tokenizer(df["quote"], truncation=True)
71
+ tokenized_test = test_dataset.map(preprocess_function, batched=True)
72
+
73
+ ## Modify inference model
74
+ training_args = torch.load("./tasks/utils/training_args.bin")
75
+ training_args.eval_strategy='no'
76
+
77
+ trainer = Trainer(
78
+ model=model,
79
+ args=training_args,
80
+ tokenizer=tokenizer
81
+ )
82
 
83
+ ## prediction
84
+ preds = trainer.predict(tokenized_test)
85
+ predictions = np.array([np.argmax(x) for x in preds[0]])
86
  #--------------------------------------------------------------------------------------------
87
  # YOUR MODEL INFERENCE STOPS HERE
88
  #--------------------------------------------------------------------------------------------