import torch from PIL import Image from RealESRGAN import RealESRGAN from flask import Flask, request, jsonify, send_file import io import logging # Setup logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = Flask(__name__) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logger.info(f'Using device: {device}') model2 = RealESRGAN(device, scale=2) model2.load_weights('weights/RealESRGAN_x2.pth', download=True) logger.info('Model x2 loaded successfully') model4 = RealESRGAN(device, scale=4) model4.load_weights('weights/RealESRGAN_x4.pth', download=True) logger.info('Model x4 loaded successfully') model8 = RealESRGAN(device, scale=8) model8.load_weights('weights/RealESRGAN_x8.pth', download=True) logger.info('Model x8 loaded successfully') def inference(image, size): global model2, model4, model8 if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info('CUDA cache cleared') logger.info(f'Starting inference with scale {size}') try: if size == '2x': result = model2.predict(image.convert('RGB')) elif size == '4x': result = model4.predict(image.convert('RGB')) else: width, height = image.size if width >= 5000 or height >= 5000: return None, "The image is too large." result = model8.predict(image.convert('RGB')) logger.info(f'Inference completed for scale {size}') except torch.cuda.OutOfMemoryError as e: logger.error(f'OutOfMemoryError: {e}') logger.info(f'Reloading model for scale {size}') if size == '2x': model2 = RealESRGAN(device, scale=2) model2.load_weights('weights/RealESRGAN_x2.pth', download=False) result = model2.predict(image.convert('RGB')) elif size == '4x': model4 = RealESRGAN(device, scale=4) model4.load_weights('weights/RealESRGAN_x4.pth', download=False) result = model4.predict(image.convert('RGB')) else: model8 = RealESRGAN(device, scale=8) model8.load_weights('weights/RealESRGAN_x8.pth', download=False) result = model8.predict(image.convert('RGB')) logger.info(f'Model reloaded and inference completed for scale {size}') return result, None @app.route('/upscale', methods=['POST']) def upscale(): if 'image' not in request.files: logger.warning('No image uploaded') return jsonify({"error": "No image uploaded"}), 400 image_file = request.files['image'] size = request.form.get('size', '2x') try: image = Image.open(image_file) logger.info(f'Image uploaded and opened successfully') except Exception as e: logger.error(f'Invalid image file: {e}') return jsonify({"error": "Invalid image file"}), 400 result, error = inference(image, size) if error: logger.error(f'Error during inference: {error}') return jsonify({"error": error}), 400 img_io = io.BytesIO() result.save(img_io, 'PNG') img_io.seek(0) logger.info('Image processing completed and ready to be sent back') return send_file(img_io, mimetype='image/png') if __name__ == '__main__': logger.info('Starting the Flask server...') app.run(host='0.0.0.0', port=5000)