Link to the distilbert spam defender
Find the v1 (TensorFlow) model in SavedModel format on this page. The license for the v1 model is Apache 2.0
v3 | v1 | |
---|---|---|
Base Model | bert-base-multilingual-cased | nlpaueb/legal-bert-small-uncased |
Base Tokenizer | bert-base-multilingual-cased | bert-base-multilingual-cased |
Framework | PyTorch | TensorFlow |
Dataset Size | 3.0M | 2.68M |
Train Split | 80% English 20% English + 100% Multilingual |
None |
English Train Accuracy | 99.5% | N/A (≈97.5%) |
Other Train Accuracy | 98.6% | 96.6% |
Final Val Accuracy | 96.8% | 94.6% |
Languages | 55 | N/A (≈35) |
Hyperparameters | maxlen=208 padding='max_length' batch_size=112 optimizer=AdamW learning_rate=1e-5 loss=BCEWithLogitsLoss() |
maxlen=192 padding='max_length' batch_size=16 optimizer=Adam learning_rate=1e-5 loss="binary_crossentropy" |
Training Stopped | 7/20/2023 | 9/05/2022 |
I manually annotated more data on top of Toxi Text 3M and added them to the training set. Training on Toxi Text 3M alone results in a biased model that classifies short text with lower precision.
Models tested for v2: roberta, xlm-roberta, bert-small, bert-base-cased/uncased, bert-multilingual-cased/uncased, and alberta-large-v2. Of these, I chose bert-multilingual-cased because it performs better with the same amount of resources as the others for this particular task.
PyTorch
text = "hello world!"
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("FredZhang7/one-for-all-toxicity-v3")
model = AutoModelForSequenceClassification.from_pretrained("FredZhang7/one-for-all-toxicity-v3").to(device)
encoding = tokenizer.encode_plus(
text,
add_special_tokens=True,
max_length=208,
padding="max_length",
truncation=True,
return_tensors="pt"
)
print('device:', device)
input_ids = encoding["input_ids"].to(device)
attention_mask = encoding["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
predicted_labels = torch.argmax(logits, dim=1)
print(predicted_labels)
Attribution
- If you distribute, remix, adapt, or build upon One-for-all Toxicity v3, please credit "AIstrova Technologies Inc." in your README.md, application description, research, or website.
- Downloads last month
- 444
This model does not have enough activity to be deployed to Inference API (serverless) yet. Increase its social
visibility and check back later, or deploy to Inference Endpoints (dedicated)
instead.