Spaces:
Runtime error
Runtime error
# app.py | |
from flask import Flask, request, jsonify | |
from flask_cors import CORS | |
from transformers import ViTImageProcessor, AutoModelForImageClassification | |
from PIL import Image | |
import requests | |
# Initialize Flask app | |
app = Flask(__name__) | |
CORS(app) # Enable CORS for all routes | |
# Load model and processor | |
processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector') | |
model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector') | |
# Classification function | |
def classify_image(image_url): | |
try: | |
image = Image.open(requests.get(image_url, stream=True).raw) | |
inputs = processor(images=image, return_tensors="pt") | |
outputs = model(**inputs) | |
logits = outputs.logits | |
predicted_class_idx = logits.argmax(-1).item() | |
return model.config.id2label[predicted_class_idx] | |
except Exception as e: | |
return str(e) | |
# API route to classify the image | |
def classify(): | |
print('ran') | |
image_url = request.args.get('url') | |
print(image_url) | |
if not image_url: | |
return jsonify({'error': 'No image URL provided'}), 400 | |
classification = classify_image(image_url) | |
return jsonify({'classification': classification}) | |
# Run the Flask server | |
if __name__ == '__main__': | |
app.run(debug=True, host='0.0.0.0') | |