thechaiexperiment commited on
Commit
b2cd2f5
·
1 Parent(s): 1a71739

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -8
app.py CHANGED
@@ -39,7 +39,9 @@ def load_models():
39
  try:
40
  # Load embeddings data with custom persistent_load function
41
  with open("models/embeddings.pkl", "rb") as file:
42
- global_models.embeddings_data = pickle.load(file, persistent_load=persistent_load)
 
 
43
  print("Embeddings data loaded successfully.")
44
  except pickle.UnpicklingError as e:
45
  raise HTTPException(status_code=500, detail=f"Unpickling error: {e}")
@@ -49,13 +51,54 @@ def load_models():
49
  app = FastAPI()
50
 
51
  @app.on_event("startup")
52
- async def startup_event():
53
- """
54
- Load models at application startup.
55
- """
56
- print("Loading models...")
57
- load_models()
58
- print("Models loaded.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  @app.get("/")
61
  async def root():
 
39
  try:
40
  # Load embeddings data with custom persistent_load function
41
  with open("models/embeddings.pkl", "rb") as file:
42
+ unpickler = pickle.Unpickler(file)
43
+ unpickler.persistent_load = persistent_load
44
+ global_models.embeddings_data = unpickler.load()
45
  print("Embeddings data loaded successfully.")
46
  except pickle.UnpicklingError as e:
47
  raise HTTPException(status_code=500, detail=f"Unpickling error: {e}")
 
51
  app = FastAPI()
52
 
53
  @app.on_event("startup")
54
+ async def load_models():
55
+ """Initialize all models and data on startup"""
56
+ try:
57
+ # Load embedding models
58
+ global_models.embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
59
+ global_models.cross_encoder = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2', max_length=512)
60
+ global_models.semantic_model = SentenceTransformer('all-MiniLM-L6-v2')
61
+
62
+ # Load BART models
63
+ global_models.tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
64
+ global_models.model = BartForConditionalGeneration.from_pretrained("facebook/bart-base")
65
+
66
+ # Load Orca model
67
+ model_name = "M4-ai/Orca-2.0-Tau-1.8B"
68
+ global_models.tokenizer_f = AutoTokenizer.from_pretrained(model_name)
69
+ global_models.model_f = AutoModelForCausalLM.from_pretrained(model_name)
70
+
71
+ # Load translation models
72
+ global_models.ar_to_en_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
73
+ global_models.ar_to_en_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-ar-en")
74
+ global_models.en_to_ar_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
75
+ global_models.en_to_ar_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-ar")
76
+
77
+ # Load Medical NER models
78
+ global_models.bio_tokenizer = AutoTokenizer.from_pretrained("blaze999/Medical-NER")
79
+ global_models.bio_model = AutoModelForTokenClassification.from_pretrained("blaze999/Medical-NER")
80
+
81
+ # Load embeddings data with proper persistent_load handling
82
+ try:
83
+ with open('embeddings.pkl', 'rb') as file:
84
+ unpickler = pickle.Unpickler(file)
85
+ unpickler.persistent_load = persistent_load
86
+ global_models.embeddings_data = unpickler.load()
87
+ except (FileNotFoundError, pickle.UnpicklingError) as e:
88
+ print(f"Error loading embeddings data: {e}")
89
+ raise HTTPException(status_code=500, detail="Failed to load embeddings data.")
90
+
91
+ # Load URL mapping data
92
+ try:
93
+ df = pd.read_excel('finalcleaned_excel_file.xlsx')
94
+ global_models.file_name_to_url = {f"article_{index}.html": url for index, url in enumerate(df['Unnamed: 0'])}
95
+ except Exception as e:
96
+ print(f"Error loading URL mapping data: {e}")
97
+ raise HTTPException(status_code=500, detail="Failed to load URL mapping data.")
98
+
99
+ except Exception as e:
100
+ print(f"Error loading models: {e}")
101
+ raise HTTPException(status_code=500, detail="Failed to load models.")
102
 
103
  @app.get("/")
104
  async def root():