vumichien commited on
Commit
8600f2c
·
1 Parent(s): ced3dcd

support 2 new endpoints

Browse files
Files changed (3) hide show
  1. main.py +13 -4
  2. models.py +27 -0
  3. routes/predict.py +107 -1
main.py CHANGED
@@ -34,10 +34,19 @@ app = FastAPI(
34
  version="1.0",
35
  lifespan=lifespan,
36
  openapi_tags=[
37
- {"name": "Health", "description": "Health check endpoints"},
38
- {"name": "Authentication", "description": "User authentication and token management"},
39
- {"name": "Prediction", "description": "Predict and process CSV files"},
40
- ]
 
 
 
 
 
 
 
 
 
41
  )
42
 
43
  # Include Routers
 
34
  version="1.0",
35
  lifespan=lifespan,
36
  openapi_tags=[
37
+ {
38
+ "name": "Health",
39
+ "description": "Health check endpoints",
40
+ },
41
+ {
42
+ "name": "Authentication",
43
+ "description": "User authentication and token management",
44
+ },
45
+ {
46
+ "name": "AI Model",
47
+ "description": "AI model endpoints for prediction and embedding",
48
+ },
49
+ ],
50
  )
51
 
52
  # Include Routers
models.py CHANGED
@@ -21,3 +21,30 @@ class UserInDB(User):
21
  class UserCreate(BaseModel):
22
  username: str
23
  password: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  class UserCreate(BaseModel):
22
  username: str
23
  password: str
24
+
25
+
26
+ class EmbeddingRequest(BaseModel):
27
+ sentences: list[str]
28
+
29
+
30
+ class PredictRecord(BaseModel):
31
+ subject: str
32
+ sub_subject: str
33
+ name_category: str
34
+ name: str
35
+ abstract: str | None = None
36
+ memo: str | None = None
37
+
38
+
39
+ class PredictResult(BaseModel):
40
+ standard_subject: str
41
+ standard_name: str
42
+ anchor_name: str
43
+
44
+
45
+ class PredictRawRequest(BaseModel):
46
+ records: list[PredictRecord]
47
+
48
+
49
+ class PredictRawResponse(BaseModel):
50
+ results: list[PredictResult]
routes/predict.py CHANGED
@@ -2,7 +2,7 @@ import os
2
  import time
3
  import shutil
4
  from pathlib import Path
5
- from fastapi import APIRouter, UploadFile, File, HTTPException, Depends
6
  from fastapi.responses import FileResponse
7
  from auth import get_current_user
8
  from services.sentence_transformer_service import SentenceTransformerService, sentence_transformer_service
@@ -10,6 +10,14 @@ from data_lib.input_name_data import InputNameData
10
  from data_lib.base_name_data import COL_NAME_SENTENCE
11
  from mapping_lib.name_mapper import NameMapper
12
  from config import UPLOAD_DIR, OUTPUT_DIR
 
 
 
 
 
 
 
 
13
 
14
  router = APIRouter()
15
 
@@ -85,3 +93,101 @@ async def predict(
85
  except Exception as e:
86
  print(f"Error processing file: {e}")
87
  raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import time
3
  import shutil
4
  from pathlib import Path
5
+ from fastapi import APIRouter, UploadFile, File, HTTPException, Depends, Body
6
  from fastapi.responses import FileResponse
7
  from auth import get_current_user
8
  from services.sentence_transformer_service import SentenceTransformerService, sentence_transformer_service
 
10
  from data_lib.base_name_data import COL_NAME_SENTENCE
11
  from mapping_lib.name_mapper import NameMapper
12
  from config import UPLOAD_DIR, OUTPUT_DIR
13
+ from models import (
14
+ EmbeddingRequest,
15
+ PredictRawRequest,
16
+ PredictRawResponse,
17
+ PredictRecord,
18
+ PredictResult,
19
+ )
20
+ import pandas as pd
21
 
22
  router = APIRouter()
23
 
 
93
  except Exception as e:
94
  print(f"Error processing file: {e}")
95
  raise HTTPException(status_code=500, detail=str(e))
96
+
97
+
98
+ @router.post("/embeddings")
99
+ async def create_embeddings(
100
+ request: EmbeddingRequest,
101
+ current_user=Depends(get_current_user),
102
+ sentence_service: SentenceTransformerService = Depends(
103
+ lambda: sentence_transformer_service
104
+ ),
105
+ ):
106
+ """
107
+ Create embeddings for a list of input sentences (requires authentication)
108
+ """
109
+ try:
110
+ embeddings = sentence_service.sentenceTransformerHelper.create_embeddings(
111
+ request.sentences
112
+ )
113
+ # Convert numpy array to list for JSON serialization
114
+ embeddings_list = embeddings.tolist()
115
+ return {"embeddings": embeddings_list}
116
+ except Exception as e:
117
+ print(f"Error creating embeddings: {e}")
118
+ raise HTTPException(status_code=500, detail=str(e))
119
+
120
+
121
+ @router.post("/predict-raw", response_model=PredictRawResponse)
122
+ async def predict_raw(
123
+ request: PredictRawRequest,
124
+ current_user=Depends(get_current_user),
125
+ sentence_service: SentenceTransformerService = Depends(
126
+ lambda: sentence_transformer_service
127
+ ),
128
+ ):
129
+ """
130
+ Process raw input records and return standardized names (requires authentication)
131
+ """
132
+ try:
133
+ # Convert input records to DataFrame
134
+ records_dict = {
135
+ "科目": [],
136
+ "中科目": [],
137
+ "分類": [],
138
+ "名称": [],
139
+ "摘要": [],
140
+ "備考": [],
141
+ "シート名": [], # Required by BaseNameData but not used
142
+ "行": [], # Required by BaseNameData but not used
143
+ }
144
+
145
+ for record in request.records:
146
+ records_dict["科目"].append(record.subject)
147
+ records_dict["中科目"].append(record.sub_subject)
148
+ records_dict["分類"].append(record.name_category)
149
+ records_dict["名称"].append(record.name)
150
+ records_dict["摘要"].append(record.abstract or "")
151
+ records_dict["備考"].append(record.memo or "")
152
+ records_dict["シート名"].append("") # Placeholder
153
+ records_dict["行"].append("") # Placeholder
154
+
155
+ df = pd.DataFrame(records_dict)
156
+
157
+ # Process input data
158
+ try:
159
+ inputData = InputNameData(sentence_service.dic_standard_subject)
160
+ # Use _add_raw_data instead of direct assignment
161
+ inputData._add_raw_data(df)
162
+ inputData.process_data(sentence_service.sentenceTransformerHelper)
163
+ except Exception as e:
164
+ print(f"Error processing input data: {e}")
165
+ raise HTTPException(status_code=500, detail=str(e))
166
+
167
+ # Map standard names
168
+ try:
169
+ nameMapper = NameMapper(
170
+ sentence_service.sentenceTransformerHelper,
171
+ sentence_service.standardNameMapData,
172
+ top_count=3,
173
+ )
174
+ df_predicted = nameMapper.predict(inputData)
175
+ except Exception as e:
176
+ print(f"Error mapping standard names: {e}")
177
+ raise HTTPException(status_code=500, detail=str(e))
178
+
179
+ # Convert results to response format
180
+ results = []
181
+ for _, row in df_predicted.iterrows():
182
+ result = PredictResult(
183
+ standard_subject=row["標準科目"],
184
+ standard_name=row["標準項目名"],
185
+ anchor_name=row["基準名称"],
186
+ )
187
+ results.append(result)
188
+
189
+ return PredictRawResponse(results=results)
190
+
191
+ except Exception as e:
192
+ print(f"Error processing records: {e}")
193
+ raise HTTPException(status_code=500, detail=str(e))