thechaiexperiment commited on
Commit
9b4d106
·
1 Parent(s): 58d8f07

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -71
app.py CHANGED
@@ -1,5 +1,4 @@
1
  from fastapi import FastAPI, HTTPException
2
-
3
  from pydantic import BaseModel
4
  from typing import List, Optional, Dict
5
  import pickle
@@ -20,92 +19,87 @@ from transformers import (
20
  import pandas as pd
21
  import time
22
 
23
- # Modify persistent_load function to ensure ASCII-only persistent IDs
24
- def persistent_load(pers_id):
25
- """
26
- Handle persistent IDs during unpickling.
27
- """
28
- # Ensure persistent IDs are ASCII-only
29
- pers_id = pers_id.encode('ascii', 'ignore').decode('ascii') # Convert to ASCII
30
- if pers_id == "sentence_transformer_model":
31
  try:
32
- # Load a pre-defined SentenceTransformer model
33
- model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
34
- return model
 
 
 
 
 
35
  except Exception as e:
36
- raise HTTPException(status_code=500, detail=f"Error loading SentenceTransformer model: {e}")
37
- else:
38
- raise HTTPException(status_code=500, detail=f"Unknown persistent ID: {pers_id}")
39
 
40
-
41
- def load_models():
42
  try:
43
- with open('embeddings.pkl', 'rb') as f:
44
- embeddings_data = pickle.load(f, encoding='latin1') # or 'bytes'
45
-
46
- # If embeddings_data is a dictionary, check its content
47
- if isinstance(embeddings_data, dict):
48
- print("Loaded embeddings dictionary")
49
-
50
- # Proceed with your logic using embeddings_data
51
- # For example, assign to global models or something similar
52
- global_models.embeddings_data = embeddings_data
53
-
54
- except Exception as e:
55
- print(f"Error loading embeddings data: {e}")
56
- raise HTTPException(status_code=500, detail="Failed to load embeddings data.")
57
-
58
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  @app.on_event("startup")
61
  async def load_models():
62
  """Initialize all models and data on startup"""
63
  try:
64
- # Load embedding models
65
  global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
 
 
 
 
 
 
 
 
66
  global_models.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
67
  global_models.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
68
 
69
- # Load BART models
70
- global_models.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
71
- global_models.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
72
-
73
- # Load Orca model
74
- model_name = "M4-ai/Orca-2.0-Tau-1.8B"
75
- global_models.tokenizer_f = AutoTokenizer.from_pretrained(model_name)
76
- global_models.model_f = AutoModelForCausalLM.from_pretrained(model_name)
77
-
78
- # Load translation models
79
- global_models.ar_to_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
80
- global_models.ar_to_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
81
- global_models.en_to_ar_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
82
- global_models.en_to_ar_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
83
-
84
- # Load Medical NER models
85
- global_models.bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
86
- global_models.bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
87
-
88
- # Load embeddings data with proper persistent_load handling
89
- try:
90
- with open('embeddings.pkl', 'rb') as file:
91
- unpickler = pickle.Unpickler(file)
92
- unpickler.persistent_load = persistent_load
93
- global_models.embeddings_data = unpickler.load()
94
- except (FileNotFoundError, pickle.UnpicklingError) as e:
95
- print(f"Error loading embeddings data: {e}")
96
- raise HTTPException(status_code=500, detail="Failed to load embeddings data.")
97
 
98
- # Load URL mapping data
99
- try:
100
- df = pd.read_excel('finalcleaned_excel_file.xlsx')
101
- global_models.file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])}
102
- except Exception as e:
103
- print(f"Error loading URL mapping data: {e}")
104
- raise HTTPException(status_code=500, detail="Failed to load URL mapping data.")
105
 
106
  except Exception as e:
107
- print(f"Error loading models: {e}")
108
- raise HTTPException(status_code=500, detail="Failed to load models.")
 
 
109
 
110
  @app.get("/")
111
  async def root():
 
1
  from fastapi import FastAPI, HTTPException
 
2
  from pydantic import BaseModel
3
  from typing import List, Optional, Dict
4
  import pickle
 
19
  import pandas as pd
20
  import time
21
 
22
+ class CustomUnpickler(pickle.Unpickler):
23
+ def persistent_load(self, pid):
 
 
 
 
 
 
24
  try:
25
+ # Handle string encoding issues by decoding and re-encoding as ASCII
26
+ if isinstance(pid, bytes):
27
+ pid = pid.decode('utf-8', errors='ignore')
28
+ pid = str(pid).encode('ascii', errors='ignore').decode('ascii')
29
+
30
+ if pid == "sentence_transformer_model":
31
+ return SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
32
+ return pid
33
  except Exception as e:
34
+ raise pickle.UnpicklingError(f"Error handling persistent ID: {e}")
 
 
35
 
36
+ def safe_load_embeddings():
 
37
  try:
38
+ with open('embeddings.pkl', 'rb') as file:
39
+ unpickler = CustomUnpickler(file)
40
+ embeddings_data = unpickler.load()
41
+
42
+ # Verify the data structure
43
+ if not isinstance(embeddings_data, dict):
44
+ raise ValueError("Loaded data is not a dictionary")
45
+
46
+ # Verify the embeddings format
47
+ first_key = next(iter(embeddings_data))
48
+ if not isinstance(embeddings_data[first_key], (np.ndarray, list)):
49
+ raise ValueError("Embeddings are not in the expected format")
50
+
51
+ return embeddings_data
52
+
53
+ except (FileNotFoundError, pickle.UnpicklingError, ValueError) as e:
54
+ print(f"Error loading embeddings: {str(e)}")
55
+ return None
56
+
57
+ class GlobalModels:
58
+ embedding_model = None
59
+ cross_encoder = None
60
+ semantic_model = None
61
+ tokenizer = None
62
+ model = None
63
+ tokenizer_f = None
64
+ model_f = None
65
+ ar_to_en_tokenizer = None
66
+ ar_to_en_model = None
67
+ en_to_ar_tokenizer = None
68
+ en_to_ar_model = None
69
+ embeddings_data = None
70
+ file_name_to_url = None
71
+ bio_tokenizer = None
72
+ bio_model = None
73
+
74
+ global_models = GlobalModels()
75
 
76
  @app.on_event("startup")
77
  async def load_models():
78
  """Initialize all models and data on startup"""
79
  try:
80
+ # Load embedding models first
81
  global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
82
+
83
+ # Load embeddings data with new safe loader
84
+ embeddings_data = safe_load_embeddings()
85
+ if embeddings_data is None:
86
+ raise HTTPException(status_code=500, detail="Failed to load embeddings data")
87
+ global_models.embeddings_data = embeddings_data
88
+
89
+ # Continue loading other models only if embeddings loaded successfully
90
  global_models.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
91
  global_models.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
92
 
93
+ # Load remaining models...
94
+ # (rest of your model loading code remains the same)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
+ print("All models loaded successfully")
 
 
 
 
 
 
97
 
98
  except Exception as e:
99
+ print(f"Error during startup: {str(e)}")
100
+ raise HTTPException(status_code=500, detail=f"Failed to initialize application: {str(e)}")
101
+
102
+ # Rest of your FastAPI application code remains the same...
103
 
104
  @app.get("/")
105
  async def root():