VanshK04 commited on
Commit
5db411e
·
verified ·
1 Parent(s): 2a2b02c

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +34 -30
tasks/text.py CHANGED
@@ -1,16 +1,16 @@
1
- from fastapi import APIRouter
2
- from datetime import datetime
3
- from datasets import load_dataset
4
- from sklearn.metrics import accuracy_score
5
- import random
6
 
7
- from .utils.evaluation import TextEvaluationRequest
8
- from .utils.emissions import tracker, clean_emissions_data, get_space_info
 
 
9
 
10
- router = APIRouter()
11
-
12
- DESCRIPTION = "Random Baseline"
13
- ROUTE = "/text"
14
 
15
  @router.post(ROUTE, tags=["Text Task"],
16
  description=DESCRIPTION)
@@ -18,9 +18,7 @@ async def evaluate_text(request: TextEvaluationRequest):
18
  """
19
  Evaluate text classification for climate disinformation detection.
20
 
21
- Current Model: Random Baseline
22
- - Makes random predictions from the label space (0-7)
23
- - Used as a baseline for comparison
24
  """
25
  # Get space info
26
  username, space_url = get_space_info()
@@ -46,31 +44,37 @@ async def evaluate_text(request: TextEvaluationRequest):
46
  # Split dataset
47
  train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
48
  test_dataset = train_test["test"]
 
 
 
 
49
 
 
 
 
 
 
 
 
 
 
50
  # Start tracking emissions
51
  tracker.start()
52
  tracker.start_task("inference")
53
 
54
- #--------------------------------------------------------------------------------------------
55
- # YOUR MODEL INFERENCE CODE HERE
56
- # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
57
- #--------------------------------------------------------------------------------------------
58
-
59
- # Make random predictions (placeholder for actual model inference)
60
- true_labels = test_dataset["label"]
61
- predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
62
-
63
- #--------------------------------------------------------------------------------------------
64
- # YOUR MODEL INFERENCE STOPS HERE
65
- #--------------------------------------------------------------------------------------------
66
 
67
-
68
  # Stop tracking emissions
69
  emissions_data = tracker.stop_task()
70
-
71
  # Calculate accuracy
72
  accuracy = accuracy_score(true_labels, predictions)
73
-
74
  # Prepare results dictionary
75
  results = {
76
  "username": username,
@@ -89,4 +93,4 @@ async def evaluate_text(request: TextEvaluationRequest):
89
  }
90
  }
91
 
92
- return results
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
+ import torch
3
+ import os
4
+ import zipfile
 
5
 
6
+ # Unzip the uploaded file
7
+ model_dir = "/tasks/bert_fine_tuned_model"
8
+ with zipfile.ZipFile('/tasks/bert_fine_tuned_model-20250107T090607Z-001.zip', 'r') as zip_ref:
9
+ zip_ref.extractall(model_dir)
10
 
11
+ # Load the fine-tuned BERT model and tokenizer
12
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
13
+ model = AutoModelForSequenceClassification.from_pretrained(model_dir)
 
14
 
15
  @router.post(ROUTE, tags=["Text Task"],
16
  description=DESCRIPTION)
 
18
  """
19
  Evaluate text classification for climate disinformation detection.
20
 
21
+ Current Model: Fine-Tuned BERT
 
 
22
  """
23
  # Get space info
24
  username, space_url = get_space_info()
 
44
  # Split dataset
45
  train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
46
  test_dataset = train_test["test"]
47
+
48
+ # Preprocess the test dataset
49
+ def preprocess_function(examples):
50
+ return tokenizer(examples["text"], truncation=True, padding=True, max_length=512)
51
 
52
+ test_dataset = test_dataset.map(preprocess_function, batched=True)
53
+
54
+ # Convert to PyTorch dataset
55
+ test_dataset = test_dataset.with_format("torch")
56
+
57
+ # Assign device
58
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
59
+ model.to(device)
60
+
61
  # Start tracking emissions
62
  tracker.start()
63
  tracker.start_task("inference")
64
 
65
+ # Perform inference
66
+ true_labels = test_dataset["label"].tolist()
67
+ inputs = {key: test_dataset[key].to(device) for key in ["input_ids", "attention_mask"]}
68
+ with torch.no_grad():
69
+ logits = model(**inputs).logits
70
+ predictions = torch.argmax(logits, dim=1).tolist()
 
 
 
 
 
 
71
 
 
72
  # Stop tracking emissions
73
  emissions_data = tracker.stop_task()
74
+
75
  # Calculate accuracy
76
  accuracy = accuracy_score(true_labels, predictions)
77
+
78
  # Prepare results dictionary
79
  results = {
80
  "username": username,
 
93
  }
94
  }
95
 
96
+ return results