Juliofc commited on
Commit
6895a87
1 Parent(s): e4e31ec

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models, transforms
4
+ import gradio as gr
5
+ from PIL import Image
6
+
7
+ # Define the model architecture (must match the saved model)
8
+ class_names = ["cordana", "healthy", "pestalotiopsis", "sigatoka"]
9
+ def load_model():
10
+ model = models.alexnet(pretrained=False)
11
+ num_ftrs = model.classifier[6].in_features
12
+ model.classifier[6] = nn.Linear(num_ftrs, len(class_names)) # Adjust this for your number of classes
13
+
14
+ # Load the model weights
15
+ model.load_state_dict(torch.load('model_alexnet.pth'))
16
+ model.eval() # Set to evaluation mode
17
+
18
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
19
+ model.to(device)
20
+ return model, device
21
+
22
+ model_alexnet, device = load_model()
23
+
24
+ # Image transformations
25
+ transform = transforms.Compose([
26
+ transforms.Resize(256),
27
+ transforms.CenterCrop(224),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
30
+ ])
31
+
32
+ # Prediction function
33
+ def predict_image(image):
34
+ image = Image.fromarray(image.astype('uint8'), 'RGB')
35
+ image = transform(image).unsqueeze(0) # Add batch dimension
36
+ image = image.to(device)
37
+
38
+ with torch.no_grad():
39
+ outputs = model_alexnet(image)
40
+ _, predicted = torch.max(outputs, 1)
41
+ predicted = predicted.cpu().numpy()
42
+
43
+ return class_names[predicted[0]] # Adjust this if needed
44
+
45
+
46
+ iface = gr.Interface(fn=predict_image, inputs="image", outputs="label",
47
+ description="This model is a fine-tuned version of AlexNet specifically designed to identify four types of diseases in banana tree leaves. It can classify the leaves as Cordana, Healthy, Pestalotiopsis, or Sigatoka. Upload a photo of a banana leaf and the model will help you determine its health condition.",
48
+ examples=[
49
+ 'data/test/cordana/1.jpeg',
50
+ 'data/test/healthy/5.jpeg',
51
+ 'data/test/pestalotiopsis/5.jpeg',
52
+ 'data/test/sigatoka/1.jpeg'
53
+ ]
54
+ )
55
+
56
+ iface.launch()