File size: 578 Bytes
1051963 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 |
import pandas as pd
import torch
from PIL import Image
from torchvision import transforms
# Load the model
model = torch.jit.load("SuSy.pt")
# Load patch
patch = Image.open("midjourney-images-example-patch0.png")
# Transform patch to tensor
patch = transforms.PILToTensor()(patch).unsqueeze(0) / 255.
# Predict patch
model.eval()
with torch.no_grad():
preds = model(patch)
# Print results
classes = ['authentic', 'dalle-3-images', 'diffusiondb', 'midjourney-images', 'midjourney_tti', 'realisticSDXL']
result = pd.DataFrame(preds.numpy(), columns=classes)
print(result) |