ad-classifier-v0.4 / README.md
teknology's picture
Update README.md
c0d7340 verified
metadata
language:
  - en
base_model:
  - microsoft/deberta-v3-base
pipeline_tag: text-classification
license: mit

Binary classification model for ad-detection on QA Systems.

Sample usage

import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
classifier_model_path = "teknology/ad-classifier-v0.4"
tokenizer = AutoTokenizer.from_pretrained(classifier_model_path)
model = AutoModelForSequenceClassification.from_pretrained(classifier_model_path)
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def classify(passages):
    inputs = tokenizer(
        passages, padding=True, truncation=True, max_length=512, return_tensors="pt"
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)
    return predictions.cpu().tolist()
preds = classify(["sample_text_1", "sample_text_2"])

Version

Previous versions can be found at: