hlydecker's picture
Update app.py
075db70 verified
raw
history blame
2.67 kB
import transformers
import torch
import torchvision
from transformers import TrainingArguments, Trainer
from transformers import ViTImageProcessor
from transformers import ViTForImageClassification
from torch.utils.data import DataLoader
from datasets import load_dataset
from torchvision.transforms import (CenterCrop,
Compose,
Normalize,
RandomHorizontalFlip,
RandomResizedCrop,
Resize,
ToTensor)
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import torch
import torch.nn.functional as F
import time
import gradio as gr
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
processor = ViTImageProcessor.from_pretrained("ViT_LCZs_v3",local_files_only=True)
model = ViTForImageClassification.from_pretrained("ViT_LCZs_v3",local_files_only=True).to(device)
import os, glob
examples_dir = './samples'
example_files = glob.glob(os.path.join(examples_dir, '*.jpg'))
def classify_image(image):
with torch.no_grad():
model.eval()
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
prob = torch.nn.functional.softmax(logits, dim=1)
top10_prob, top10_indices = torch.topk(prob, 10)
top10_confidences = {}
for i in range(10):
top10_confidences[model.config.id2label[int(top10_indices[0][i])]] = float(top10_prob[0][i])
return top10_confidences #confidences
with gr.Blocks(title="ViT LCZ Classification - ClassCat",
css=".gradio-container {background:white;}"
) as demo:
gr.HTML("""<div style="font-family:'Times New Roman', 'Serif'; font-size:16pt; font-weight:bold; text-align:center; color:royalblue;">LCZ Classification with ViT</div>""")
with gr.Row():
input_image = gr.Image(type="pil", image_mode="RGB", shape=(224, 224))
output_label=gr.Label(label="Probabilities", num_top_classes=3)
send_btn = gr.Button("Infer")
send_btn.click(fn=classify_image, inputs=input_image, outputs=output_label)
with gr.Row():
gr.Examples(['data/closed_highrise.png'], label='Sample images : cat', inputs=input_image)
gr.Examples(['data/open_lowrise.png'], label='cheetah', inputs=input_image)
gr.Examples(['data/dense_trees.png'], label='hotdog', inputs=input_image)
gr.Examples(['data/large_lowrise.png'], label='lion', inputs=input_image)
demo.launch(debug=True)