import streamlit as st from transformers import ViTImageProcessor, AutoModelForImageClassification from PIL import Image import requests from io import BytesIO # Load the model and processor @st.cache_data def load_model(): processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector') model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector') return processor, model processor, model = load_model() # Define prediction function def predict_image(image): try: # Process the image and make prediction inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits # Get predicted class predicted_class_idx = logits.argmax(-1).item() predicted_label = model.config.id2label[predicted_class_idx] return predicted_label except Exception as e: return str(e) # Streamlit app st.title("NSFW Image Classifier") # Get image URL from query parameters query_params = st.query_params() image_url = query_params.get('image_url', [None])[0] if image_url: try: # Load image from URL response = requests.get(image_url) image = Image.open(BytesIO(response.content)) st.image(image, caption='Image from URL', use_column_width=True) st.write("") st.write("Classifying...") # Predict and display result prediction = predict_image(image) st.write(f"Predicted Class: {prediction}") except Exception as e: st.write(f"Error: {e}") else: st.write("Please provide an image URL using the 'image_url' query parameter.")