Spaces:
Sleeping
Sleeping
Update tasks/text.py
Browse files- tasks/text.py +34 -30
tasks/text.py
CHANGED
@@ -1,16 +1,16 @@
|
|
1 |
-
from
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
import random
|
6 |
|
7 |
-
|
8 |
-
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
ROUTE = "/text"
|
14 |
|
15 |
@router.post(ROUTE, tags=["Text Task"],
|
16 |
description=DESCRIPTION)
|
@@ -18,9 +18,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
18 |
"""
|
19 |
Evaluate text classification for climate disinformation detection.
|
20 |
|
21 |
-
Current Model:
|
22 |
-
- Makes random predictions from the label space (0-7)
|
23 |
-
- Used as a baseline for comparison
|
24 |
"""
|
25 |
# Get space info
|
26 |
username, space_url = get_space_info()
|
@@ -46,31 +44,37 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
46 |
# Split dataset
|
47 |
train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
|
48 |
test_dataset = train_test["test"]
|
|
|
|
|
|
|
|
|
49 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
# Start tracking emissions
|
51 |
tracker.start()
|
52 |
tracker.start_task("inference")
|
53 |
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
true_labels = test_dataset["label"]
|
61 |
-
predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
|
62 |
-
|
63 |
-
#--------------------------------------------------------------------------------------------
|
64 |
-
# YOUR MODEL INFERENCE STOPS HERE
|
65 |
-
#--------------------------------------------------------------------------------------------
|
66 |
|
67 |
-
|
68 |
# Stop tracking emissions
|
69 |
emissions_data = tracker.stop_task()
|
70 |
-
|
71 |
# Calculate accuracy
|
72 |
accuracy = accuracy_score(true_labels, predictions)
|
73 |
-
|
74 |
# Prepare results dictionary
|
75 |
results = {
|
76 |
"username": username,
|
@@ -89,4 +93,4 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
89 |
}
|
90 |
}
|
91 |
|
92 |
-
return results
|
|
|
1 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
2 |
+
import torch
|
3 |
+
import os
|
4 |
+
import zipfile
|
|
|
5 |
|
6 |
+
# Unzip the uploaded file
|
7 |
+
model_dir = "/tasks/bert_fine_tuned_model"
|
8 |
+
with zipfile.ZipFile('/tasks/bert_fine_tuned_model-20250107T090607Z-001.zip', 'r') as zip_ref:
|
9 |
+
zip_ref.extractall(model_dir)
|
10 |
|
11 |
+
# Load the fine-tuned BERT model and tokenizer
|
12 |
+
tokenizer = AutoTokenizer.from_pretrained(model_dir)
|
13 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
|
|
|
14 |
|
15 |
@router.post(ROUTE, tags=["Text Task"],
|
16 |
description=DESCRIPTION)
|
|
|
18 |
"""
|
19 |
Evaluate text classification for climate disinformation detection.
|
20 |
|
21 |
+
Current Model: Fine-Tuned BERT
|
|
|
|
|
22 |
"""
|
23 |
# Get space info
|
24 |
username, space_url = get_space_info()
|
|
|
44 |
# Split dataset
|
45 |
train_test = dataset["train"].train_test_split(test_size=request.test_size, seed=request.test_seed)
|
46 |
test_dataset = train_test["test"]
|
47 |
+
|
48 |
+
# Preprocess the test dataset
|
49 |
+
def preprocess_function(examples):
|
50 |
+
return tokenizer(examples["text"], truncation=True, padding=True, max_length=512)
|
51 |
|
52 |
+
test_dataset = test_dataset.map(preprocess_function, batched=True)
|
53 |
+
|
54 |
+
# Convert to PyTorch dataset
|
55 |
+
test_dataset = test_dataset.with_format("torch")
|
56 |
+
|
57 |
+
# Assign device
|
58 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
59 |
+
model.to(device)
|
60 |
+
|
61 |
# Start tracking emissions
|
62 |
tracker.start()
|
63 |
tracker.start_task("inference")
|
64 |
|
65 |
+
# Perform inference
|
66 |
+
true_labels = test_dataset["label"].tolist()
|
67 |
+
inputs = {key: test_dataset[key].to(device) for key in ["input_ids", "attention_mask"]}
|
68 |
+
with torch.no_grad():
|
69 |
+
logits = model(**inputs).logits
|
70 |
+
predictions = torch.argmax(logits, dim=1).tolist()
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
|
|
|
72 |
# Stop tracking emissions
|
73 |
emissions_data = tracker.stop_task()
|
74 |
+
|
75 |
# Calculate accuracy
|
76 |
accuracy = accuracy_score(true_labels, predictions)
|
77 |
+
|
78 |
# Prepare results dictionary
|
79 |
results = {
|
80 |
"username": username,
|
|
|
93 |
}
|
94 |
}
|
95 |
|
96 |
+
return results
|