camenduru's picture
thanks to MichalMlodawski ❤
698ae26 verified
metadata
license: cc-by-nc-nd-4.0
language:
  - en
model-index:
  - name: roberta-large Image Prompt Classifier
    results:
      - task:
          type: text-classification
        dataset:
          name: nsfw-text-detection
          type: custom
        metrics:
          - name: Accuracy
            type: self-reported
            value: 93%
          - name: Precision
            type: self-reported
            value: 88%
          - name: Recall
            type: self-reported
            value: 90%

roberta-large Image Prompt Classifier

Model Overview

This model is a fine-tuned version of roberta-large designed specifically for classifying image generation prompts into three distinct categories: SAFE, QUESTIONABLE, and UNSAFE. Leveraging the robust capabilities of the roberta-large architecture, this model ensures high accuracy and reliability in identifying the nature of prompts used for generating images.

Model Details

  • Model Name: roberta-large Image Prompt Classifier
  • Base Model: roberta-large
  • Fine-tuned By: Michał Młodawski
  • Categories:
    • 0: SAFE
    • 1: QUESTIONABLE
    • 2: UNSAFE

Use Cases

This model is particularly useful for platforms and applications involving AI-generated content, where it is crucial to filter and classify prompts to maintain content safety and appropriateness. Some potential applications include:

  • Content Moderation: Automatically classify and filter prompts to prevent the generation of inappropriate or harmful images.
  • User Safety: Enhance user experience by ensuring that generated content adheres to safety guidelines.
  • Compliance: Help platforms comply with regulatory requirements by identifying and flagging potentially unsafe prompts.

How It Works

The model takes an input prompt and classifies it into one of three categories:

  1. SAFE: Prompts that are deemed appropriate and free from harmful content.
  2. QUESTIONABLE: Prompts that may require further review due to potential ambiguity or slight risk.
  3. UNSAFE: Prompts that are likely to generate inappropriate or harmful content.

The classification is based on the semantic understanding and contextual analysis provided by the roberta-large architecture, fine-tuned on a curated dataset tailored for this specific task.

Performance

Extensive testing and validation have been conducted to ensure the model's performance meets high standards of accuracy and reliability. Key metrics include:

  • Accuracy: 93%
  • Precision: 88%
  • Recall: 90%

These metrics highlight the model's capability to effectively distinguish between different categories of prompts with high confidence.

Getting Started

To start using the roberta-large Image Prompt Classifier, you can easily integrate it into your existing workflows with the following steps:

Installation

pip install transformers==4.37.2
pip install torch==2.3.1

Usage

import os
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch


model_repo = "MichalMlodawski/nsfw-text-detection-large"

tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)

def classify_text(text):
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
    
    with torch.no_grad():
        outputs = model(**inputs)
    
    logits = outputs.logits
    predicted_class = torch.argmax(logits, dim=1).item()
    
    return predicted_class

if __name__ == "__main__":
    examples = [
        {"text": "wild west cowboy caravan through Arizona Desert , Men Woman and children, horse drawn carriages", "expected_class": 0},
        {"text": "A person with questionable intentions in a dimly lit alley.", "expected_class": 1},
        {"text": "An inappropriate scene involving explicit content.", "expected_class": 2}
    ]
    
    for example in examples:
        text = example["text"]
        expected_class = example["expected_class"]
        predicted_class = classify_text(text)
        print(f"Prompt: {text}")
        print(f"Predicted class: {predicted_class}")
        print(f"Expected class: {expected_class}")
        print("---")

Disclaimer

The dataset includes actual prompts for image generation that could be perceived as abusive, offensive, or obscene. Furthermore, the examples and data might contain unfavorable information about certain businesses. We merely collect this data and do not assume any legal responsibility. Please note: A portion of the data was created using Large Language Models (LLM).