Tngarg commited on
Commit
52150b5
Β·
verified Β·
1 Parent(s): 26f451c

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +57 -0
  2. light_cnn_model.pth +3 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torchvision.transforms as transforms
4
+ from PIL import Image
5
+ import torch.nn as nn
6
+ import os
7
+
8
+ # βœ… Define Lightweight CNN Model (Same as trained)
9
+ class SmallCNN(nn.Module):
10
+ def __init__(self):
11
+ super(SmallCNN, self).__init__()
12
+ self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1) # Reduced filters
13
+ self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1) # Reduced filters
14
+ self.fc1 = nn.Linear(32 * 8 * 8, 10) # 10-class classification (CIFAR-10)
15
+
16
+ def forward(self, x):
17
+ x = torch.relu(self.conv1(x))
18
+ x = torch.max_pool2d(x, 2)
19
+ x = torch.relu(self.conv2(x))
20
+ x = torch.max_pool2d(x, 2)
21
+ x = x.view(x.size(0), -1)
22
+ x = self.fc1(x)
23
+ return x
24
+
25
+ # βœ… Load the trained model from Hugging Face Space
26
+ model_path = os.path.join(os.getenv("SPACE_ROOT", ""), "light_cnn_model.pth")
27
+
28
+ # βœ… Initialize model and load weights
29
+ model = SmallCNN()
30
+ model.load_state_dict(torch.load(model_path, map_location=torch.device("cpu")))
31
+ model.eval()
32
+
33
+ # βœ… Define Image Transformation
34
+ transform = transforms.Compose([
35
+ transforms.Resize((32, 32)), # Resize image to 32x32 pixels
36
+ transforms.ToTensor(), # Convert image to tensor
37
+ ])
38
+
39
+ # βœ… Define Prediction Function
40
+ def predict(image):
41
+ image = transform(image).unsqueeze(0) # Convert image to tensor and add batch dimension
42
+ with torch.no_grad():
43
+ output = model(image) # Forward pass through model
44
+ prediction = torch.argmax(output, dim=1).item() # Get predicted class
45
+ return f"Predicted Class: {prediction}"
46
+
47
+ # βœ… Create Gradio Interface
48
+ interface = gr.Interface(
49
+ fn=predict, # Function to process image
50
+ inputs=gr.Image(type="pil"), # User uploads an image
51
+ outputs="text", # Model returns a text output
52
+ title="Lightweight CNN Image Classification",
53
+ description="Upload an image to classify using the trained CNN model.",
54
+ )
55
+
56
+ # βœ… Launch the Gradio App
57
+ interface.launch()
light_cnn_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2afca200cc840345889a7174a85ab796df3ebf4ae681522f8ed43df4b39ed756
3
+ size 104984
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ Pillow