seronk commited on
Commit
8fe645e
·
verified ·
1 Parent(s): 766da2d

Create distillbert-baseline.py

Browse files
Files changed (1) hide show
  1. distillbert-baseline.py +68 -0
distillbert-baseline.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from transformers import TrainingArguments
3
+ from transformers import DistilBertForSequenceClassification, DistilBertTokenizerFast
4
+
5
+
6
+
7
+
8
+ dataset = load_dataset("quotaclimat/frugalaichallenge-text-train")
9
+
10
+ # %% [markdown]
11
+ #
12
+
13
+ # %%
14
+ LABEL_MAPPING = {
15
+ "0_not_relevant": 0,
16
+ "1_not_happening": 1,
17
+ "2_not_human": 2,
18
+ "3_not_bad": 3,
19
+ "4_solutions_harmful_unnecessary": 4,
20
+ "5_science_unreliable": 5,
21
+ "6_proponents_biased": 6,
22
+ "7_fossil_fuels_needed": 7
23
+ }
24
+
25
+ # %%
26
+ dataset = dataset.map(lambda x: {"label": LABEL_MAPPING[x["label"]]})
27
+
28
+ # %%
29
+ print(dataset)
30
+
31
+ # %%
32
+
33
+ tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
34
+ # Tokenize the datasets
35
+ def tokenize_function(examples):
36
+ return tokenizer(examples["quote"], padding="max_length", truncation=True)
37
+
38
+ train_dataset = dataset["train"].map(tokenize_function, batched=True)
39
+ test_dataset = dataset["test"].map(tokenize_function, batched=True)
40
+
41
+
42
+ model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=8) # Set num_labels for your classification task
43
+
44
+ # %%
45
+
46
+ # Define training arguments
47
+ training_args = TrainingArguments(
48
+ output_dir="./results", # Output directory for saved models
49
+ eval_strategy="epoch", # Evaluation strategy (can be "steps" or "epoch")
50
+ per_device_train_batch_size=16, # Batch size for training
51
+ per_device_eval_batch_size=64, # Batch size for evaluation
52
+ num_train_epochs=3, # Number of training epochs
53
+ logging_dir="./logs", # Directory for logs
54
+ logging_steps=10, # How often to log
55
+ )
56
+
57
+
58
+ # %%
59
+
60
+ trainer = Trainer(
61
+ model=model, # The model to train
62
+ args=training_args, # The training arguments
63
+ train_dataset=train_dataset, # The training dataset
64
+ eval_dataset=test_dataset # The evaluation dataset
65
+ )
66
+ trainer.train()
67
+ results = trainer.evaluate()
68
+ print(results)