ai-image-detector / README.md
dafilabs's picture
Update README.md
c890dbf verified
metadata
license: apache-2.0
metrics:
  - accuracy
base_model:
  - google/efficientnet-b4
pipeline_tag: image-classification
library_name: timm
tags:
  - art
  - pytorch
  - images
  - ai

AI Image Detection

Dataset

  • AI: ≈100,000 Images
  • Human: ≈100,000 Images

Model

  • Architecture: EfficientNet-B4
  • Framework: PyTorch

Evaluation Metrics

  • Training Accuracy: 99.75%
  • Validation Accuracy: 98.59%
  • Training Loss: 0.0072
  • Validation Loss: 0.0553

Usage

pip install torch torchvision timm huggingface_hub pillow

Example Code

import torch
from torchvision import transforms
from PIL import Image
from timm import create_model
from huggingface_hub import hf_hub_download

# Parameters
IMG_SIZE = 380
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LABEL_MAPPING = {1: "human", 0: "ai"}

# Download model from HuggingFace Hub
MODEL_PATH = hf_hub_download(repo_id="Dafilab/ai-image-detector", filename="model_epoch_8_acc_0.9859.pth")

# Preprocessing
transform = transforms.Compose([
    transforms.Resize(IMG_SIZE + 20),
    transforms.CenterCrop(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

# Load model
model = create_model('efficientnet_b4', pretrained=False, num_classes=2)
model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
model.to(DEVICE).eval()

# Prediction function
def predict_image(image_path):
    img = Image.open(image_path).convert("RGB")
    img = transform(img).unsqueeze(0).to(DEVICE)
    with torch.no_grad():
        logits = model(img)
        probs = torch.nn.functional.softmax(logits, dim=1)
        predicted_class = torch.argmax(probs, dim=1).item()
        confidence = probs[0, predicted_class].item()
    return LABEL_MAPPING[predicted_class], confidence

# Example usage
image_path = "path/to/image.jpg"
label, confidence = predict_image(image_path)
print(f"Label: {label}, Confidence: {confidence:.2f}")