laeeqkhan commited on
Commit
6a5a9c8
·
verified ·
1 Parent(s): eaa2b79

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +65 -0
main.py CHANGED
@@ -2,3 +2,68 @@ from datasets import load_dataset
2
 
3
  dataset = load_dataset("mteb/tweet_sentiment_extraction")
4
  df = pd.DataFrame(dataset['train'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  dataset = load_dataset("mteb/tweet_sentiment_extraction")
4
  df = pd.DataFrame(dataset['train'])
5
+
6
+
7
+ from transformers import GPT2Tokenizer
8
+
9
+ # Loading the dataset to train our model
10
+ dataset = load_dataset("mteb/tweet_sentiment_extraction")
11
+
12
+ tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
13
+ tokenizer.pad_token = tokenizer.eos_token
14
+ def tokenize_function(examples):
15
+ return tokenizer(examples["text"], padding="max_length", truncation=True)
16
+
17
+ tokenized_datasets = dataset.map(tokenize_function, batched=True)
18
+
19
+
20
+ small_train_dataset = tokenized_datasets["train"].shuffle(seed=42).select(range(1000))
21
+ small_eval_dataset = tokenized_datasets["test"].shuffle(seed=42).select(range(1000))
22
+
23
+
24
+ from transformers import GPT2ForSequenceClassification
25
+
26
+ model = GPT2ForSequenceClassification.from_pretrained("gpt2", num_labels=3)
27
+
28
+
29
+
30
+ import evaluate
31
+
32
+ metric = evaluate.load("accuracy")
33
+
34
+ def compute_metrics(eval_pred):
35
+ logits, labels = eval_pred
36
+ predictions = np.argmax(logits, axis=-1)
37
+ return metric.compute(predictions=predictions, references=labels)
38
+
39
+
40
+
41
+
42
+ from transformers import TrainingArguments, Trainer
43
+
44
+ training_args = TrainingArguments(
45
+ output_dir="test_trainer",
46
+ #evaluation_strategy="epoch",
47
+ per_device_train_batch_size=1, # Reduce batch size here
48
+ per_device_eval_batch_size=1, # Optionally, reduce for evaluation as well
49
+ gradient_accumulation_steps=4
50
+ )
51
+
52
+
53
+ trainer = Trainer(
54
+ model=model,
55
+ args=training_args,
56
+ train_dataset=small_train_dataset,
57
+ eval_dataset=small_eval_dataset,
58
+ compute_metrics=compute_metrics,
59
+
60
+ )
61
+
62
+ trainer.train()
63
+
64
+
65
+
66
+
67
+ import evaluate
68
+
69
+ trainer.evaluate()