vit_modle / README.md
gihakkk's picture
Update README.md
df04e8d verified
metadata
license: unknown

๋กœ๋งจ์Šค ์Šค์บ  ์‚ฌ์ง„๊ณผ, ๊ทธ๋ƒฅ ์‚ฌ์ง„์„ ๊ตฌ๋ณ„ํ•  ์ˆ˜ ์žˆ๋Š” ViT ๋ชจ๋ธ ์ž…๋‹ˆ๋‹ค.
๊ธฐ์กด์˜ CNN ๋ชจ๋ธ์— ๋น„ํ•ด ํ›จ์‹  ์„ฑ๋Šฅ์ด ์ข‹์Šต๋‹ˆ๋‹ค.
์ถ”ํ›„ ๋ฐ์ดํ„ฐ๋ฅผ ์ถ”๊ฐ€ํ•ด ์„ฑ๋Šฅ์„ ๋”์šฑ ๋Š˜๋ฆด๊ฒƒ ์ž…๋‹ˆ๋‹ค. ์‚ฌ์šฉ ์ฝ”๋“œ๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์Šต๋‹ˆ๋‹ค.

import torch
from transformers import ViTForImageClassification, ViTFeatureExtractor
from PIL import Image

# Hugging Face์—์„œ ๋ชจ๋ธ ๋ฐ ํŠน์ง• ์ถ”์ถœ๊ธฐ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ
model = ViTForImageClassification.from_pretrained("gihakkk/vit_modle")
feature_extractor = ViTFeatureExtractor.from_pretrained("gihakkk/vit_modle")

# ์ƒˆ๋กœ์šด ์ด๋ฏธ์ง€ ์˜ˆ์ธก ํ•จ์ˆ˜ ์ •์˜
def predict_image(image_path):
    # ์ด๋ฏธ์ง€๋ฅผ ๋กœ๋“œํ•˜๊ณ  RGB๋กœ ๋ณ€ํ™˜
    image = Image.open(image_path).convert("RGB")
    
    # ์ด๋ฏธ์ง€๋ฅผ ํŠน์ง• ์ถ”์ถœ๊ธฐ๋กœ ์ „์ฒ˜๋ฆฌํ•˜์—ฌ ๋ชจ๋ธ ์ž…๋ ฅ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜
    inputs = feature_extractor(images=image, return_tensors="pt")
    
    # ์˜ˆ์ธก ์ˆ˜ํ–‰
    with torch.no_grad():
        outputs = model(**inputs).logits
    predicted_class = torch.argmax(outputs, dim=-1).item()

    return "๊ทธ๋ƒฅ ์‚ฌ์ง„" if predicted_class == 1 else "๋กœ๋งจ์Šค ์Šค์บ  ์‚ฌ์ง„"

# ์˜ˆ์ธก ์˜ˆ์‹œ
image_path = r'path\to\your\img.jpg'
result = predict_image(image_path)
print(result)