Spaces:
Sleeping
Sleeping
import atexit | |
import base64 | |
import hashlib | |
import os | |
import uuid | |
import numpy as np | |
from apscheduler.schedulers.background import BackgroundScheduler | |
from flask import Flask, jsonify, request, logging as flog | |
from flask_limiter.util import get_remote_address | |
from ultralytics import YOLO | |
app = Flask(__name__) | |
onnx_model = YOLO('numbers_yolov8s.onnx', task="detect") | |
cls_type_array = np.array( | |
['1', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '3', '4', '5', '6', '7', '8', '9', | |
'equal', '-', '+']) | |
def load_model(): | |
global onnx_model, cls_type_array | |
onnx_model = YOLO('numbers_yolov8s.onnx', task="detect") | |
cls_type_array = np.array( | |
['1', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '3', '4', '5', '6', '7', '8', '9', | |
'equal', '-', '+']) | |
scheduler = BackgroundScheduler() | |
scheduler.add_job(func=load_model, trigger="interval", seconds=3600) | |
scheduler.start() | |
def shutdown_scheduler(): | |
scheduler.shutdown() | |
atexit.register(shutdown_scheduler) | |
def generate_hashed_uuid(): | |
# 生成随机UUID | |
random_uuid = uuid.uuid4() | |
# 将UUID转换为字符串 | |
str_uuid = str(random_uuid) | |
# 对UUID进行哈希 | |
hash_object = hashlib.sha256(str_uuid.encode()) | |
# 将哈希对象转换为十六进制字符串 | |
hex_dig = hash_object.hexdigest() | |
return hex_dig | |
def ocr_png(filename): | |
# 打开图片 | |
results = onnx_model(filename)[0] | |
# 获取cls和对应的左上角x坐标,并将cls转换为整数 | |
cls_and_x = np.array([(int(cls.item()), box[0].item()) for cls, box in zip(results.boxes.cls, results.boxes.xyxy)]) | |
# 过滤出cls小于20的结果,并根据左上角x坐标从小到大排序 | |
sorted_cls_and_x = cls_and_x[cls_and_x[:, 0] < 20] | |
sorted_cls_and_x = sorted_cls_and_x[sorted_cls_and_x[:, 1].argsort()] | |
# 使用NumPy的向量化操作获取排序后的cls列表 | |
sorted_cls = cls_type_array[sorted_cls_and_x[:, 0].astype(int)] | |
# print(sorted_cls) | |
# 对于cls为21和22,进行同样的操作 | |
sorted_cls_and_x_21_22 = cls_and_x[np.isin(cls_and_x[:, 0], [21, 22])] | |
sorted_cls_and_x_21_22 = sorted_cls_and_x_21_22[sorted_cls_and_x_21_22[:, 1].argsort()] | |
sorted_cls_21_22 = cls_type_array[sorted_cls_and_x_21_22[:, 0].astype(int)] | |
# print(sorted_cls_21_22) | |
result = sorted_cls[0] + sorted_cls_21_22[0] + sorted_cls[1] + sorted_cls_21_22[1] + sorted_cls[2] | |
return {"ocr": result, "result": eval(result)} | |
def get_ipaddr(): | |
if request.access_route: | |
print(request.access_route[0]) | |
return request.access_route[0] | |
else: | |
return request.remote_addr or '127.0.0.1' | |
handler = flog.default_handler | |
def get_token(): | |
default_token = "init_token" | |
if os.path.exists("token"): | |
return open("token", "r").read().strip() | |
return default_token | |
def check_request(required_data, data): | |
token = get_token() | |
if not data or any(key not in data for key in required_data): | |
print("Error:Invalid Request Data\n" + str(data)) | |
return False | |
if data["token"] != token: | |
print("Error:Invalid Token\n" + str(data)) | |
return False | |
return True | |
def rate_limit_exceeded(e): | |
print(get_remote_address()) | |
return jsonify(msg="Too many request"), 429 | |
def method_not_allowed(e): | |
print(get_remote_address()) | |
return jsonify(msg="Unauthorized Request"), 405 | |
def index(): | |
return jsonify(status_code=200, ip=get_ipaddr()) | |
def update_token(): | |
require_data = ["token", "new_token"] | |
data = request.get_json(force=True, silent=True) | |
if not check_request(require_data, data): | |
return jsonify(msg="Unauthorized Request"), 403 | |
token = open("token", "w+") | |
token.write(data["new_token"]) | |
token.close() | |
return jsonify(msg="Token updated successfully", success=True) | |
def solver_captcha(): | |
require_data = ["token", "data"] | |
data = request.get_json(force=True, silent=True) | |
if not check_request(require_data, data): | |
return jsonify(msg="Unauthorized Request"), 403 | |
file_name = generate_hashed_uuid() | |
try: | |
image_data = base64.b64decode(data["data"]) | |
with open(f"{file_name}.png", "wb") as f: | |
f.write(image_data) | |
resp = ocr_png(f"{file_name}.png") | |
return resp | |
except Exception as e: | |
print(e) | |
return "error", 500 | |
finally: | |
os.remove(f"{file_name}.png") | |
app.run(host="0.0.0.0", port=8081) | |