from typing import Any import pytorch_lightning as pl from torchvision.models import efficientnet_v2_s, EfficientNet_V2_S_Weights import torch from torch import nn from torchvision import transforms from torch.nn import functional as F import yaml from yaml.loader import SafeLoader from PIL import Image import gradio as gr import os class WeedModel(pl.LightningModule): def __init__(self, params): super().__init__() self.params = params model = self.params["model"] if(model.lower() == "efficientnet"): if(self.params["pretrained"]): self.base_model = efficientnet_v2_s(weights=EfficientNet_V2_S_Weights.IMAGENET1K_V1) else: self.base_model = efficientnet_v2_s(weights=None) num_ftrs = self.base_model.classifier[-1].in_features self.base_model.classifier[-1] = nn.Linear(num_ftrs, self.params["n_class"]) else: print("not prepared model yet!!") def forward(self, x): embedding = self.base_model(x) return embedding def configure_optimizers(self): if(self.params["optimizer"] == "Adam"): optimizer = torch.optim.Adam(self.parameters(), lr=self.params["Lr"]) elif(self.params["optimizer"] == "SGD"): optimizer = torch.optim.SGD(self.parameters(), lr=self.params["Lr"]) return optimizer def training_step(self, train_batch, batch_idx): x = train_batch["image"] y = train_batch["label"] y_hat = self(x) loss = F.cross_entropy(y_hat, y) self.log('metrics/batch/train_loss', loss, prog_bar=False) preds = F.softmax(y_hat, dim=-1) return loss def validation_step(self, val_batch, batch_idx): x = val_batch["image"] y = val_batch["label"] y_hat = self(x) loss = F.cross_entropy(y_hat, y) self.log('metrics/batch/val_loss', loss) def predict_step(self, batch: Any, batch_idx: int=0, dataloader_idx: int = 0) -> Any: y_hat = self(batch) preds = torch.softmax(y_hat, dim=-1).tolist() # preds = torch.argmax(preds, dim=-1) return preds def predict(image): tensor_image = transform(image) outs = model.predict_step(tensor_image.unsqueeze(0)) labels = {class_names[k]: float(v) for k, v in enumerate(outs[0][:-1])} return labels title = " AISeed AI Application Demo " description = "# A Demo of Deep Learning for Weed Classification" example_list = [["examples/" + example] for example in os.listdir("examples")] with open("class_names.txt", "r", encoding='utf-8') as f: class_names = f.read().splitlines() with gr.Blocks() as demo: demo.title = title gr.Markdown(description) with gr.Tabs(): with gr.TabItem("for Images"): with gr.Row(): with gr.Column(): im = gr.Image(type="pil", label="input image") with gr.Column(): label_conv = gr.Label(label="Predictions", num_top_classes=4) btn = gr.Button(value="predict") btn.click(predict, inputs=im, outputs=[label_conv]) gr.Examples(examples=example_list, inputs=[im], outputs=[label_conv]) with gr.TabItem("for Webcam"): with gr.Row(): with gr.Column(): webcam = gr.Image(type="pil", label="input image", source="webcam") # capture = gr.Image(type="pil", label="output image") with gr.Column(): label = gr.Label(label="Predictions", num_top_classes=4) webcam.change(predict, inputs=webcam, outputs=[label]) if __name__ == '__main__': with open('config.yaml') as f: PARAMS = yaml.load(f, Loader=SafeLoader) print(PARAMS) model = WeedModel.load_from_checkpoint("model\epoch=08.ckpt", params=PARAMS).cpu() model.eval() transform = transforms.Compose([ transforms.Resize(256), transforms.CenterCrop(224), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) demo.launch()