thechaiexperiment commited on
Commit
1a71739
·
1 Parent(s): 56b03dc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -0
app.py CHANGED
@@ -19,8 +19,48 @@ from transformers import (
19
  import pandas as pd
20
  import time
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  app = FastAPI()
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  # Models and data structures to store loaded models
25
  class GlobalModels:
26
  embedding_model = None
 
19
  import pandas as pd
20
  import time
21
 
22
+ # Define persistent_load to handle persistent IDs
23
+ def persistent_load(pers_id):
24
+ """
25
+ Handle persistent IDs during unpickling.
26
+ """
27
+ if pers_id == "sentence_transformer_model":
28
+ try:
29
+ # Load a pre-defined SentenceTransformer model
30
+ model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
31
+ return model
32
+ except Exception as e:
33
+ raise HTTPException(status_code=500, detail=f"Error loading SentenceTransformer model: {e}")
34
+ else:
35
+ raise HTTPException(status_code=500, detail=f"Unknown persistent ID: {pers_id}")
36
+
37
+ # Function to load models
38
+ def load_models():
39
+ try:
40
+ # Load embeddings data with custom persistent_load function
41
+ with open("models/embeddings.pkl", "rb") as file:
42
+ global_models.embeddings_data = pickle.load(file, persistent_load=persistent_load)
43
+ print("Embeddings data loaded successfully.")
44
+ except pickle.UnpicklingError as e:
45
+ raise HTTPException(status_code=500, detail=f"Unpickling error: {e}")
46
+ except Exception as e:
47
+ raise HTTPException(status_code=500, detail=f"Failed to load models: {e}")
48
+
49
  app = FastAPI()
50
 
51
+ @app.on_event("startup")
52
+ async def startup_event():
53
+ """
54
+ Load models at application startup.
55
+ """
56
+ print("Loading models...")
57
+ load_models()
58
+ print("Models loaded.")
59
+
60
+ @app.get("/")
61
+ async def root():
62
+ return {"message": "Server is running"}
63
+
64
  # Models and data structures to store loaded models
65
  class GlobalModels:
66
  embedding_model = None