prithivMLmods commited on
Commit
3f5e2ad
·
verified ·
1 Parent(s): e6ad0c0

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +109 -0
README.md CHANGED
@@ -113,6 +113,115 @@ gr_interface = gr.Interface(
113
  # Launch the application
114
  gr_interface.launch()
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
116
  ```
117
 
118
  ## **🚀 How to Train the Model**
 
113
  # Launch the application
114
  gr_interface.launch()
115
 
116
+ ```
117
+ ### Train Details
118
+
119
+ ```python
120
+
121
+ # Import necessary libraries
122
+ from datasets import load_dataset, ClassLabel
123
+ from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
124
+ import torch
125
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
126
+
127
+ # Load dataset
128
+ dataset = load_dataset("prithivMLmods/Spam-Text-Detect-Analysis", split="train")
129
+
130
+ # Encode labels as integers
131
+ label_mapping = {"ham": 0, "spam": 1}
132
+ dataset = dataset.map(lambda x: {"label": label_mapping[x["Category"]]})
133
+ dataset = dataset.rename_column("Message", "text").remove_columns(["Category"])
134
+
135
+ # Convert label column to ClassLabel for stratification
136
+ class_label = ClassLabel(names=["ham", "spam"])
137
+ dataset = dataset.cast_column("label", class_label)
138
+
139
+ # Split into train and test
140
+ dataset = dataset.train_test_split(test_size=0.2, stratify_by_column="label")
141
+ train_dataset = dataset["train"]
142
+ test_dataset = dataset["test"]
143
+
144
+ # Load BERT tokenizer
145
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
146
+
147
+ # Tokenize the data
148
+ def tokenize_function(examples):
149
+ return tokenizer(examples["text"], padding="max_length", truncation=True, max_length=128)
150
+
151
+ train_dataset = train_dataset.map(tokenize_function, batched=True)
152
+ test_dataset = test_dataset.map(tokenize_function, batched=True)
153
+
154
+ # Set format for PyTorch
155
+ train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
156
+ test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
157
+
158
+ # Load pre-trained BERT model
159
+ model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
160
+
161
+ # Move model to GPU if available
162
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
163
+ model.to(device)
164
+
165
+ # Define evaluation metric
166
+ def compute_metrics(eval_pred):
167
+ predictions, labels = eval_pred
168
+ predictions = torch.argmax(torch.tensor(predictions), dim=-1)
169
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, predictions, average="binary")
170
+ acc = accuracy_score(labels, predictions)
171
+ return {"accuracy": acc, "precision": precision, "recall": recall, "f1": f1}
172
+
173
+ # Training arguments
174
+ training_args = TrainingArguments(
175
+ output_dir="./results",
176
+ evaluation_strategy="epoch", # Evaluate after every epoch
177
+ save_strategy="epoch", # Save checkpoint after every epoch
178
+ learning_rate=2e-5,
179
+ per_device_train_batch_size=16,
180
+ per_device_eval_batch_size=16,
181
+ num_train_epochs=3,
182
+ weight_decay=0.01,
183
+ logging_dir="./logs",
184
+ logging_steps=10,
185
+ load_best_model_at_end=True,
186
+ metric_for_best_model="accuracy",
187
+ greater_is_better=True
188
+ )
189
+
190
+ # Trainer
191
+ trainer = Trainer(
192
+ model=model,
193
+ args=training_args,
194
+ train_dataset=train_dataset,
195
+ eval_dataset=test_dataset,
196
+ compute_metrics=compute_metrics
197
+ )
198
+
199
+ # Train the model
200
+ trainer.train()
201
+
202
+ # Evaluate the model
203
+ results = trainer.evaluate()
204
+ print("Evaluation Results:", results)
205
+
206
+ # Save the trained model
207
+ model.save_pretrained("./saved_model")
208
+ tokenizer.save_pretrained("./saved_model")
209
+
210
+ # Load the model for inference
211
+ loaded_model = BertForSequenceClassification.from_pretrained("./saved_model").to(device)
212
+ loaded_tokenizer = BertTokenizer.from_pretrained("./saved_model")
213
+
214
+ # Test the model on a custom input
215
+ def predict(text):
216
+ inputs = loaded_tokenizer(text, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
217
+ inputs = {k: v.to(device) for k, v in inputs.items()} # Move inputs to the same device as model
218
+ outputs = loaded_model(**inputs)
219
+ prediction = torch.argmax(outputs.logits, dim=-1).item()
220
+ return "Spam" if prediction == 1 else "Ham"
221
+
222
+ # Example test
223
+ example_text = "Congratulations! You've won a $1000 Walmart gift card. Click here to claim now."
224
+ print("Prediction:", predict(example_text))
225
  ```
226
 
227
  ## **🚀 How to Train the Model**