ArslanRobo's picture
Create app.py
ac614bf verified
raw
history blame
5.25 kB
"""We can use Gradio to build the UI and then make it compatible for the Hugging face."""
import gradio as gr
import cv2
import numpy as np
import imutils
from PIL import Image
cv2.ocl.setUseOpenCL(False)
# Sharpening function
def image_sharpening(image):
kernel_sharpening = np.array([[-1, -1, -1],
[-1, 9, -1],
[-1, -1, -1]])
sharpened = cv2.filter2D(image, -1, kernel_sharpening)
return sharpened
# Remove black borders function
def remove_black_region(result):
gray = cv2.cvtColor(result, cv2.COLOR_BGR2GRAY)
thresh = cv2.threshold(gray, 0, 255, cv2.THRESH_BINARY)[1]
cnts = cv2.findContours(thresh.copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
cnts = imutils.grab_contours(cnts)
c = max(cnts, key=cv2.contourArea)
(x, y, w, h) = cv2.boundingRect(c)
crop = result[y:y + h, x:x + w]
return crop
# Key point detection and descriptor function
def detectAndDescribe(image, method='orb'):
if method == 'sift':
descriptor = cv2.SIFT_create()
elif method == 'brisk':
descriptor = cv2.BRISK_create()
elif method == 'orb':
descriptor = cv2.ORB_create()
(kps, features) = descriptor.detectAndCompute(image, None)
return kps, features
# Matcher creation
def createMatcher(method, crossCheck):
if method in ['sift', 'surf']:
bf = cv2.BFMatcher(cv2.NORM_L2, crossCheck=crossCheck)
else:
bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=crossCheck)
return bf
# Matching key points
def matchKeyPointsKNN(featuresA, featuresB, ratio, method):
bf = createMatcher(method, crossCheck=False)
rawMatches = bf.knnMatch(featuresA, featuresB, 2)
matches = []
for m, n in rawMatches:
if m.distance < n.distance * ratio:
matches.append(m)
return matches
# Homography calculation
def getHomography(kpsA, kpsB, featuresA, featuresB, matches, reprojThresh=4.0):
kpsA = np.float32([kp.pt for kp in kpsA])
kpsB = np.float32([kp.pt for kp in kpsB])
if len(matches) > 4:
ptsA = np.float32([kpsA[m.queryIdx] for m in matches])
ptsB = np.float32([kpsB[m.trainIdx] for m in matches])
(H, status) = cv2.findHomography(ptsA, ptsB, cv2.RANSAC, reprojThresh)
return matches, H, status
else:
return None
# Stitching function for two images
def stitch_two_images(queryImg, trainImg, feature_extractor):
queryImg_gray = cv2.cvtColor(queryImg, cv2.COLOR_BGR2GRAY)
trainImg_gray = cv2.cvtColor(trainImg, cv2.COLOR_BGR2GRAY)
kpsA, featuresA = detectAndDescribe(trainImg_gray, method=feature_extractor)
kpsB, featuresB = detectAndDescribe(queryImg_gray, method=feature_extractor)
matches = matchKeyPointsKNN(featuresA, featuresB, ratio=0.75, method=feature_extractor)
M = getHomography(kpsA, kpsB, featuresA, featuresB, matches, reprojThresh=5)
if M is None:
return None
(matches, H, status) = M
width = trainImg.shape[1] + queryImg.shape[1]
height = trainImg.shape[0] + queryImg.shape[0]
result = cv2.warpPerspective(trainImg, H, (width, height))
result[0:queryImg.shape[0], 0:queryImg.shape[1]] = queryImg
crop_image = remove_black_region(result)
return crop_image
# Calculate target brightness
def calculate_target_brightness(images):
brightness_values = [np.mean(image.astype(np.float32)) for image in images]
return np.mean(brightness_values)
# Brightness adjustment
def global_brightness_adjustment(images, target_brightness):
adjusted_images = []
for image in images:
image_float = image.astype(np.float32)
avg_brightness = np.mean(image_float)
brightness_shift = target_brightness - avg_brightness
adjusted_image = image_float + brightness_shift
adjusted_image = np.clip(adjusted_image, 0, 255).astype(np.uint8)
adjusted_images.append(adjusted_image)
return adjusted_images
# Main Stitching function
def stitch_images(uploaded_files, feature_extractor):
images = [cv2.cvtColor(np.array(Image.open(file)), cv2.COLOR_RGB2BGR) for file in uploaded_files]
if len(images) == 0:
return None
# feature_extractor = 'orb'
target_brightness = calculate_target_brightness(images)
adjusted_images = global_brightness_adjustment(images, target_brightness)
stitched_image = adjusted_images[0]
for i in range(1, len(adjusted_images)):
queryImg = stitched_image
trainImg = adjusted_images[i]
stitched_image = stitch_two_images(queryImg, trainImg, feature_extractor)
return cv2.cvtColor(stitched_image, cv2.COLOR_BGR2RGB)
# Gradio interface with feature extractor selector
with gr.Blocks() as demo:
gr.Markdown("## Image Stitching App with Feature Extractor Selection")
image_input = gr.Files(label="Upload Images", type="filepath")
extractor_input = gr.Dropdown(choices=["orb", "sift", "brisk"], label="Feature Extractor", value="orb")
image_output = gr.Image(type="numpy", label="Stitched Image")
process_button = gr.Button("Process Image")
process_button.click(stitch_images, inputs=[image_input, extractor_input], outputs=image_output)
# Launch the Gradio app
demo.launch()