upscale-api / app.py
Arifzyn19
Add application file
6cd4595
import torch
from PIL import Image
from RealESRGAN import RealESRGAN
from flask import Flask, request, jsonify, send_file
import io
import logging
# Setup logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f'Using device: {device}')
model2 = RealESRGAN(device, scale=2)
model2.load_weights('weights/RealESRGAN_x2.pth', download=True)
logger.info('Model x2 loaded successfully')
model4 = RealESRGAN(device, scale=4)
model4.load_weights('weights/RealESRGAN_x4.pth', download=True)
logger.info('Model x4 loaded successfully')
model8 = RealESRGAN(device, scale=8)
model8.load_weights('weights/RealESRGAN_x8.pth', download=True)
logger.info('Model x8 loaded successfully')
def inference(image, size):
global model2, model4, model8
if torch.cuda.is_available():
torch.cuda.empty_cache()
logger.info('CUDA cache cleared')
logger.info(f'Starting inference with scale {size}')
try:
if size == '2x':
result = model2.predict(image.convert('RGB'))
elif size == '4x':
result = model4.predict(image.convert('RGB'))
else:
width, height = image.size
if width >= 5000 or height >= 5000:
return None, "The image is too large."
result = model8.predict(image.convert('RGB'))
logger.info(f'Inference completed for scale {size}')
except torch.cuda.OutOfMemoryError as e:
logger.error(f'OutOfMemoryError: {e}')
logger.info(f'Reloading model for scale {size}')
if size == '2x':
model2 = RealESRGAN(device, scale=2)
model2.load_weights('weights/RealESRGAN_x2.pth', download=False)
result = model2.predict(image.convert('RGB'))
elif size == '4x':
model4 = RealESRGAN(device, scale=4)
model4.load_weights('weights/RealESRGAN_x4.pth', download=False)
result = model4.predict(image.convert('RGB'))
else:
model8 = RealESRGAN(device, scale=8)
model8.load_weights('weights/RealESRGAN_x8.pth', download=False)
result = model8.predict(image.convert('RGB'))
logger.info(f'Model reloaded and inference completed for scale {size}')
return result, None
@app.route('/upscale', methods=['POST'])
def upscale():
if 'image' not in request.files:
logger.warning('No image uploaded')
return jsonify({"error": "No image uploaded"}), 400
image_file = request.files['image']
size = request.form.get('size', '2x')
try:
image = Image.open(image_file)
logger.info(f'Image uploaded and opened successfully')
except Exception as e:
logger.error(f'Invalid image file: {e}')
return jsonify({"error": "Invalid image file"}), 400
result, error = inference(image, size)
if error:
logger.error(f'Error during inference: {error}')
return jsonify({"error": error}), 400
img_io = io.BytesIO()
result.save(img_io, 'PNG')
img_io.seek(0)
logger.info('Image processing completed and ready to be sent back')
return send_file(img_io, mimetype='image/png')
if __name__ == '__main__':
logger.info('Starting the Flask server...')
app.run(host='0.0.0.0', port=5000)