Nguyễn Bá Thiêm
Add image super resolution functionality
b16ab70
raw
history blame
4.23 kB
import time
import streamlit as st
import numpy as np
from PIL import Image
import cv2 # If you're using OpenCV for image processing
from io import BytesIO
import base64
from models.HAT.hat import *
# Initialize session state for enhanced images
if 'hat_enhanced_image' not in st.session_state:
st.session_state['hat_enhanced_image'] = None
if 'rcan_enhanced_image' not in st.session_state:
st.session_state['rcan_enhanced_image'] = None
if 'hat_clicked' not in st.session_state:
st.session_state['hat_clicked'] = False
if 'rcan_clicked' not in st.session_state:
st.session_state['rcan_clicked'] = False
st.markdown("<h1 style='text-align: center'>Image Super Resolution</h1>", unsafe_allow_html=True)
# Sidebar for navigation
st.sidebar.title("Options")
app_mode = st.sidebar.selectbox("Choose the input source",
["Upload image", "Take a photo"])
# Depending on the choice, show the uploader widget or webcam capture
if app_mode == "Upload image":
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png"], on_change=lambda: reset_states())
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
elif app_mode == "Take a photo":
# Using JS code to access user's webcam
camera_input = st.camera_input("Take a picture", on_change=lambda: reset_states())
if camera_input is not None:
# Convert the camera image to an RGB image
image = Image.open(camera_input).convert("RGB")
def reset_states():
st.session_state['hat_enhanced_image'] = None
st.session_state['rcan_enhanced_image'] = None
st.session_state['hat_clicked'] = False
st.session_state['rcan_clicked'] = False
def get_image_download_link(img, filename):
"""Generates a link allowing the PIL image to be downloaded"""
# Convert the PIL image to Bytes
buffered = BytesIO()
img.save(buffered, format="PNG")
return st.download_button(
label="Download Image",
data=buffered.getvalue(),
file_name=filename,
mime="image/png"
)
if 'image' in locals():
# st.image(image, caption='Uploaded Image', use_column_width=True)
st.write("")
if st.button('Enhance with HAT'):
with st.spinner('Processing using HAT...'):
with st.spinner('Wait for it... the model is processing the image'):
# Simulate a delay for processing image
enhanced_image = HAT_for_deployment(image)
st.session_state['hat_enhanced_image'] = enhanced_image
st.session_state['hat_clicked'] = True
st.success('Done!')
# Display the low and high resolution images side by side
if st.session_state['hat_enhanced_image'] is not None:
col1, col2 = st.columns(2)
col1.header("Original")
col1.image(image, use_column_width=True)
col2.header("Enhanced")
col2.image(st.session_state['hat_enhanced_image'], use_column_width=True)
with col2:
get_image_download_link(st.session_state['hat_enhanced_image'], 'hat_enhanced.jpg')
if st.button('Enhance with RCAN'):
with st.spinner('Processing using RCAN...'):
with st.spinner('Wait for it... the model is processing the image'):
# Simulate a delay for processing image
time.sleep(2) # replace this with actual model processing code
enhanced_image = image
# Display the low and high resolution images side by side
st.session_state['rcan_enhanced_image'] = enhanced_image
st.session_state['rcan_clicked'] = True
st.success('Done!')
if st.session_state['rcan_enhanced_image'] is not None:
col1, col2 = st.columns(2)
col1.header("Original")
col1.image(image, use_column_width=True)
col2.header("Enhanced")
col2.image(st.session_state['rcan_enhanced_image'], use_column_width=True)
with col2:
get_image_download_link(st.session_state['rcan_enhanced_image'], 'rcan_enhanced.jpg')