Tihsrah-CD's picture
feat: Add inference code for the Topic Classifier model
a9a3816
|
raw
history blame
5.92 kB
metadata
license: mit
language:
  - en
base_model:
  - distilbert/distilbert-base-uncased
pipeline_tag: text-classification

Topic Classifier

This repository contains the Topic Classifier model developed by DAXA.AI. The Topic Classifier is a machine learning model designed to categorize text documents across various domains, such as corporate documents, financial texts, harmful content, and medical documents.

Model Details

Model Description

The Topic Classifier is a BERT-based model, fine-tuned from the distilbert-base-uncased model. It is intended for categorizing text into specific topics, including "CORPORATE_DOCUMENTS," "FINANCIAL," "HARMFUL," and "MEDICAL." This model streamlines text classification tasks across multiple sectors, making it suitable for various business use cases.

  • Developed by: DAXA.AI
  • Funded by: Open Source
  • Model type: Text classification
  • Language(s): English
  • License: MIT
  • Fine-tuned from: distilbert-base-uncased

Model Sources

Usage

How to Get Started with the Model

To use the Topic Classifier in your Python project, you can follow the steps below:

# Import necessary libraries
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import joblib
from huggingface_hub import hf_hub_url, cached_download

# Load the tokenizer and model
tokenizer = AutoTokenizer.from_pretrained("daxa-ai/topic-classifier")
model = AutoModelForSequenceClassification.from_pretrained("daxa-ai/topic-classifier")

# Example text
text = "Please enter your text here."
encoded_input = tokenizer(text, return_tensors='pt')
output = model(**encoded_input)

# Apply softmax to the logits
probabilities = torch.nn.functional.softmax(output.logits, dim=-1)

# Get the predicted label
predicted_label = torch.argmax(probabilities, dim=-1)

# URL of your Hugging Face model repository
REPO_NAME = "daxa-ai/topic-classifier"

# Path to the label encoder file in the repository
LABEL_ENCODER_FILE = "label_encoder.joblib"

# Construct the URL to the label encoder file
url = hf_hub_url(REPO_NAME, filename=LABEL_ENCODER_FILE)

# Download and cache the label encoder file
filename = cached_download(url)

# Load the label encoder
label_encoder = joblib.load(filename)

# Decode the predicted label
decoded_label = label_encoder.inverse_transform(predicted_label.numpy())

print(decoded_label)

Training Details

Training Data

The training dataset consists of 29,286 entries, categorized into four distinct labels. The distribution of these labels is presented below:

Document Type Instances
CORPORATE_DOCUMENTS 17,649
FINANCIAL 3,385
HARMFUL 2,388
MEDICAL 5,864

Evaluation

Testing Data & Metrics

The model was evaluated on a dataset consisting of 4,565 entries. The distribution of labels in the evaluation set is shown below:

Document Type Instances
CORPORATE_DOCUMENTS 3,051
FINANCIAL 409
HARMFUL 246
MEDICAL 859

The evaluation metrics include precision, recall, and F1-score, calculated for each label:

Document Type Precision Recall F1-Score Support
CORPORATE_DOCUMENTS 1.00 1.00 1.00 3,051
FINANCIAL 0.95 0.96 0.96 409
HARMFUL 0.95 0.95 0.95 246
MEDICAL 0.99 1.00 0.99 859
Accuracy 0.99 4,565
Macro Avg 0.97 0.98 0.97 4,565
Weighted Avg 0.99 0.99 0.99 4,565

Test Data Evaluation Results

The model's evaluation results are as follows:

  • Evaluation Loss: 0.0233
  • Accuracy: 0.9908
  • Precision: 0.9909
  • Recall: 0.9908
  • F1-Score: 0.9908
  • Evaluation Runtime: 30.1149 seconds
  • Evaluation Samples Per Second: 151.586
  • Evaluation Steps Per Second: 2.391

Inference Code

from transformers import AutoModelForSequenceClassification, AutoTokenizer, pipeline


def model_fn(model_dir):
    """
    Load the model and tokenizer from the specified paths
    :param model_dir:
    :return:
    """
    tokenizer = AutoTokenizer.from_pretrained(model_dir)
    model = AutoModelForSequenceClassification.from_pretrained(model_dir)
    return model, tokenizer


def predict_fn(data, model_and_tokenizer):
    # destruct model and tokenizer
    model, tokenizer = model_and_tokenizer

    bert_pipe = pipeline("text-classification", model=model, tokenizer=tokenizer,
                         truncation=True, max_length=512, return_all_scores=True)
    # Tokenize the input, pick up first 512 tokens before passing it further
    tokens = tokenizer.encode(data['inputs'], add_special_tokens=False, max_length=512, truncation=True)
    input_data = tokenizer.decode(tokens)
    return bert_pipe(input_data)

Conclusion

The Topic Classifier achieves high accuracy, precision, recall, and F1-score, making it a reliable model for categorizing text across the domains of corporate documents, financial content, harmful content, and medical texts. The model is optimized for immediate deployment and works efficiently in real-world applications.

For more information or to try the model yourself, check out the public space here.