import time
import gc
import torch

from PIL import Image
from torchvision import transforms
import gradio as gr

from transformers import AutoConfig, AutoModelForImageSegmentation

# 1) Wrap config loading in a helper that monkey-patches a dummy get_text_config().

def load_model():
    config = AutoConfig.from_pretrained("zhengpeng7/BiRefNet_lite", trust_remote_code=True)
    config.is_encoder_decoder = False

    # We define a dummy function that returns a minimal object
    # with a tie_word_embeddings attribute, so tie_weights() won't fail.
    def dummy_text_config(decoder=True):
        class DummyTextConfig:
            tie_word_embeddings = False
        return DummyTextConfig()

    # Patch the config so huggingface code won't blow up
    setattr(config, "get_text_config", dummy_text_config)

    model = AutoModelForImageSegmentation.from_pretrained(
        "zhengpeng7/BiRefNet_lite",
        config=config,
        trust_remote_code=True
    )

    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model.to(device)
    model.eval()
    return model, device

# 2) Initialize global model & device
birefnet, device = load_model()

# 3) Preprocessing transform
image_size = (1024, 1024)
transform_image = transforms.Compose([
    transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

def run_inference(images, model, device):
    inputs = []
    original_sizes = []
    for img in images:
        original_sizes.append(img.size)
        inputs.append(transform_image(img))

    input_tensor = torch.stack(inputs).to(device)
    try:
        with torch.no_grad():
            # If the model returns multiple outputs, adapt as needed
            output = model(input_tensor)
            # The last element might be your segmentation mask. Adjust if needed:
            # e.g. preds = output[-1] if it returns a list/tuple
            # or preds = output.logits if it returns a named field
            # The original example used `output[-1].sigmoid()`, so:
            preds = output[-1].sigmoid().cpu()
    except torch.OutOfMemoryError:
        del input_tensor
        torch.cuda.empty_cache()
        raise

    # Post-process
    results = []
    for i, img in enumerate(images):
        pred = preds[i].squeeze()
        pred_pil = transforms.ToPILImage()(pred)
        mask = pred_pil.resize(original_sizes[i])
        result = Image.new("RGBA", original_sizes[i], (0, 0, 0, 0))
        result.paste(img, mask=mask)
        results.append(result)

    # Cleanup
    del input_tensor, preds
    gc.collect()
    torch.cuda.empty_cache()

    return results

def binary_search_max(images):
    low, high = 1, len(images)
    best, best_count = None, 0

    while low <= high:
        mid = (low + high) // 2
        batch = images[:mid]
        try:
            # Re-load the model to avoid leftover memory fragmentation
            global birefnet, device
            birefnet, device = load_model()
            res = run_inference(batch, birefnet, device)
            best, best_count = res, mid
            low = mid + 1
        except torch.OutOfMemoryError:
            high = mid - 1

    return best, best_count

def extract_objects(filepaths):
    images = [Image.open(p).convert("RGB") for p in filepaths]
    start_time = time.time()

    # First attempt: all images at once
    try:
        results = run_inference(images, birefnet, device)
        end_time = time.time()
        total_time = end_time - start_time
        summary = f"Total request time: {total_time:.2f}s\nProcessed {len(images)} images successfully."
        return results, summary

    except torch.OutOfMemoryError:
        # If it fails with OOM, do a fallback
        oom_time = time.time()
        initial_attempt_time = oom_time - start_time
        
        best, best_count = binary_search_max(images)
        end_time = time.time()
        total_time = end_time - start_time

        if best is None:
            # Not even 1 image can be processed
            summary = (
                f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n"
                f"Could not process even a single image.\n"
                f"Total time including fallback attempts: {total_time:.2f}s."
            )
            return [], summary
        else:
            summary = (
                f"Initial attempt OOM after {initial_attempt_time:.2f}s.\n"
                f"Found that {best_count} images can be processed without OOM.\n"
                f"Total time including fallback attempts: {total_time:.2f}s.\n"
                f"Next time, try using up to {best_count} images."
            )
            return best, summary

iface = gr.Interface(
    fn=extract_objects,
    inputs=gr.Files(label="Upload Multiple Images", type="filepath", file_count="multiple"),
    outputs=[gr.Gallery(label="Processed Images"), gr.Textbox(label="Timing Info")],
    title="BiRefNet Bulk Background Removal (with fallback)",
    description="Upload multiple images. If OOM occurs, we fallback to smaller batches."
)

if __name__ == "__main__":
    iface.launch()