File size: 1,161 Bytes
21be2de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import gradio as gr
import torchvision.transforms as transforms
from torchvision import models
from PIL import Image

# Load a pre-trained ResNet model
model = models.resnet50(pretrained=True)
model.eval()
transform = transforms.Compose([transforms.Resize(256),
                                transforms.CenterCrop(224),
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                     std=[0.229, 0.224, 0.225])])

# Define a function to classify an image
def classify_image(input_image):
    img = Image.open(input_image)
    img = transform(img).unsqueeze(0)
    with torch.no_grad():
        outputs = model(img)
        _, predicted_class = outputs.max(1)
    return class_names[predicted_class.item()]

# Create a Gradio interface
iface = gr.Interface(
    fn=classify_image,
    inputs=gr.inputs.Image(type="file", label="Upload an Image"),
    outputs=gr.outputs.Textbox(label="Predicted Class"),
    live=True,
    theme="default",
    title="Image Classification with ResNet",
)

# Launch the Gradio interface
iface.launch()