nsfw_api2 / app.py
yeftakun's picture
Update app.py
83c18ee verified
raw
history blame
2.55 kB
import streamlit as st
from transformers import ViTImageProcessor, AutoModelForImageClassification
from PIL import Image
import requests
from io import BytesIO
import json
from flask import Flask, request, jsonify
# Load the model and processor
processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector')
model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector')
# Define prediction function
def predict_image(image):
try:
# Process the image and make prediction
inputs = processor(images=image, return_tensors="pt")
outputs = model(**inputs)
logits = outputs.logits
# Get predicted class
predicted_class_idx = logits.argmax(-1).item()
predicted_label = model.config.id2label[predicted_class_idx]
return predicted_label
except Exception as e:
return str(e)
# Streamlit app
st.title("NSFW Image Classifier")
# Display API usage instructions
st.write("You can use this app with the API endpoint below. Send a POST request with the image URL to get classification.")
st.write("Example URL to use with curl:")
st.code("curl -X POST https://huggingface.co/spaces/yeftakun/nsfw_api2/api/classify -H 'Content-Type: application/json' -d '{\"image_url\": \"https://example.jpg\"}'")
# URL input for UI
image_url = st.text_input("Enter Image URL", placeholder="Enter image URL here")
if image_url:
try:
# Load image from URL
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))
st.image(image, caption='Image from URL', use_column_width=True)
st.write("")
st.write("Classifying...")
# Predict and display result
prediction = predict_image(image)
st.write(f"Predicted Class: {prediction}")
except Exception as e:
st.write(f"Error: {e}")
# API Endpoint using Flask
app = Flask(__name__)
@app.route('/api/classify', methods=['POST'])
def classify():
data = request.json
image_url = data.get('image_url')
if not image_url:
return jsonify({"error": "Image URL is required"}), 400
try:
# Load image from URL
response = requests.get(image_url)
image = Image.open(BytesIO(response.content))
# Predict image
prediction = predict_image(image)
return jsonify({"predicted_class": prediction})
except Exception as e:
return jsonify({"error": str(e)}), 500
if __name__ == '__main__':
app.run(port=5000)