Spaces:
Running
Running
| import os | |
| import time | |
| import shutil | |
| from pathlib import Path | |
| from fastapi import APIRouter, UploadFile, File, HTTPException, Depends, Body | |
| from fastapi.responses import FileResponse | |
| from auth import get_current_user | |
| from services.sentence_transformer_service import SentenceTransformerService, sentence_transformer_service | |
| from data_lib.input_name_data import InputNameData | |
| from data_lib.base_name_data import COL_NAME_SENTENCE | |
| from mapping_lib.subject_mapper import SubjectMapper | |
| from mapping_lib.name_mapper import NameMapper | |
| from config import UPLOAD_DIR, OUTPUT_DIR | |
| from models import ( | |
| EmbeddingRequest, | |
| PredictRawRequest, | |
| PredictRawResponse, | |
| PredictRecord, | |
| PredictResult, | |
| ) | |
| import pandas as pd | |
| router = APIRouter() | |
| async def predict( | |
| current_user=Depends(get_current_user), | |
| file: UploadFile = File(...), | |
| sentence_service: SentenceTransformerService = Depends(lambda: sentence_transformer_service) | |
| ): | |
| """ | |
| Process an input CSV file and return standardized names (requires authentication) | |
| """ | |
| if not file.filename.endswith(".csv"): | |
| raise HTTPException(status_code=400, detail="Only CSV files are supported") | |
| # Save uploaded file | |
| timestamp = int(time.time()) | |
| input_file_path = os.path.join(UPLOAD_DIR, f"input_{timestamp}_{current_user.username}.csv") | |
| output_file_path = os.path.join(OUTPUT_DIR, f"output_{timestamp}_{current_user.username}.csv") | |
| try: | |
| with open(input_file_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| finally: | |
| file.file.close() | |
| try: | |
| # Process input data | |
| start_time = time.time() | |
| try: | |
| inputData = InputNameData() | |
| inputData.load_data_from_csv(input_file_path) | |
| except Exception as e: | |
| print(f"Error processing input data: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| try: | |
| subject_mapper = SubjectMapper(sentence_service.sentenceTransformerHelper, sentence_service.dic_standard_subject) | |
| dic_subject_map = subject_mapper.map_standard_subjects(inputData.dataframe) | |
| inputData.dic_standard_subject = dic_subject_map | |
| inputData.process_data() | |
| except Exception as e: | |
| print(f"Error processing input data: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Map standard names | |
| try: | |
| nameMapper = NameMapper( | |
| sentence_service.sentenceTransformerHelper, | |
| sentence_service.standardNameMapData, | |
| top_count=3 | |
| ) | |
| df_predicted = nameMapper.predict(inputData) | |
| except Exception as e: | |
| print(f"Error mapping standard names: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Create output dataframe and save to CSV | |
| # column_to_keep = ['ファイル名', 'シート名', '行', '科目', '中科目', '分類', '名称', '摘要', '備考'] | |
| column_to_keep = ['シート名', '行', '科目', '中科目', '分類', '名称', '摘要', '備考'] | |
| output_df = inputData.dataframe[column_to_keep].copy() | |
| output_df.reset_index(drop=False, inplace=True) | |
| output_df.loc[:, "出力_科目"] = df_predicted["標準科目"] | |
| output_df.loc[:, "出力_項目名"] = df_predicted["標準項目名"] | |
| output_df.loc[:, "出力_確率度"] = df_predicted["基準名称類似度"] | |
| # Save with utf_8_sig encoding for Japanese Excel compatibility | |
| output_df.to_csv(output_file_path, index=False, encoding="utf_8_sig") | |
| end_time = time.time() | |
| execution_time = end_time - start_time | |
| print(f"Execution time: {execution_time} seconds") | |
| return FileResponse( | |
| path=output_file_path, | |
| filename=f"output_{Path(file.filename).stem}.csv", | |
| media_type="text/csv", | |
| headers={ | |
| "Content-Disposition": f'attachment; filename="output_{Path(file.filename).stem}.csv"', | |
| "Content-Type": "application/x-www-form-urlencoded", | |
| }, | |
| ) | |
| except Exception as e: | |
| print(f"Error processing file: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def create_embeddings( | |
| request: EmbeddingRequest, | |
| current_user=Depends(get_current_user), | |
| sentence_service: SentenceTransformerService = Depends( | |
| lambda: sentence_transformer_service | |
| ), | |
| ): | |
| """ | |
| Create embeddings for a list of input sentences (requires authentication) | |
| """ | |
| try: | |
| start_time = time.time() | |
| embeddings = sentence_service.sentenceTransformerHelper.create_embeddings( | |
| request.sentences | |
| ) | |
| end_time = time.time() | |
| execution_time = end_time - start_time | |
| print(f"Execution time: {execution_time} seconds") | |
| # Convert numpy array to list for JSON serialization | |
| embeddings_list = embeddings.tolist() | |
| return {"embeddings": embeddings_list} | |
| except Exception as e: | |
| print(f"Error creating embeddings: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def predict_raw( | |
| request: PredictRawRequest, | |
| current_user=Depends(get_current_user), | |
| sentence_service: SentenceTransformerService = Depends( | |
| lambda: sentence_transformer_service | |
| ), | |
| ): | |
| """ | |
| Process raw input records and return standardized names (requires authentication) | |
| """ | |
| try: | |
| # Convert input records to DataFrame | |
| records_dict = { | |
| "科目": [], | |
| "中科目": [], | |
| "分類": [], | |
| "名称": [], | |
| "摘要": [], | |
| "備考": [], | |
| "シート名": [], # Required by BaseNameData but not used | |
| "行": [], # Required by BaseNameData but not used | |
| } | |
| for record in request.records: | |
| records_dict["科目"].append(record.subject) | |
| records_dict["中科目"].append(record.sub_subject) | |
| records_dict["分類"].append(record.name_category) | |
| records_dict["名称"].append(record.name) | |
| records_dict["摘要"].append(record.abstract or "") | |
| records_dict["備考"].append(record.memo or "") | |
| records_dict["シート名"].append("") # Placeholder | |
| records_dict["行"].append("") # Placeholder | |
| df = pd.DataFrame(records_dict) | |
| # Process input data | |
| try: | |
| inputData = InputNameData(sentence_service.dic_standard_subject) | |
| # Use _add_raw_data instead of direct assignment | |
| inputData._add_raw_data(df) | |
| inputData.process_data(sentence_service.sentenceTransformerHelper) | |
| except Exception as e: | |
| print(f"Error processing input data: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Map standard names | |
| try: | |
| nameMapper = NameMapper( | |
| sentence_service.sentenceTransformerHelper, | |
| sentence_service.standardNameMapData, | |
| top_count=3, | |
| ) | |
| df_predicted = nameMapper.predict(inputData) | |
| except Exception as e: | |
| print(f"Error mapping standard names: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| # Convert results to response format | |
| results = [] | |
| for _, row in df_predicted.iterrows(): | |
| result = PredictResult( | |
| standard_subject=row["標準科目"], | |
| standard_name=row["標準項目名"], | |
| anchor_name=row["基準名称"], | |
| similarity_score=float(row["基準名称類似度"]), | |
| ) | |
| results.append(result) | |
| return PredictRawResponse(results=results) | |
| except Exception as e: | |
| print(f"Error processing records: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |