koCSN_SAPR / main.py
yuneun92's picture
Upload 2 files
f59d2d3 verified
raw
history blame
No virus
5.1 kB
"""
This is main.py
"""
from fastapi import FastAPI, File, UploadFile, Form, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
from fastapi.templating import Jinja2Templates
from pydantic import BaseModel
from typing import List
class AppData:
def __init__(self):
self.file_content = ""
self.name_list = []
self.place = []
self.times = []
self.name_dic = {}
self.end_output = []
class ItemListRequest(BaseModel):
nameList: List[str]
app_data = AppData()
# ์„ค์ •
app = FastAPI()
app.mount("/static", StaticFiles(directory="static"), name="static")
templates = Jinja2Templates(directory="templates")
@app.get("/", response_class=HTMLResponse)
async def page_home(request: Request):
"""INDEX.HTML ํ™”๋ฉด"""
return templates.TemplateResponse("index.html", {"request": request})
@app.get("/put.html", response_class=HTMLResponse)
async def page_put(request: Request):
"""PUT.HTML ํ™”๋ฉด"""
return templates.TemplateResponse("put.html", {"request": request})
@app.get("/confirm.html", response_class=HTMLResponse)
async def page_confirm(request: Request):
"""confirm.HTML ํ™”๋ฉด"""
return templates.TemplateResponse("confirm.html",{
"request": request, "file_content": app_data.file_content})
@app.get("/result.html", response_class=HTMLResponse)
async def page_result(request: Request):
"""result.HTML ํ™”๋ฉด"""
return templates.TemplateResponse("result.html", {"request": request})
@app.get("/user.html", response_class=HTMLResponse)
async def page_user(request: Request):
"""user.HTML ํ™”๋ฉด"""
return templates.TemplateResponse("user.html", {"request": request})
@app.get("/final.html", response_class=HTMLResponse)
async def page_final(request: Request):
"""final.HTML ํ™”๋ฉด"""
return templates.TemplateResponse("final.html", {"request": request,
"output": app_data.end_output,
"place": app_data.place,
"time": app_data.times})
@app.post("/upload", response_class=HTMLResponse)
async def upload_file(file: UploadFile = File(...)):
"""ํŒŒ์ผ ์—…๋กœ๋“œ ๋ฐ ์ €์žฅ"""
with open("uploads/" + file.filename, "wb") as f:
f.write(file.file.read())
with open("uploads/" + file.filename, "r", encoding="utf-8") as f:
app_data.file_content = f.read()
return RedirectResponse(url="/put.html")
@app.post("/ners", response_class=JSONResponse)
async def ner_file():
"""์ €์žฅ๋œ ํŒŒ์ผ์— NER ์ž‘์—…์„ ํ•ด์„œ ํ™”์ž๋ž‘ ์žฅ์†Œ๋ฅผ ๊ตฌ๋ถ„"""
from utils.load_model import load_ner
from utils.input_process import make_ner_input
from utils.ner_utils import make_name_list, show_name_list, combine_similar_names
content = app_data.file_content
_, ner_checkpoint = load_ner()
contents = make_ner_input(content)
name_list, times, places = make_name_list(contents, ner_checkpoint)
name_dic = show_name_list(name_list)
similar_name = combine_similar_names(name_dic)
result_list = [', '.join(names) for names, _ in similar_name.items()]
app_data.place = ' '.join(places)
app_data.times = ' '.join(times)
# JSONResponse๋กœ ์‘๋‹ต
return JSONResponse(content={"itemList": result_list})
@app.post("/kcsn", response_class=JSONResponse)
async def kcsn_file(request_data: ItemListRequest):
"""์‚ฌ์šฉ์ž๊ฐ€ ์˜ฌ๋ ค์ค€ ํŒŒ์ผ์— ๋Œ€ํ•ด์„œ KCSN ๋ชจ๋ธ ๋™์ž‘"""
import torch
from utils.fs_utils import get_alias2id, find_speak, making_script
from utils.input_process import make_instance_list, input_data_loader
from utils.train_model import KCSN
from utils.ner_utils import convert_name2codename, convert_codename2name
content = app_data.file_content
name_list = request_data.nameList
name_dic = {}
for idx, name in enumerate(name_list):
name_dic[f'&C{idx:02d}&'] = name.split(', ')
content_re = convert_name2codename(name_dic, content)
# checkpoint = torch.load('./model/final.pth')
# model = checkpoint['model']
# model.to('cpu')
# tokenizer = checkpoint['tokenizer']
from utils.arguments import get_train_args
from transformers import AutoTokenizer
args = get_train_args()
path ='model/model.ckpt'
model = KCSN(args)
model.to('cpu')
checkpoint = torch.load(path)
tokenizer = AutoTokenizer.from_pretrained(args.bert_pretrained_dir)
model.load_state_dict(checkpoint['model'])
check_name = 'data/name.txt'
alias2id = get_alias2id(check_name)
instances, instance_num = make_instance_list(content_re)
inputs = input_data_loader(instances, alias2id)
output = find_speak(model, inputs, tokenizer, alias2id)
outputs = convert_codename2name(name_dic, output)
app_data.end_output = making_script(content, outputs, instance_num)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="127.0.0.1", port=8000)