BEiT_Gradio / app.py
WangQvQ's picture
Update app.py
da6e5a6
raw
history blame contribute delete
No virus
1.08 kB
import gradio as gr
from transformers import BeitFeatureExtractor, BeitForImageClassification
from PIL import Image
import requests
import numpy as np
# Load the pre-trained BEiT model and feature extractor
feature_extractor = BeitFeatureExtractor.from_pretrained('microsoft/beit-large-patch16-512')
model = BeitForImageClassification.from_pretrained('microsoft/beit-large-patch16-512')
def classify_image(input_image):
image = Image.fromarray(input_image.astype('uint8'))
inputs = feature_extractor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
predicted_class_idx = logits.argmax(-1).item()
predicted_class = model.config.id2label[predicted_class_idx]
return {"Predicted Class": predicted_class}
iface = gr.Interface(
fn=classify_image,
inputs=gr.inputs.Image(type="numpy"), # Specify input type as numpy array
outputs="json",
live=True,
title="BEiT Classification",
description="Upload an image and you will get a description"
)
if __name__ == "__main__":
iface.launch()