import streamlit as st from transformers import ViTImageProcessor, AutoModelForImageClassification from PIL import Image import requests from io import BytesIO import json # Load the model and processor processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector') model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector') # Define prediction function def predict_image(image): try: # Process the image and make prediction inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits # Get predicted class predicted_class_idx = logits.argmax(-1).item() predicted_label = model.config.id2label[predicted_class_idx] return predicted_label except Exception as e: return str(e) # Streamlit app for UI and API endpoint st.title("NSFW Image Classifier") # URL input for UI image_url_ui = st.text_input("Enter Image URL", placeholder="Enter image URL here") # API endpoint for classification (POST request) @st.experimental_singleton # Ensure a single instance for performance def api_endpoint(): if request.method == 'POST': data = request.json if 'image_url' in data: try: image_url = data['image_url'] # Load image from URL response = requests.get(image_url) image = Image.open(BytesIO(response.content)) # Predict and return result as JSON prediction = predict_image(image) return json.dumps({'predicted_class': prediction}) except Exception as e: return json.dumps({'error': str(e)}), 500 # Internal Server Error else: return json.dumps({'error': 'Missing "image_url" in request body'}), 400 # Bad Request else: return json.dumps({'error': 'Only POST requests are allowed'}), 405 # Method Not Allowed st.experimental_next_router(api_endpoint) # Register the API endpoint if image_url_ui: try: # Load image from UI input (if URL is provided) response = requests.get(image_url_ui) image = Image.open(BytesIO(response.content)) st.image(image, caption='Image from URL', use_column_width=True) st.write("") st.write("Classifying...") # Predict and display result (for UI) prediction = predict_image(image) st.write(f"Predicted Class: {prediction}") except Exception as e: st.write(f"Error: {e}") # Display API endpoint information space_url = st.session_state.get('huggingface_space_url') # Assuming it's available if space_url: api_endpoint_url = f"{space_url}/api/classify" # Construct the URL based on Space URL st.write(f"You can also use this API endpoint to classify images:") st.write(f"```curl") st.write(f"curl -X POST -H 'Content-Type: application/json' -d '{{ \"image_url\": \"https://example.jpg\" }}' {api_endpoint_url}") st.write(f"```") st.write(f"This will return the predicted class in JSON format.")