Spaces:
Sleeping
Sleeping
| import sys | |
| import os | |
| import time | |
| from fastapi import FastAPI, UploadFile, File, HTTPException, Depends, status | |
| from fastapi.responses import FileResponse | |
| from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm | |
| import uvicorn | |
| import traceback | |
| import pickle | |
| import shutil | |
| from pathlib import Path | |
| from contextlib import asynccontextmanager | |
| import pandas as pd | |
| from typing import Annotated, Optional, Union | |
| from datetime import datetime, timedelta, timezone | |
| import jwt | |
| from jwt.exceptions import InvalidTokenError | |
| from passlib.context import CryptContext | |
| from pydantic import BaseModel | |
| current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| sys.path.append(os.path.join(current_dir, "meisai-check-ai")) | |
| from sentence_transformer_lib.sentence_transformer_helper import ( | |
| SentenceTransformerHelper, | |
| ) | |
| from data_lib.input_name_data import InputNameData | |
| from data_lib.subject_data import SubjectData | |
| from data_lib.sample_name_data import SampleNameData | |
| from clustering_lib.sentence_clustering_lib import SentenceClusteringLib | |
| from data_lib.base_data import ( | |
| COL_STANDARD_NAME, | |
| COL_STANDARD_NAME_KEY, | |
| COL_STANDARD_SUBJECT, | |
| ) | |
| from mapping_lib.name_mapping_helper import NameMappingHelper | |
| # Initialize global variables for model and data | |
| sentenceTransformerHelper = None | |
| dic_standard_subject = None | |
| sample_name_sentence_embeddings = None | |
| sample_name_sentence_similarities = None | |
| sampleData = None | |
| sentence_clustering_lib = None | |
| name_groups = None | |
| # Create data directory if it doesn't exist | |
| os.makedirs(os.path.join(current_dir, "data"), exist_ok=True) | |
| os.makedirs(os.path.join(current_dir, "uploads"), exist_ok=True) | |
| os.makedirs(os.path.join(current_dir, "outputs"), exist_ok=True) | |
| # Authentication related settings | |
| SECRET_KEY = "09d25e094faa6ca2556c818166b7a9563b93f7099f6f0f4caa6cf63b88e8d3e7" | |
| ALGORITHM = "HS256" | |
| ACCESS_TOKEN_EXPIRE_HOURS = 24 # Token expiration set to 24 hours | |
| # Password hashing context | |
| pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto") | |
| # OAuth2 scheme for token | |
| oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") | |
| # User database models | |
| class Token(BaseModel): | |
| access_token: str | |
| token_type: str | |
| class TokenData(BaseModel): | |
| username: Optional[str] = None | |
| class User(BaseModel): | |
| username: str | |
| email: Optional[str] = None | |
| full_name: Optional[str] = None | |
| disabled: Optional[bool] = None | |
| class UserInDB(User): | |
| hashed_password: str | |
| # Fake users database with hashed passwords | |
| users_db = { | |
| "chien_vm": { | |
| "username": "chien_vm", | |
| "full_name": "Chien VM", | |
| "email": "[email protected]", | |
| "hashed_password": "$2b$12$RtcKFk7B3hKd7vYkwxdFN.eBXSiryQIRUG.OoJ07Pl9lzHNUkugMi", | |
| "disabled": False, | |
| }, | |
| "hoi_nv": { | |
| "username": "hoi_nv", | |
| "full_name": "Hoi NV", | |
| "email": "[email protected]", | |
| "hashed_password": "$2b$12$RtcKFk7B3hKd7vYkwxdFN.eBXSiryQIRUG.OoJ07Pl9lzHNUkugMi", | |
| "disabled": False, | |
| } | |
| } | |
| # Authentication helper functions | |
| def verify_password(plain_password, hashed_password): | |
| return pwd_context.verify(plain_password, hashed_password) | |
| def get_user(db, username: str): | |
| if username in db: | |
| user_dict = db[username] | |
| return UserInDB(**user_dict) | |
| return None | |
| def authenticate_user(fake_db, username: str, password: str): | |
| user = get_user(fake_db, username) | |
| if not user: | |
| return False | |
| if not verify_password(password, user.hashed_password): | |
| return False | |
| return user | |
| def create_access_token(data: dict, expires_delta: Optional[timedelta] = None): | |
| to_encode = data.copy() | |
| if expires_delta: | |
| expire = datetime.now(timezone.utc) + expires_delta | |
| else: | |
| expire = datetime.now(timezone.utc) + timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS) | |
| to_encode.update({"exp": expire}) | |
| encoded_jwt = jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) | |
| return encoded_jwt | |
| async def get_current_user(token: Annotated[str, Depends(oauth2_scheme)]): | |
| credentials_exception = HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Could not validate credentials", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| try: | |
| payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) | |
| username = payload.get("sub") | |
| if username is None: | |
| raise credentials_exception | |
| token_data = TokenData(username=username) | |
| except InvalidTokenError: | |
| raise credentials_exception | |
| user = get_user(users_db, username=token_data.username) | |
| if user is None: | |
| raise credentials_exception | |
| return user | |
| async def get_current_active_user( | |
| current_user: Annotated[User, Depends(get_current_user)], | |
| ): | |
| if current_user.disabled: | |
| raise HTTPException(status_code=400, detail="Inactive user") | |
| return current_user | |
| async def lifespan(app: FastAPI): | |
| """Lifespan context manager for startup and shutdown events""" | |
| global sentenceTransformerHelper, dic_standard_subject, sample_name_sentence_embeddings | |
| global sample_name_sentence_similarities, sampleData, sentence_clustering_lib, name_groups | |
| try: | |
| # Load sentence transformer model | |
| sentenceTransformerHelper = SentenceTransformerHelper( | |
| convert_to_zenkaku_flag=True, replace_words=None, keywords=None | |
| ) | |
| sentenceTransformerHelper.load_model_by_name( | |
| "Detomo/cl-nagoya-sup-simcse-ja-for-standard-name-v1_0" | |
| ) | |
| # Load standard subject dictionary | |
| dic_standard_subject = SubjectData.create_standard_subject_dic_from_file( | |
| "data/subjectData.csv" | |
| ) | |
| # Load pre-computed embeddings and similarities | |
| with open( | |
| f"data/sample_name_sentence_embeddings(cl-nagoya-sup-simcse-ja-for-standard-name-v1_1).pkl", | |
| "rb", | |
| ) as f: | |
| sample_name_sentence_embeddings = pickle.load(f) | |
| with open( | |
| f"data/sample_name_sentence_similarities(cl-nagoya-sup-simcse-ja-for-standard-name-v1_1).pkl", | |
| "rb", | |
| ) as f: | |
| sample_name_sentence_similarities = pickle.load(f) | |
| # Load and process sample data | |
| sampleData = SampleNameData() | |
| file_path = os.path.join(current_dir, "data", "sampleData.csv") | |
| sampleData.load_data_from_csv(file_path) | |
| sampleData.process_data() | |
| # Create sentence clusters | |
| sentence_clustering_lib = SentenceClusteringLib(sample_name_sentence_embeddings) | |
| best_name_eps = 0.07 | |
| name_groups, _ = sentence_clustering_lib.create_sentence_cluster(best_name_eps) | |
| sampleData._create_key_column( | |
| COL_STANDARD_NAME_KEY, COL_STANDARD_SUBJECT, COL_STANDARD_NAME | |
| ) | |
| sampleData.set_name_sentence_labels(name_groups) | |
| sampleData.build_search_tree() | |
| print("Models and data loaded successfully") | |
| except Exception as e: | |
| print(f"Error during startup: {e}") | |
| traceback.print_exc() | |
| yield # This is where the app runs | |
| # Cleanup code (if needed) goes here | |
| print("Shutting down application") | |
| app = FastAPI(lifespan=lifespan) | |
| async def root(): | |
| return {"message": "Hello World"} | |
| async def health_check(): | |
| return {"status": "ok", "timestamp": time.time()} | |
| async def login_for_access_token( | |
| form_data: Annotated[OAuth2PasswordRequestForm, Depends()] | |
| ) -> Token: | |
| """ | |
| Login endpoint to get an access token | |
| """ | |
| user = authenticate_user(users_db, form_data.username, form_data.password) | |
| if not user: | |
| raise HTTPException( | |
| status_code=status.HTTP_401_UNAUTHORIZED, | |
| detail="Incorrect username or password", | |
| headers={"WWW-Authenticate": "Bearer"}, | |
| ) | |
| access_token_expires = timedelta(hours=ACCESS_TOKEN_EXPIRE_HOURS) | |
| access_token = create_access_token( | |
| data={"sub": user.username}, expires_delta=access_token_expires | |
| ) | |
| return Token(access_token=access_token, token_type="bearer") | |
| async def predict( | |
| current_user: Annotated[User, Depends(get_current_active_user)], | |
| file: UploadFile = File(...) | |
| ): | |
| """ | |
| Process an input CSV file and return standardized names (requires authentication) | |
| """ | |
| global sentenceTransformerHelper, dic_standard_subject, sample_name_sentence_embeddings | |
| global sample_name_sentence_similarities, sampleData, name_groups | |
| if not file.filename.endswith(".csv"): | |
| raise HTTPException(status_code=400, detail="Only CSV files are supported") | |
| # Save uploaded file | |
| timestamp = int(time.time()) | |
| input_file_path = os.path.join(current_dir, "uploads", f"input_{timestamp}_{current_user.username}.csv") | |
| # Use CSV format with correct extension | |
| output_file_path = os.path.join(current_dir, "outputs", f"output_{timestamp}_{current_user.username}.csv") | |
| try: | |
| with open(input_file_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| finally: | |
| file.file.close() | |
| try: | |
| # Process input data | |
| inputData = InputNameData(dic_standard_subject) | |
| inputData.load_data_from_csv(input_file_path) | |
| inputData.process_data() | |
| # Map standard names | |
| nameMappingHelper = NameMappingHelper( | |
| sentenceTransformerHelper, | |
| inputData, | |
| sampleData, | |
| sample_name_sentence_embeddings, | |
| sample_name_sentence_similarities, | |
| ) | |
| df_predicted = nameMappingHelper.map_standard_names() | |
| # Create output dataframe and save to CSV | |
| print("Columns of inputData.dataframe", inputData.dataframe.columns) | |
| column_to_keep = ['シート名', '行', '科目', '分類', '名称', '摘要', '備考'] | |
| output_df = inputData.dataframe[column_to_keep].copy() | |
| output_df.reset_index(drop=False, inplace=True) | |
| output_df.loc[:, "出力_科目"] = df_predicted["出力_科目"] | |
| output_df.loc[:, "出力_項目名"] = df_predicted["出力_項目名"] | |
| output_df.loc[:, "出力_確率度"] = df_predicted["出力_確率度"] | |
| # Save with utf_8_sig encoding for Japanese Excel compatibility | |
| output_df.to_csv(output_file_path, index=False, encoding="utf_8_sig") | |
| # Return the file as a download with correct content type and headers | |
| return FileResponse( | |
| path=output_file_path, | |
| filename=f"output_{Path(file.filename).stem}.csv", | |
| media_type="text/csv", | |
| headers={ | |
| "Content-Disposition": f'attachment; filename="output_{Path(file.filename).stem}.csv"', | |
| "Content-Type": "application/x-www-form-urlencoded", | |
| }, | |
| ) | |
| except Exception as e: | |
| print(f"Error processing file: {e}") | |
| traceback.print_exc() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |