yuneun92 commited on
Commit
f59d2d3
โ€ข
1 Parent(s): 3b4a8e0

Upload 2 files

Browse files
Files changed (2) hide show
  1. _test.py +76 -0
  2. main.py +155 -0
_test.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #%%
2
+ from utils.load_model import load_ner
3
+ from utils.input_process import make_ner_input
4
+ from utils.ner_utils import make_name_list, show_name_list, combine_similar_names
5
+ import torch
6
+
7
+ from utils.train_model import KCSN
8
+ from utils.arguments import get_train_args
9
+
10
+
11
+ args = get_train_args()
12
+ path ='model/model.ckpt'
13
+ model = KCSN(args)
14
+ checkpoint = torch.load(path)
15
+ model.load_state_dict(checkpoint['model'])
16
+
17
+ # model = checkpoint['model']
18
+
19
+
20
+ # %%
21
+ with open('test/test.txt', "r", encoding="utf-8") as f:
22
+ file_content = f.read()
23
+
24
+ content = make_ner_input(file_content)
25
+ name_list, time, place = make_name_list(content, checkpoint)
26
+ name_dic = show_name_list(name_list)
27
+ similar_name = combine_similar_names(name_dic)
28
+
29
+ for i in similar_name:
30
+ print(i)
31
+
32
+ # %% CSN ๋ชจ๋ธ
33
+ import torch
34
+
35
+ from utils.fs_utils import get_alias2id, find_speak
36
+ from utils.ner_utils import make_name_list
37
+ from utils.input_process import make_ner_input, make_instance_list, input_data_loader
38
+
39
+ checkpoint = torch.load('./model/final.pth')
40
+ model = checkpoint['model']
41
+ model.to('cpu')
42
+ tokenizer = checkpoint['tokenizer']
43
+
44
+ check_name = './data/name.txt'
45
+ alias2id = get_alias2id(check_name)
46
+
47
+ with open('test/KoCSN_test.txt', "r", encoding="utf-8") as f:
48
+ file_content = f.read()
49
+
50
+ instances, instance_num = make_instance_list(file_content)
51
+ inputs = input_data_loader(instances, alias2id)
52
+ output = find_speak(model, inputs, tokenizer, alias2id)
53
+
54
+
55
+
56
+ def make_script(texts, instance_num, output):
57
+ script = []
58
+ for idx, text in enumerate(texts):
59
+ if idx in instance_num
60
+
61
+
62
+ #%%
63
+
64
+ n = int(input())
65
+ num = list(map(int, input().split()))
66
+ ans = []
67
+
68
+ for i, j in enumerate(num):
69
+ print(i, j)
70
+ if len(ans) == 0:
71
+ ans.append(i+1)
72
+ else:
73
+ ans.insert(len(ans)-j, i+1)
74
+
75
+ print(ans)
76
+ # %%
main.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This is main.py
3
+ """
4
+ from fastapi import FastAPI, File, UploadFile, Form, Request
5
+ from fastapi.staticfiles import StaticFiles
6
+ from fastapi.responses import HTMLResponse, RedirectResponse, JSONResponse
7
+ from fastapi.templating import Jinja2Templates
8
+ from pydantic import BaseModel
9
+ from typing import List
10
+
11
+ class AppData:
12
+ def __init__(self):
13
+ self.file_content = ""
14
+ self.name_list = []
15
+ self.place = []
16
+ self.times = []
17
+ self.name_dic = {}
18
+ self.end_output = []
19
+
20
+
21
+ class ItemListRequest(BaseModel):
22
+ nameList: List[str]
23
+
24
+
25
+ app_data = AppData()
26
+
27
+ # ์„ค์ •
28
+ app = FastAPI()
29
+ app.mount("/static", StaticFiles(directory="static"), name="static")
30
+ templates = Jinja2Templates(directory="templates")
31
+
32
+
33
+ @app.get("/", response_class=HTMLResponse)
34
+ async def page_home(request: Request):
35
+ """INDEX.HTML ํ™”๋ฉด"""
36
+ return templates.TemplateResponse("index.html", {"request": request})
37
+
38
+
39
+ @app.get("/put.html", response_class=HTMLResponse)
40
+ async def page_put(request: Request):
41
+ """PUT.HTML ํ™”๋ฉด"""
42
+ return templates.TemplateResponse("put.html", {"request": request})
43
+
44
+
45
+ @app.get("/confirm.html", response_class=HTMLResponse)
46
+ async def page_confirm(request: Request):
47
+ """confirm.HTML ํ™”๋ฉด"""
48
+ return templates.TemplateResponse("confirm.html",{
49
+ "request": request, "file_content": app_data.file_content})
50
+
51
+
52
+ @app.get("/result.html", response_class=HTMLResponse)
53
+ async def page_result(request: Request):
54
+ """result.HTML ํ™”๋ฉด"""
55
+ return templates.TemplateResponse("result.html", {"request": request})
56
+
57
+
58
+ @app.get("/user.html", response_class=HTMLResponse)
59
+ async def page_user(request: Request):
60
+ """user.HTML ํ™”๋ฉด"""
61
+ return templates.TemplateResponse("user.html", {"request": request})
62
+
63
+
64
+ @app.get("/final.html", response_class=HTMLResponse)
65
+ async def page_final(request: Request):
66
+ """final.HTML ํ™”๋ฉด"""
67
+ return templates.TemplateResponse("final.html", {"request": request,
68
+ "output": app_data.end_output,
69
+ "place": app_data.place,
70
+ "time": app_data.times})
71
+
72
+
73
+ @app.post("/upload", response_class=HTMLResponse)
74
+ async def upload_file(file: UploadFile = File(...)):
75
+ """ํŒŒ์ผ ์—…๋กœ๋“œ ๋ฐ ์ €์žฅ"""
76
+ with open("uploads/" + file.filename, "wb") as f:
77
+ f.write(file.file.read())
78
+
79
+ with open("uploads/" + file.filename, "r", encoding="utf-8") as f:
80
+ app_data.file_content = f.read()
81
+
82
+ return RedirectResponse(url="/put.html")
83
+
84
+
85
+ @app.post("/ners", response_class=JSONResponse)
86
+ async def ner_file():
87
+ """์ €์žฅ๋œ ํŒŒ์ผ์— NER ์ž‘์—…์„ ํ•ด์„œ ํ™”์ž๋ž‘ ์žฅ์†Œ๋ฅผ ๊ตฌ๋ถ„"""
88
+ from utils.load_model import load_ner
89
+ from utils.input_process import make_ner_input
90
+ from utils.ner_utils import make_name_list, show_name_list, combine_similar_names
91
+
92
+ content = app_data.file_content
93
+ _, ner_checkpoint = load_ner()
94
+
95
+ contents = make_ner_input(content)
96
+ name_list, times, places = make_name_list(contents, ner_checkpoint)
97
+ name_dic = show_name_list(name_list)
98
+ similar_name = combine_similar_names(name_dic)
99
+ result_list = [', '.join(names) for names, _ in similar_name.items()]
100
+ app_data.place = ' '.join(places)
101
+ app_data.times = ' '.join(times)
102
+
103
+
104
+ # JSONResponse๋กœ ์‘๋‹ต
105
+ return JSONResponse(content={"itemList": result_list})
106
+
107
+
108
+ @app.post("/kcsn", response_class=JSONResponse)
109
+ async def kcsn_file(request_data: ItemListRequest):
110
+ """์‚ฌ์šฉ์ž๊ฐ€ ์˜ฌ๋ ค์ค€ ํŒŒ์ผ์— ๋Œ€ํ•ด์„œ KCSN ๋ชจ๋ธ ๋™์ž‘"""
111
+ import torch
112
+ from utils.fs_utils import get_alias2id, find_speak, making_script
113
+ from utils.input_process import make_instance_list, input_data_loader
114
+ from utils.train_model import KCSN
115
+ from utils.ner_utils import convert_name2codename, convert_codename2name
116
+
117
+ content = app_data.file_content
118
+ name_list = request_data.nameList
119
+ name_dic = {}
120
+
121
+ for idx, name in enumerate(name_list):
122
+ name_dic[f'&C{idx:02d}&'] = name.split(', ')
123
+
124
+ content_re = convert_name2codename(name_dic, content)
125
+
126
+ # checkpoint = torch.load('./model/final.pth')
127
+ # model = checkpoint['model']
128
+ # model.to('cpu')
129
+ # tokenizer = checkpoint['tokenizer']
130
+
131
+ from utils.arguments import get_train_args
132
+ from transformers import AutoTokenizer
133
+
134
+ args = get_train_args()
135
+ path ='model/model.ckpt'
136
+ model = KCSN(args)
137
+ model.to('cpu')
138
+
139
+ checkpoint = torch.load(path)
140
+ tokenizer = AutoTokenizer.from_pretrained(args.bert_pretrained_dir)
141
+ model.load_state_dict(checkpoint['model'])
142
+
143
+ check_name = 'data/name.txt'
144
+ alias2id = get_alias2id(check_name)
145
+ instances, instance_num = make_instance_list(content_re)
146
+ inputs = input_data_loader(instances, alias2id)
147
+ output = find_speak(model, inputs, tokenizer, alias2id)
148
+ outputs = convert_codename2name(name_dic, output)
149
+ app_data.end_output = making_script(content, outputs, instance_num)
150
+
151
+
152
+ if __name__ == "__main__":
153
+ import uvicorn
154
+
155
+ uvicorn.run(app, host="127.0.0.1", port=8000)