import os import time import shutil from pathlib import Path from fastapi import APIRouter, UploadFile, File, HTTPException, Depends from fastapi.responses import FileResponse from auth import get_current_user from services.sentence_transformer_service import SentenceTransformerService, sentence_transformer_service from data_lib.input_name_data import InputNameData from data_lib.base_data import COL_NAME_SENTENCE from mapping_lib.name_mapping_helper import NameMappingHelper from config import UPLOAD_DIR, OUTPUT_DIR router = APIRouter() @router.post("/predict") async def predict( current_user=Depends(get_current_user), file: UploadFile = File(...), sentence_service: SentenceTransformerService = Depends(lambda: sentence_transformer_service) ): """ Process an input CSV file and return standardized names (requires authentication) """ if not file.filename.endswith(".csv"): raise HTTPException(status_code=400, detail="Only CSV files are supported") # Save uploaded file timestamp = int(time.time()) input_file_path = os.path.join(UPLOAD_DIR, f"input_{timestamp}_{current_user.username}.csv") output_file_path = os.path.join(OUTPUT_DIR, f"output_{timestamp}_{current_user.username}.csv") try: with open(input_file_path, "wb") as buffer: shutil.copyfileobj(file.file, buffer) finally: file.file.close() try: # Process input data inputData = InputNameData(sentence_service.dic_standard_subject) inputData.load_data_from_csv(input_file_path) inputData.process_data() input_name_sentences = inputData.dataframe[COL_NAME_SENTENCE] input_name_sentence_embeddings = sentence_service.sentenceTransformerHelper.create_embeddings(input_name_sentences) # Create similarity matrix similarity_matrix = sentence_service.sentenceTransformerHelper.create_similarity_matrix_from_embeddings( sentence_service.sample_name_sentence_embeddings, input_name_sentence_embeddings ) # Map standard names nameMappingHelper = NameMappingHelper( sentence_service.sentenceTransformerHelper, inputData, sentence_service.sampleData, input_name_sentence_embeddings, sentence_service.sample_name_sentence_embeddings, similarity_matrix, ) df_predicted = nameMappingHelper.map_standard_names() # Create output dataframe and save to CSV column_to_keep = ['シート名', '行', '科目', '分類', '名称', '摘要', '備考'] output_df = inputData.dataframe[column_to_keep].copy() output_df.reset_index(drop=False, inplace=True) output_df.loc[:, "出力_科目"] = df_predicted["出力_科目"] output_df.loc[:, "出力_項目名"] = df_predicted["出力_項目名"] output_df.loc[:, "出力_確率度"] = df_predicted["出力_確率度"] # Save with utf_8_sig encoding for Japanese Excel compatibility output_df.to_csv(output_file_path, index=False, encoding="utf_8_sig") return FileResponse( path=output_file_path, filename=f"output_{Path(file.filename).stem}.csv", media_type="text/csv", headers={ "Content-Disposition": f'attachment; filename="output_{Path(file.filename).stem}.csv"', "Content-Type": "application/x-www-form-urlencoded", }, ) except Exception as e: print(f"Error processing file: {e}") raise HTTPException(status_code=500, detail=str(e))