File size: 3,180 Bytes
26b3f96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fcc142
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import cv2
import numpy as np
import gradio as gr
from PIL import Image
import tempfile

def equalize_exposure(images):
    equalized_images = []
    for img in images:
        img_lab = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
        l, a, b = cv2.split(img_lab)
        # Apply CLAHE to L-channel
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        l_eq = clahe.apply(l)
        img_eq = cv2.merge((l_eq, a, b))
        img_eq = cv2.cvtColor(img_eq, cv2.COLOR_LAB2BGR)
        equalized_images.append(img_eq)
    return equalized_images

def stitch_images(image_files):
    # Load images and convert to BGR format
    images = []
    for file in image_files:
        img_pil = Image.open(file).convert('RGB')
        img_bgr = cv2.cvtColor(np.array(img_pil), cv2.COLOR_RGB2BGR)
        images.append(img_bgr)

    # Check if there are at least two images
    if len(images) < 2:
        print("Need at least two images to stitch.")
        return None, None

    # Equalize exposure of images
    images_eq = equalize_exposure(images)

    # Create a Stitcher object using the default parameters
    stitcher = cv2.Stitcher_create(cv2.Stitcher_PANORAMA)

    # Configure stitcher parameters
    stitcher.setPanoConfidenceThresh(0.8)
    stitcher.setWaveCorrection(False)

    # Perform stitching
    status, stitched = stitcher.stitch(images_eq)

    if status != cv2.Stitcher_OK:
        print(f"Image stitching failed ({status})")
        return None, None

    # Perspective correction
    # Convert to grayscale
    gray = cv2.cvtColor(stitched, cv2.COLOR_BGR2GRAY)
    # Find all non-zero points (non-black areas)
    coords = cv2.findNonZero(gray)
    x, y, w, h = cv2.boundingRect(coords)

    # Define source and destination points for perspective transform
    src_pts = np.float32([
        [x, y],
        [x + w, y],
        [x + w, y + h],
        [x, y + h]
    ])

    dst_pts = np.float32([
        [0, 0],
        [w, 0],
        [w, h],
        [0, h]
    ])

    # Compute the perspective transform matrix and apply it
    M = cv2.getPerspectiveTransform(src_pts, dst_pts)
    warped = cv2.warpPerspective(stitched, M, (w, h))

    # Convert corrected image back to PIL format
    stitched_rgb = cv2.cvtColor(warped, cv2.COLOR_BGR2RGB)
    stitched_image = Image.fromarray(stitched_rgb)

    # Save the stitched image to a temporary file for download
    temp_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
    stitched_image.save(temp_file.name)

    return stitched_image, temp_file.name

# Gradio Interface
with gr.Blocks() as interface:
    gr.Markdown("<h1 style='color: #2196F3; text-align: center;'>Image Stitcher 🧵</h1>")
    gr.Markdown("<h3 style='color: #2196F3; text-align: center;'>Upload the images you want to stitch</h3>")

    image_upload = gr.Files(type="filepath", label="Upload Images")
    stitch_button = gr.Button("Stitch", variant="primary")
    stitched_image = gr.Image(type="pil", label="Stitched Image")
    download_button = gr.File(label="Download Stitched Image")

    stitch_button.click(stitch_images, inputs=image_upload, outputs=[stitched_image, download_button])

interface.launch()