import sys import os import time from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, status from fastapi.responses import FileResponse from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm import uvicorn import traceback import pickle import shutil from pathlib import Path from contextlib import asynccontextmanager import pandas as pd from typing import Annotated, Optional, Union from datetime import datetime, timedelta, timezone import jwt from jwt.exceptions import InvalidTokenError from passlib.context import CryptContext from pydantic import BaseModel current_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.append(os.path.join(current_dir, "meisai-check-ai")) from sentence_transformer_lib.sentence_transformer_helper import ( SentenceTransformerHelper, ) from data_lib.input_name_data import InputNameData from data_lib.subject_data import SubjectData from data_lib.sample_name_data import SampleNameData from clustering_lib.sentence_clustering_lib import SentenceClusteringLib from data_lib.base_data import ( COL_STANDARD_NAME, COL_STANDARD_NAME_KEY, COL_STANDARD_SUBJECT, ) from mapping_lib.name_mapping_helper import NameMappingHelper # Initialize global variables for model and data sentenceTransformerHelper = None dic_standard_subject = None sample_name_sentence_embeddings = None sample_name_sentence_similarities = None sampleData = None sentence_clustering_lib = None name_groups = None # Create data directory if it doesn't exist os.makedirs(os.path.join(current_dir, "data"), exist_ok=True) os.makedirs(os.path.join(current_dir, "uploads"), exist_ok=True) os.makedirs(os.path.join(current_dir, "outputs"), exist_ok=True) # Authentication related settings SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7" ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_HOURS = 24 # Token expiration set to 24 hours # Password hashing context pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") # OAuth2 scheme for token oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") # User database models class Token(BaseModel): access_token: str token_type: str class TokenData(BaseModel): username: Optional[str] = None class User(BaseModel): username: str email: Optional[str] = None full_name: Optional[str] = None disabled: Optional[bool] = None class UserInDB(User): hashed_password: str # Fake users database with hashed passwords users_db = { "chien_vm": { "username": "chien_vm", "full_name": "Chien VM", "email": "chien_vm@detomo.co.jp", "hashed_password": "$2b$12$RtcKFk7B3hKd7vYkwxdFN.eBXSiryQIRUG.OoJ07Pl9lzHNUkugMi", "disabled": False, }, "hoi_nv": { "username": "hoi_nv", "full_name": "Hoi NV", "email": "hoi_nv@detomo.co.jp", "hashed_password": "$2b$12$RtcKFk7B3hKd7vYkwxdFN.eBXSiryQIRUG.OoJ07Pl9lzHNUkugMi", "disabled": False, } } # Authentication helper functions def verify_password(plain_password, hashed_password): return pwd_context.verify(plain_password, hashed_password) def get_user(db, username: str): if username in db: user_dict = db[username] return UserInDB(**user_dict) return None def authenticate_user(fake_db, username: str, password: str): user = get_user(fake_db, username) if not user: return False if not verify_password(password, user.hashed_password): return False return user def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): to_encode = data.copy() if expires_delta: expire = datetime.now(timezone.utc) + expires_delta else: expire = datetime.now(timezone.utc) + timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS) to_encode.update({"exp": expire}) encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) return encoded_jwt async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]): credentials_exception = HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Could not validate credentials", headers={"WWW-Authenticate": "Bearer"}, ) try: payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) username = payload.get("sub") if username is None: raise credentials_exception token_data = TokenData(username=username) except InvalidTokenError: raise credentials_exception user = get_user(users_db, username=token_data.username) if user is None: raise credentials_exception return user async def get_current_active_user( current_user: Annotated[User, Depends(get_current_user)], ): if current_user.disabled: raise HTTPException(status_code=400, detail="Inactive user") return current_user @asynccontextmanager async def lifespan(app: FastAPI): """Lifespan context manager for startup and shutdown events""" global sentenceTransformerHelper, dic_standard_subject, sample_name_sentence_embeddings global sample_name_sentence_similarities, sampleData, sentence_clustering_lib, name_groups try: # Load sentence transformer model sentenceTransformerHelper = SentenceTransformerHelper( convert_to_zenkaku_flag=True, replace_words=None, keywords=None ) sentenceTransformerHelper.load_model_by_name( "Detomo/cl-nagoya-sup-simcse-ja-for-standard-name-v1_0" ) # Load standard subject dictionary dic_standard_subject = SubjectData.create_standard_subject_dic_from_file( "data/subjectData.csv" ) # Load pre-computed embeddings and similarities with open( f"data/sample_name_sentence_embeddings(cl-nagoya-sup-simcse-ja-for-standard-name-v1_1).pkl", "rb", ) as f: sample_name_sentence_embeddings = pickle.load(f) with open( f"data/sample_name_sentence_similarities(cl-nagoya-sup-simcse-ja-for-standard-name-v1_1).pkl", "rb", ) as f: sample_name_sentence_similarities = pickle.load(f) # Load and process sample data sampleData = SampleNameData() file_path = os.path.join(current_dir, "data", "sampleData.csv") sampleData.load_data_from_csv(file_path) sampleData.process_data() # Create sentence clusters sentence_clustering_lib = SentenceClusteringLib(sample_name_sentence_embeddings) best_name_eps = 0.07 name_groups, _ = sentence_clustering_lib.create_sentence_cluster(best_name_eps) sampleData._create_key_column( COL_STANDARD_NAME_KEY, COL_STANDARD_SUBJECT, COL_STANDARD_NAME ) sampleData.set_name_sentence_labels(name_groups) sampleData.build_search_tree() print("Models and data loaded successfully") except Exception as e: print(f"Error during startup: {e}") traceback.print_exc() yield # This is where the app runs # Cleanup code (if needed) goes here print("Shutting down application") app = FastAPI(lifespan=lifespan) @app.get("/") async def root(): return {"message": "Hello World"} @app.get("/health") async def health_check(): return {"status": "ok", "timestamp": time.time()} @app.post("/token") async def login_for_access_token( form_data: Annotated[OAuth2PasswordRequestForm, Depends()] ) -> Token: """ Login endpoint to get an access token """ user = authenticate_user(users_db, form_data.username, form_data.password) if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect username or password", headers={"WWW-Authenticate": "Bearer"}, ) access_token_expires = timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS) access_token = create_access_token( data={"sub": user.username}, expires_delta=access_token_expires ) return Token(access_token=access_token, token_type="bearer") @app.post("/predict") async def predict( current_user: Annotated[User, Depends(get_current_active_user)], file: UploadFile = File(...) ): """ Process an input CSV file and return standardized names (requires authentication) """ global sentenceTransformerHelper, dic_standard_subject, sample_name_sentence_embeddings global sample_name_sentence_similarities, sampleData, name_groups if not file.filename.endswith(".csv"): raise HTTPException(status_code=400, detail="Only CSV files are supported") # Save uploaded file timestamp = int(time.time()) input_file_path = os.path.join(current_dir, "uploads", f"input_{timestamp}_{current_user.username}.csv") # Use CSV format with correct extension output_file_path = os.path.join(current_dir, "outputs", f"output_{timestamp}_{current_user.username}.csv") try: with open(input_file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) finally: file.file.close() try: # Process input data inputData = InputNameData(dic_standard_subject) inputData.load_data_from_csv(input_file_path) inputData.process_data() # Map standard names nameMappingHelper = NameMappingHelper( sentenceTransformerHelper, inputData, sampleData, sample_name_sentence_embeddings, sample_name_sentence_similarities, ) df_predicted = nameMappingHelper.map_standard_names() # Create output dataframe and save to CSV print("Columns of inputData.dataframe", inputData.dataframe.columns) column_to_keep = ['シート名', '行', '科目', '分類', '名称', '摘要', '備考'] output_df = inputData.dataframe[column_to_keep].copy() output_df.reset_index(drop=False, inplace=True) output_df.loc[:, "出力_科目"] = df_predicted["出力_科目"] output_df.loc[:, "出力_項目名"] = df_predicted["出力_項目名"] output_df.loc[:, "出力_確率度"] = df_predicted["出力_確率度"] # Save with utf_8_sig encoding for Japanese Excel compatibility output_df.to_csv(output_file_path, index=False, encoding="utf_8_sig") # Return the file as a download with correct content type and headers return FileResponse( path=output_file_path, filename=f"output_{Path(file.filename).stem}.csv", media_type="text/csv", headers={ "Content-Disposition": f'attachment; filename="output_{Path(file.filename).stem}.csv"', "Content-Type": "application/x-www-form-urlencoded", }, ) except Exception as e: print(f"Error processing file: {e}") traceback.print_exc() raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000)