์‚ฌ์šฉ์˜ˆ์‹œ

import onnxruntime as ort
import numpy as np
from transformers import AutoFeatureExtractor
from PIL import Image

# ONNX ๋ชจ๋ธ ๊ฒฝ๋กœ
onnx_model_path = r'C:\mobilevit_model.onnx'

# ONNX ๋Ÿฐํƒ€์ž„ ์„ธ์…˜ ์ดˆ๊ธฐํ™”
ort_session = ort.InferenceSession(onnx_model_path)

# ์ƒˆ๋กœ์šด ์ด๋ฏธ์ง€ ์˜ˆ์ธก ํ•จ์ˆ˜ ์ •์˜
def predict_image(image_path):
    # MobileViT ๋ชจ๋ธ์— ๋งž๋Š” ํŠน์ง• ์ถ”์ถœ๊ธฐ ๋กœ๋“œ
    feature_extractor = AutoFeatureExtractor.from_pretrained("apple/mobilevit-small")
    
    # ์ด๋ฏธ์ง€๋ฅผ ๋กœ๋“œํ•˜๊ณ  RGB๋กœ ๋ณ€ํ™˜
    image = Image.open(image_path).convert("RGB")
    
    # ์ด๋ฏธ์ง€๋ฅผ ํŠน์ง• ์ถ”์ถœ๊ธฐ๋กœ ์ „์ฒ˜๋ฆฌ
    inputs = feature_extractor(images=image, return_tensors="np")
    input_array = inputs['pixel_values']  # ONNX๋Š” Numpy ํ˜•์‹์„ ์‚ฌ์šฉ
    
    # ONNX ๋ชจ๋ธ์— ์ž…๋ ฅ ์ „๋‹ฌ ๋ฐ ์ถ”๋ก 
    ort_inputs = {ort_session.get_inputs()[0].name: input_array}
    ort_outputs = ort_session.run(None, ort_inputs)
    
    # ๊ฒฐ๊ณผ ํ•ด์„
    logits = ort_outputs[0]
    predicted_class = np.argmax(logits, axis=-1).item()
    
    return "๊ทธ๋ƒฅ ์‚ฌ์ง„" if predicted_class == 1 else "๋กœ๋งจ์Šค ์Šค์บ  ์‚ฌ์ง„"

# ์˜ˆ์ธก ์˜ˆ์‹œ
image_path = r'C:\1234567.jpg'
result = predict_image(image_path)
print(result)
Downloads last month
6
Safetensors
Model size
4.95M params
Tensor type
F32
ยท
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.