LukasT9 commited on
Commit
b008788
·
verified ·
1 Parent(s): 678828d

Create inference.py

Browse files
Files changed (1) hide show
  1. inference.py +41 -0
inference.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models, transforms
4
+ from PIL import Image
5
+
6
+ def load_model(model_path, device):
7
+ """Loads the TorchScript model."""
8
+ model = torch.jit.load(model_path, map_location=device)
9
+ model.to(device).eval()
10
+ return model
11
+
12
+ def preprocess_image(image_path):
13
+ """Pre-processes the image for feeding into the model."""
14
+ IMG_SIZE = 1024
15
+ transform = transforms.Compose([
16
+ transforms.Resize(IMG_SIZE + 32),
17
+ transforms.CenterCrop(IMG_SIZE),
18
+ transforms.ToTensor(),
19
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
20
+ ])
21
+ img = Image.open(image_path).convert("RGB")
22
+ return transform(img).unsqueeze(0)
23
+
24
+ def predict(model, image_tensor, device, threshold=0.5):
25
+ """Performs model prediction."""
26
+ with torch.no_grad():
27
+ outputs = model(image_tensor.to(device))
28
+ prob = torch.sigmoid(outputs).item()
29
+ label = "Real" if prob >= threshold else "AI"
30
+ return prob, label
31
+
32
+ if __name__ == "__main__":
33
+ model_path = "model.pt" # Path to Flux-Detector
34
+ image_path = "test_image.jpg" # Path to test image
35
+
36
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
+ model = load_model(model_path, device)
38
+ image_tensor = preprocess_image(image_path)
39
+ prob, label = predict(model, image_tensor, device)
40
+
41
+ print(f"Model Prediction: {prob:.4f} -> {label}")