import streamlit as st import torch from PIL import Image from transformers import AutoFeatureExtractor, AutoModelForImageClassification # 모델 및 설정 로드 @st.cache_resource def load_model(): feature_extractor = AutoFeatureExtractor.from_pretrained("xinyu1205/recognize-anything-plus-model") model = AutoModelForImageClassification.from_pretrained("xinyu1205/recognize-anything-plus-model") model.eval() return feature_extractor, model # 예측 함수 def predict(image, feature_extractor, model): inputs = feature_extractor(images=image, return_tensors="pt") with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits # 상위 5개 태그 반환 top_5 = torch.topk(logits, k=5) return [model.config.id2label[i.item()] for i in top_5.indices[0]] # Streamlit 앱 st.title("RAM++ Image Tagging") feature_extractor, model = load_model() uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"]) if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption='Uploaded Image', use_column_width=True) if st.button('Get Tags'): tags = predict(image, feature_extractor, model) st.write("Predicted Tags:") st.write(", ".join(tags))