HWT / app.py
ankankbhunia's picture
INIT
13580fb verified
raw
history blame
3.55 kB
import gradio as gr
from PIL import Image
import numpy as np
from io import BytesIO
import glob
import os
import time
from data.dataset import load_itw_samples, crop_
import torch
import cv2
import os
import numpy as np
from models.model import TRGAN
from params import *
from torch import nn
from data.dataset import get_transform
import pickle
from PIL import Image
import tqdm
import shutil
model_path = 'files/iam_model.pth'
batch_size = 1
print ('(1) Loading model...')
model = TRGAN(batch_size = batch_size)
model.netG.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')) )
print (model_path+' : Model loaded Successfully')
model.eval()
# Define a function to generate an image based on text and images
def generate_image(text,folder, _ch3, images):
# Your image generation logic goes here (replace with your actual implementation)
# For demonstration purposes, we'll just concatenate the uploaded images horizontally.
if images:
style_inputs, width_length = load_itw_samples(images)
elif folder:
style_inputs, width_length = load_itw_samples(folder)
else:
return None
# Load images
text = text.replace("\n", "").replace("\t", "")
text_encode = [j.encode() for j in text.split(' ')]
eval_text_encode, eval_len_text = model.netconverter.encode(text_encode)
eval_text_encode = eval_text_encode.to('cuda').repeat(batch_size, 1, 1)
input_styles, page_val = model._generate_page(style_inputs.to(DEVICE).clone(), width_length, eval_text_encode, eval_len_text, no_concat = True)
page_val = crop_(page_val[0]*255)
input_styles = crop_(input_styles[0]*255)
max_width = max(page_val.shape[1],input_styles.shape[1])
if page_val.shape[1]!=max_width:
page_val = np.concatenate([page_val, np.ones((page_val.shape[0],max_width-page_val.shape[1]))*255], 1)
else:
input_styles = np.concatenate([input_styles, np.ones((input_styles.shape[0],max_width-input_styles.shape[1]))*255], 1)
upper_pad = np.ones((45,input_styles.shape[1]))*255
input_styles = np.concatenate([upper_pad, input_styles], 0)
page_val = np.concatenate([upper_pad, page_val], 0)
page_val = Image.fromarray(page_val).convert('RGB')
input_styles = Image.fromarray(input_styles).convert('RGB')
return input_styles, page_val
# Define Gradio Interface
iface = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(value = "In the quiet hum of everyday life, the dance of existence unfolds. Time, an ever-flowing river, carries the stories of triumph and heartache. Each fleeting moment is a brushstroke on the canvas of our memories. Within the tapestry of human connection, threads of empathy weave a fabric that binds us all. Nature's symphony plays, a harmonious blend of rustling leaves and birdsong. In the labyrinth of possibility, dreams take flight. Beneath the veneer of routine, lies the extraordinary. Embrace the kaleidoscope of experience, for in the ordinary, the extraordinary often reveals itself.",label = "Input text"),
gr.Dropdown(value = "files/example_data/style-30", choices=glob.glob('files/example_data/*'), label="Choose from provided writer styles"),
gr.Markdown("### OR"),
gr.File(label="Upload multiple word images", file_count="multiple")
],
outputs=[#gr.Markdown("## Output"),
gr.Image(type="pil", label="Style Image"),
gr.Image(type="pil", label="Generated Image")]
)
# Launch the Gradio Interface
iface.launch(debug=True, share=True)