import torch
import numpy as np
from super_image import HanModel
from typing import Union

def create_superimage_model(
    device: Union[str, torch.device] = "cuda"
) -> HanModel:
    """ Create the super image model

        HanModel: The super image model
    return HanModel.from_pretrained('eugenesiow/han', scale=4).to(device)

def run_superimage(
    model: HanModel,
    lr: np.ndarray,
    hr: np.ndarray,
    device: Union[str, torch.device] = "cuda"    
    """ Run the super image model

        model (HanModel): The super image model
        lr (np.ndarray): The low resolution image
        hr (np.ndarray): The high resolution image
        device (Union[str, torch.device], optional): The device to run the model on. Defaults to "cuda".

        dict: The results
    # Convert the images to tensors
    lr_tensor = (torch.from_numpy(lr[[3, 2, 1]]).to(device) / 2000).float()
    # Run the model
    with torch.no_grad():
        sr_tensor = model(lr_tensor[None])

    # Convert the tensors to numpy arrays
    lr = (lr_tensor.cpu().numpy() * 2000).astype(np.uint16)
    sr = (sr_tensor.cpu().numpy() * 2000).astype(np.uint16)

    # Return the results
    return {
        "lr": lr.squeeze(),
        "hr": hr[0:3].squeeze(),
        "sr": sr.squeeze()