upscale-api / main.py
Arifzyn19
Add application file
8b019f8
import torch
from PIL import Image
from RealESRGAN import RealESRGAN
import os
# Function to perform inference and upscale the image
def upscale_image(image_path, scale):
# Initialize device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initialize and load models
model = RealESRGAN(device, scale=scale, weights_only=False)
model.load_weights(f'weights/RealESRGAN_x{scale}.pth', download=True)
# Open the image from the file path
try:
image = Image.open(image_path)
except Exception as e:
print(f"Error opening the image: {e}")
return
# Perform inference
try:
result = model.predict(image.convert('RGB'))
except Exception as e:
print(f"Error during inference: {e}")
return
# Save the upscaled image
output_path = f'upscaled_image_x{scale}.png'
result.save(output_path, 'PNG')
print(f"Upscaled image saved to {output_path}")
if __name__ == '__main__':
# Path of the image to be upscaled
image_path = './groot.jpeg'
# Scaling factor (2x, 4x, or 8x)
scale = input("Enter the scaling factor (2, 4, or 8): ")
# Validate scale
if scale not in ['2', '4', '8']:
print("Invalid scale factor. Please enter 2, 4, or 8.")
else:
upscale_image(image_path, int(scale))