File size: 2,070 Bytes
52150b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
import torch
import torchvision.transforms as transforms
from PIL import Image
import torch.nn as nn
import os

# βœ… Define Lightweight CNN Model (Same as trained)
class SmallCNN(nn.Module):
    def __init__(self):
        super(SmallCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)  # Reduced filters
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)  # Reduced filters
        self.fc1 = nn.Linear(32 * 8 * 8, 10)  # 10-class classification (CIFAR-10)

    def forward(self, x):
        x = torch.relu(self.conv1(x))
        x = torch.max_pool2d(x, 2)
        x = torch.relu(self.conv2(x))
        x = torch.max_pool2d(x, 2)
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x

# βœ… Load the trained model from Hugging Face Space
model_path = os.path.join(os.getenv("SPACE_ROOT", ""), "light_cnn_model.pth")

# βœ… Initialize model and load weights
model = SmallCNN()
model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
model.eval()

# βœ… Define Image Transformation
transform = transforms.Compose([
    transforms.Resize((32, 32)),  # Resize image to 32x32 pixels
    transforms.ToTensor(),  # Convert image to tensor
])

# βœ… Define Prediction Function
def predict(image):
    image = transform(image).unsqueeze(0)  # Convert image to tensor and add batch dimension
    with torch.no_grad():
        output = model(image)  # Forward pass through model
        prediction = torch.argmax(output, dim=1).item()  # Get predicted class
    return f"Predicted Class: {prediction}"

# βœ… Create Gradio Interface
interface = gr.Interface(
    fn=predict,  # Function to process image
    inputs=gr.Image(type="pil"),  # User uploads an image
    outputs="text",  # Model returns a text output
    title="Lightweight CNN Image Classification",
    description="Upload an image to classify using the trained CNN model.",
)

# βœ… Launch the Gradio App
interface.launch()