Spaces:
Build error
Build error
''' | |
M-LSD | |
Copyright 2021-present NAVER Corp. | |
Apache License v2.0 | |
''' | |
# for demo | |
import os | |
from flask import Flask, request, session, json, Response, render_template, abort, send_from_directory | |
import requests | |
from urllib.request import urlopen | |
from io import BytesIO | |
import uuid | |
import cv2 | |
import time | |
import argparse | |
# for tflite | |
import numpy as np | |
from PIL import Image | |
import tensorflow as tf | |
# for square detector | |
from utils import pred_squares | |
os.environ['CUDA_VISIBLE_DEVICES'] = '' # CPU mode | |
# flask | |
app = Flask(__name__) | |
logger = app.logger | |
logger.info('init demo app') | |
# config | |
parser = argparse.ArgumentParser() | |
## model parameters | |
parser.add_argument('--tflite_path', default='./tflite_models/M-LSD_512_large_fp16.tflite', type=str) | |
parser.add_argument('--input_size', default=512, type=int, | |
help='The size of input images.') | |
## LSD parameter | |
parser.add_argument('--score_thr', default=0.10, type=float, | |
help='Discard center points when the score < score_thr.') | |
## intersection point parameters | |
parser.add_argument('--outside_ratio', default=0.10, type=float, | |
help='''Discard an intersection point | |
when it is located outside a line segment farther than line_length * outside_ratio.''') | |
parser.add_argument('--inside_ratio', default=0.50, type=float, | |
help='''Discard an intersection point | |
when it is located inside a line segment farther than line_length * inside_ratio.''') | |
## ranking boxes parameters | |
parser.add_argument('--w_overlap', default=0.0, type=float, | |
help='''When increasing w_overlap, the final box tends to overlap with | |
the detected line segments as much as possible.''') | |
parser.add_argument('--w_degree', default=1.14, type=float, | |
help='''When increasing w_degree, the final box tends to be | |
a parallel quadrilateral with reference to the angle of the box.''') | |
parser.add_argument('--w_length', default=0.03, type=float, | |
help='''When increasing w_length, the final box tends to be | |
a parallel quadrilateral with reference to the length of the box.''') | |
parser.add_argument('--w_area', default=1.84, type=float, | |
help='When increasing w_area, the final box tends to be the largest one out of candidates.') | |
parser.add_argument('--w_center', default=1.46, type=float, | |
help='When increasing w_center, the final box tends to be located in the center of input image.') | |
## flask demo parameter | |
parser.add_argument('--port', default=5000, type=int, | |
help='flask demo will be running on http://0.0.0.0:port/') | |
class model_graph: | |
def __init__(self, args): | |
self.interpreter, self.input_details, self.output_details = self.load_tflite(args.tflite_path) | |
self.params = {'score': args.score_thr,'outside_ratio': args.outside_ratio,'inside_ratio': args.inside_ratio, | |
'w_overlap': args.w_overlap,'w_degree': args.w_degree,'w_length': args.w_length, | |
'w_area': args.w_area,'w_center': args.w_center} | |
self.args = args | |
def load_tflite(self, tflite_path): | |
interpreter = tf.lite.Interpreter(model_path=tflite_path) | |
interpreter.allocate_tensors() | |
input_details = interpreter.get_input_details() | |
output_details = interpreter.get_output_details() | |
return interpreter, input_details, output_details | |
def pred_tflite(self, image): | |
segments, squares, score_array, inter_points = pred_squares(image, self.interpreter, self.input_details, self.output_details, [self.args.input_size, self.args.input_size], params=self.params) | |
output = {} | |
output['segments'] = segments | |
output['squares'] = squares | |
output['scores'] = score_array | |
output['inter_points'] = inter_points | |
return output | |
def read_image(self, image_url): | |
response = requests.get(image_url, stream=True) | |
image = np.asarray(Image.open(BytesIO(response.content)).convert('RGB')) | |
max_len = 1024 | |
h, w, _ = image.shape | |
org_shape = [h, w] | |
max_idx = np.argmax(org_shape) | |
max_val = org_shape[max_idx] | |
if max_val > max_len: | |
min_idx = (max_idx + 1) % 2 | |
ratio = max_len / max_val | |
new_min = org_shape[min_idx] * ratio | |
new_shape = [0, 0] | |
new_shape[max_idx] = 1024 | |
new_shape[min_idx] = new_min | |
image = cv2.resize(image, (int(new_shape[1]), int(new_shape[0])), interpolation=cv2.INTER_AREA) | |
return image | |
def init_resize_image(self, im, maximum_size=1024): | |
h, w, _ = im.shape | |
size = [h, w] | |
max_arg = np.argmax(size) | |
max_len = size[max_arg] | |
min_arg = max_arg - 1 | |
min_len = size[min_arg] | |
if max_len < maximum_size: | |
return im | |
else: | |
ratio = maximum_size / max_len | |
max_len = max_len * ratio | |
min_len = min_len * ratio | |
size[max_arg] = int(max_len) | |
size[min_arg] = int(min_len) | |
im = cv2.resize(im, (size[1], size[0]), interpolation = cv2.INTER_AREA) | |
return im | |
def decode_image(self, session_id, rawimg): | |
dirpath = os.path.join('static/results', session_id) | |
if not os.path.exists(dirpath): | |
os.makedirs(dirpath) | |
save_path = os.path.join(dirpath, 'input.png') | |
input_image_url = os.path.join(dirpath, 'input.png') | |
img = cv2.imdecode(np.frombuffer(rawimg, dtype='uint8'), 1)[:,:,::-1] | |
img = self.init_resize_image(img) | |
cv2.imwrite(save_path, img[:,:,::-1]) | |
return img, input_image_url | |
def draw_output(self, image, output, save_path='test.png'): | |
color_dict = {'red': [255, 0, 0], | |
'green': [0, 255, 0], | |
'blue': [0, 0, 255], | |
'cyan': [0, 255, 255], | |
'black': [0, 0, 0], | |
'yellow': [255, 255, 0], | |
'dark_yellow': [200, 200, 0]} | |
line_image = image.copy() | |
square_image = image.copy() | |
square_candidate_image = image.copy() | |
line_thick = 5 | |
# output > line array | |
for line in output['segments']: | |
x_start, y_start, x_end, y_end = [int(val) for val in line] | |
cv2.line(line_image, (x_start, y_start), (x_end, y_end), color_dict['red'], line_thick) | |
inter_image = line_image.copy() | |
for pt in output['inter_points']: | |
x, y = [int(val) for val in pt] | |
cv2.circle(inter_image, (x, y), 10, color_dict['blue'], -1) | |
for square in output['squares']: | |
cv2.polylines(square_candidate_image, [square.reshape([-1, 1, 2])], True, color_dict['dark_yellow'], line_thick) | |
for square in output['squares'][0:1]: | |
cv2.polylines(square_image, [square.reshape([-1, 1, 2])], True, color_dict['yellow'], line_thick) | |
for pt in square: | |
cv2.circle(square_image, (int(pt[0]), int(pt[1])), 10, color_dict['cyan'], -1) | |
''' | |
square image | square candidates image | |
inter image | line image | |
''' | |
output_image = self.init_resize_image(square_image, 512) | |
output_image = np.concatenate([output_image, self.init_resize_image(square_candidate_image, 512)], axis=1) | |
output_image_tmp = np.concatenate([self.init_resize_image(inter_image, 512), self.init_resize_image(line_image, 512)], axis=1) | |
output_image = np.concatenate([output_image, output_image_tmp], axis=0) | |
cv2.imwrite(save_path, output_image[:,:,::-1]) | |
return output_image | |
def save_output(self, session_id, input_image_url, image, output): | |
dirpath = os.path.join('static/results', session_id) | |
if not os.path.exists(dirpath): | |
os.makedirs(dirpath) | |
save_path = os.path.join(dirpath, 'output.png') | |
self.draw_output(image, output, save_path=save_path) | |
output_image_url = os.path.join(dirpath, 'output.png') | |
rst = {} | |
rst['input_image_url'] = input_image_url | |
rst['session_id'] = session_id | |
rst['output_image_url'] = output_image_url | |
with open(os.path.join(dirpath, 'results.json'), 'w') as f: | |
json.dump(rst, f) | |
def init_worker(args): | |
global model | |
model = model_graph(args) | |
def index(): | |
return render_template('index_scan.html', session_id='dummy_session_id') | |
def index_post(): | |
request_start = time.time() | |
configs = request.form | |
session_id = str(uuid.uuid1()) | |
image_url = configs['image_url'] # image_url | |
if len(image_url) == 0: | |
bio = BytesIO() | |
request.files['image'].save(bio) | |
rawimg = bio.getvalue() | |
image, image_url = model.decode_image(session_id, rawimg) | |
else: | |
image = model.read_image(image_url) | |
output = model.pred_tflite(image) | |
model.save_output(session_id, image_url, image, output) | |
return render_template('index_scan.html', session_id=session_id) | |
def favicon(): | |
return send_from_directory(os.path.join(app.root_path, 'static'), | |
'favicon.ico', mimetype='image/vnd.microsoft.icon') | |
if __name__ == '__main__': | |
args = parser.parse_args() | |
init_worker(args) | |
app.run(host='0.0.0.0', port=args.port) | |