nubbers / api.py
zhou12189108's picture
Upload api.py
802b66f verified
import atexit
import base64
import hashlib
import os
import uuid
import numpy as np
from apscheduler.schedulers.background import BackgroundScheduler
from flask import Flask, jsonify, request, logging as flog
from flask_limiter.util import get_remote_address
from ultralytics import YOLO
app = Flask(__name__)
onnx_model = YOLO('numbers_yolov8s.onnx', task="detect")
cls_type_array = np.array(
['1', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '3', '4', '5', '6', '7', '8', '9',
'equal', '-', '+'])
def load_model():
global onnx_model, cls_type_array
onnx_model = YOLO('numbers_yolov8s.onnx', task="detect")
cls_type_array = np.array(
['1', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '2', '20', '3', '4', '5', '6', '7', '8', '9',
'equal', '-', '+'])
scheduler = BackgroundScheduler()
scheduler.add_job(func=load_model, trigger="interval", seconds=3600)
scheduler.start()
def shutdown_scheduler():
scheduler.shutdown()
atexit.register(shutdown_scheduler)
def generate_hashed_uuid():
# 生成随机UUID
random_uuid = uuid.uuid4()
# 将UUID转换为字符串
str_uuid = str(random_uuid)
# 对UUID进行哈希
hash_object = hashlib.sha256(str_uuid.encode())
# 将哈希对象转换为十六进制字符串
hex_dig = hash_object.hexdigest()
return hex_dig
def ocr_png(filename):
# 打开图片
results = onnx_model(filename)[0]
# 获取cls和对应的左上角x坐标,并将cls转换为整数
cls_and_x = np.array([(int(cls.item()), box[0].item()) for cls, box in zip(results.boxes.cls, results.boxes.xyxy)])
# 过滤出cls小于20的结果,并根据左上角x坐标从小到大排序
sorted_cls_and_x = cls_and_x[cls_and_x[:, 0] < 20]
sorted_cls_and_x = sorted_cls_and_x[sorted_cls_and_x[:, 1].argsort()]
# 使用NumPy的向量化操作获取排序后的cls列表
sorted_cls = cls_type_array[sorted_cls_and_x[:, 0].astype(int)]
# print(sorted_cls)
# 对于cls为21和22,进行同样的操作
sorted_cls_and_x_21_22 = cls_and_x[np.isin(cls_and_x[:, 0], [21, 22])]
sorted_cls_and_x_21_22 = sorted_cls_and_x_21_22[sorted_cls_and_x_21_22[:, 1].argsort()]
sorted_cls_21_22 = cls_type_array[sorted_cls_and_x_21_22[:, 0].astype(int)]
# print(sorted_cls_21_22)
result = sorted_cls[0] + sorted_cls_21_22[0] + sorted_cls[1] + sorted_cls_21_22[1] + sorted_cls[2]
return {"ocr": result, "result": eval(result)}
def get_ipaddr():
if request.access_route:
print(request.access_route[0])
return request.access_route[0]
else:
return request.remote_addr or '127.0.0.1'
handler = flog.default_handler
def get_token():
default_token = "init_token"
if os.path.exists("token"):
return open("token", "r").read().strip()
return default_token
def check_request(required_data, data):
token = get_token()
if not data or any(key not in data for key in required_data):
print("Error:Invalid Request Data\n" + str(data))
return False
if data["token"] != token:
print("Error:Invalid Token\n" + str(data))
return False
return True
@app.errorhandler(429)
def rate_limit_exceeded(e):
print(get_remote_address())
return jsonify(msg="Too many request"), 429
@app.errorhandler(405)
def method_not_allowed(e):
print(get_remote_address())
return jsonify(msg="Unauthorized Request"), 405
@app.route("/", methods=["GET"])
def index():
return jsonify(status_code=200, ip=get_ipaddr())
@app.route("/update/token", methods=["POST"])
def update_token():
require_data = ["token", "new_token"]
data = request.get_json(force=True, silent=True)
if not check_request(require_data, data):
return jsonify(msg="Unauthorized Request"), 403
token = open("token", "w+")
token.write(data["new_token"])
token.close()
return jsonify(msg="Token updated successfully", success=True)
@app.route("/api/solve", methods=["POST"])
def solver_captcha():
require_data = ["token", "data"]
data = request.get_json(force=True, silent=True)
if not check_request(require_data, data):
return jsonify(msg="Unauthorized Request"), 403
file_name = generate_hashed_uuid()
try:
image_data = base64.b64decode(data["data"])
with open(f"{file_name}.png", "wb") as f:
f.write(image_data)
resp = ocr_png(f"{file_name}.png")
return resp
except Exception as e:
print(e)
return "error", 500
finally:
os.remove(f"{file_name}.png")
app.run(host="0.0.0.0", port=8081)