# app.py from flask import Flask, request, jsonify from flask_cors import CORS from transformers import ViTImageProcessor, AutoModelForImageClassification from PIL import Image import requests # Initialize Flask app app = Flask(__name__) CORS(app) # Enable CORS for all routes # Load model and processor processor = ViTImageProcessor.from_pretrained('AdamCodd/vit-base-nsfw-detector') model = AutoModelForImageClassification.from_pretrained('AdamCodd/vit-base-nsfw-detector') # Classification function def classify_image(image_url): try: image = Image.open(requests.get(image_url, stream=True).raw) inputs = processor(images=image, return_tensors="pt") outputs = model(**inputs) logits = outputs.logits predicted_class_idx = logits.argmax(-1).item() return model.config.id2label[predicted_class_idx] except Exception as e: return str(e) # API route to classify the image @app.route('/api/classify', methods=['GET']) def classify(): print('ran') image_url = request.args.get('url') print(image_url) if not image_url: return jsonify({'error': 'No image URL provided'}), 400 classification = classify_image(image_url) return jsonify({'classification': classification}) # Run the Flask server if __name__ == '__main__': app.run(debug=True, host='0.0.0.0')