Image Classification
image
fake-detection
ai-detection
flux-detector / inference.py
LukasT9's picture
Update inference.py
f83bc3e verified
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
def load_model(model_path, device):
"""Loads the TorchScript model."""
model = torch.jit.load(model_path, map_location=device)
model.to(device).eval()
return model
def preprocess_image(image_path):
"""Pre-processes the image for feeding into the model."""
IMG_SIZE = 1024
transform = transforms.Compose([
transforms.Resize(IMG_SIZE + 32),
transforms.CenterCrop(IMG_SIZE),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
img = Image.open(image_path).convert("RGB")
return transform(img).unsqueeze(0)
def predict(model, image_tensor, device, threshold=0.5):
"""Performs model prediction."""
with torch.no_grad():
outputs = model(image_tensor.to(device))
prob = torch.sigmoid(outputs).item()
label = "Real" if prob >= threshold else "AI"
return prob, label
if __name__ == "__main__":
model_path = r"model.pt" # Path to Flux-Detector
image_path = r"test_image.png" # Path to test image
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = load_model(model_path, device)
image_tensor = preprocess_image(image_path)
prob, label = predict(model, image_tensor, device)
print(f"Model Prediction: {prob:.4f} -> {label}")