import streamlit as st from transformers import ViTImageProcessor, AutoModelForImageClassification from PIL import Image import requests from io import BytesIO import json from flask import Flask, request, jsonify # Load the model and processor processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector') model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector') # Define prediction function def predict_image(image): try: # Process the image and make prediction inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits # Get predicted class predicted_class_idx = logits.argmax(-1).item() predicted_label = model.config.id2label[predicted_class_idx] return predicted_label except Exception as e: return str(e) # Streamlit app st.title("NSFW Image Classifier") # Display API usage instructions st.write("You can use this app with the API endpoint below. Send a POST request with the image URL to get classification.") st.write("Example URL to use with curl:") st.code("curl -X POST https://huggingface.co/spaces/yeftakun/nsfw_api2/api/classify -H 'Content-Type: application/json' -d '{\"image_url\": \"https://example.jpg\"}'") # URL input for UI image_url = st.text_input("Enter Image URL", placeholder="Enter image URL here") if image_url: try: # Load image from URL response = requests.get(image_url) image = Image.open(BytesIO(response.content)) st.image(image, caption='Image from URL', use_column_width=True) st.write("") st.write("Classifying...") # Predict and display result prediction = predict_image(image) st.write(f"Predicted Class: {prediction}") except Exception as e: st.write(f"Error: {e}") # API Endpoint using Flask app = Flask(__name__) @app.route('/api/classify', methods=['POST']) def classify(): data = request.json image_url = data.get('image_url') if not image_url: return jsonify({"error": "Image URL is required"}), 400 try: # Load image from URL response = requests.get(image_url) image = Image.open(BytesIO(response.content)) # Predict image prediction = predict_image(image) return jsonify({"predicted_class": prediction}) except Exception as e: return jsonify({"error": str(e)}), 500 if __name__ == '__main__': app.run(port=5000)