Spaces:
Runtime error
Runtime error
from typing import Any | |
import pytorch_lightning as pl | |
from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights | |
import torch | |
from torch import nn | |
from torchvision import transforms | |
from torch.nn import functional as F | |
import yaml | |
from yaml.loader import SafeLoader | |
from PIL import Image | |
import gradio as gr | |
import os | |
class WeedModel(pl.LightningModule): | |
def __init__(self, params): | |
super().__init__() | |
self.params = params | |
model = self.params["model"] | |
if(model.lower() == "efficientnet"): | |
if(self.params["pretrained"]): self.base_model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1) | |
else: self.base_model = efficientnet_v2_s(weights=None) | |
num_ftrs = self.base_model.classifier[-1].in_features | |
self.base_model.classifier[-1] = nn.Linear(num_ftrs, self.params["n_class"]) | |
else: | |
print("not prepared model yet!!") | |
def forward(self, x): | |
embedding = self.base_model(x) | |
return embedding | |
def configure_optimizers(self): | |
if(self.params["optimizer"] == "Adam"): | |
optimizer = torch.optim.Adam(self.parameters(), lr=self.params["Lr"]) | |
elif(self.params["optimizer"] == "SGD"): | |
optimizer = torch.optim.SGD(self.parameters(), lr=self.params["Lr"]) | |
return optimizer | |
def training_step(self, train_batch, batch_idx): | |
x = train_batch["image"] | |
y = train_batch["label"] | |
y_hat = self(x) | |
loss = F.cross_entropy(y_hat, y) | |
self.log('metrics/batch/train_loss', loss, prog_bar=False) | |
preds = F.softmax(y_hat, dim=-1) | |
return loss | |
def validation_step(self, val_batch, batch_idx): | |
x = val_batch["image"] | |
y = val_batch["label"] | |
y_hat = self(x) | |
loss = F.cross_entropy(y_hat, y) | |
self.log('metrics/batch/val_loss', loss) | |
def predict_step(self, batch: Any, batch_idx: int=0, dataloader_idx: int = 0) -> Any: | |
y_hat = self(batch) | |
preds = torch.softmax(y_hat, dim=-1).tolist() | |
# preds = torch.argmax(preds, dim=-1) | |
return preds | |
def predict(image): | |
tensor_image = transform(image) | |
outs = model.predict_step(tensor_image.unsqueeze(0)) | |
labels = {class_names[k]: float(v) for k, v in enumerate(outs[0][:-1])} | |
return labels | |
title = " AISeed AI Application Demo " | |
description = "# A Demo of Deep Learning for Weed Classification" | |
example_list = [["examples/" + example] for example in os.listdir("examples")] | |
with open("class_names.txt", "r", encoding='utf-8') as f: | |
class_names = f.read().splitlines() | |
with gr.Blocks() as demo: | |
demo.title = title | |
gr.Markdown(description) | |
with gr.Tabs(): | |
with gr.TabItem("for Images"): | |
with gr.Row(): | |
with gr.Column(): | |
im = gr.Image(type="pil", label="input image") | |
with gr.Column(): | |
label_conv = gr.Label(label="Predictions", num_top_classes=4) | |
btn = gr.Button(value="predict") | |
btn.click(predict, inputs=im, outputs=[label_conv]) | |
gr.Examples(examples=example_list, inputs=[im], outputs=[label_conv]) | |
with gr.TabItem("for Webcam"): | |
with gr.Row(): | |
with gr.Column(): | |
webcam = gr.Image(type="pil", label="input image", source="webcam") | |
# capture = gr.Image(type="pil", label="output image") | |
with gr.Column(): | |
label = gr.Label(label="Predictions", num_top_classes=4) | |
webcam.change(predict, inputs=webcam, outputs=[label]) | |
if __name__ == '__main__': | |
with open('config.yaml') as f: | |
PARAMS = yaml.load(f, Loader=SafeLoader) | |
print(PARAMS) | |
model = WeedModel.load_from_checkpoint("model\epoch=08.ckpt", params=PARAMS).cpu() | |
model.eval() | |
transform = transforms.Compose([ | |
transforms.Resize(256), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), | |
]) | |
demo.launch() | |