|
from transformers import ViTImageProcessor, ViTForImageClassification |
|
from PIL import Image |
|
import requests |
|
import os |
|
import gradio as gr |
|
from timeit import default_timer as timer |
|
from typing import Tuple, Dict |
|
|
|
|
|
def predict(img) -> Tuple[Dict, float]: |
|
start_time = timer() |
|
processor = ViTImageProcessor.from_pretrained('bazyl/gtsrb-model') |
|
model = ViTForImageClassification.from_pretrained('bazyl/gtsrb-model') |
|
inputs = processor(images=img, return_tensors="pt") |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
predicted_class_idx = logits.argmax(-1).item() |
|
print("Predicted class:", model.config.id2label[predicted_class_idx]) |
|
|
|
title = "GTSRB - German Traffic Sign Recognition by Bazyl Horsey" |
|
description = "CNN created for the GTSRB Dataset, achieved 99.93% test accuracy" |
|
|
|
|
|
example_list = [["examples/" + example] for example in os.listdir("examples")] |
|
|
|
|
|
demo = gr.Interface( |
|
fn=predict, |
|
inputs=gr.Image(type="pil"), |
|
outputs=[ |
|
gr.Label(num_top_classes=5, label="Predictions"), |
|
gr.Number(label="Prediction time (s)"), |
|
], |
|
examples=example_list, |
|
title=title, |
|
description=description, |
|
) |
|
|
|
|
|
demo.launch() |