sksameermujahid commited on
Commit
213ba13
·
verified ·
1 Parent(s): 280956d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -22
app.py CHANGED
@@ -28,6 +28,10 @@ import numpy as np
28
  import tempfile
29
  from pathlib import Path
30
 
 
 
 
 
31
  # Configure logging
32
  logging.basicConfig(level=logging.INFO)
33
 
@@ -65,29 +69,32 @@ def load_sentence_transformer():
65
  cache_dir = Path('/cache')
66
  cache_dir.mkdir(parents=True, exist_ok=True)
67
 
68
- model_embedding = SentenceTransformer(
69
- "jinaai/jina-embeddings-v3",
70
- trust_remote_code=True,
71
- cache_folder=str(cache_dir)
72
- ).to(device)
73
-
74
- # Load and optimize model state dict
75
- if os.path.exists(model_path):
76
- state_dict = torch.load(model_path, map_location=device)
77
-
78
- # Dequantize if needed
79
- for key, tensor in state_dict.items():
80
- if hasattr(tensor, 'dequantize'):
81
- state_dict[key] = tensor.dequantize().to(dtype=torch.float32)
82
- elif tensor.dtype == torch.bfloat16:
83
- state_dict[key] = tensor.to(dtype=torch.float32)
84
-
85
- model_embedding.load_state_dict(state_dict)
86
- print("SentenceTransformer model loaded successfully.")
 
 
 
 
 
87
  else:
88
- print(f"Warning: Model file not found at {model_path}")
89
-
90
- return model_embedding
91
  except Exception as e:
92
  print(f"Error loading model: {str(e)}")
93
  raise
 
28
  import tempfile
29
  from pathlib import Path
30
 
31
+ # Add at the top of app.py after imports
32
+ if not hasattr(np, '__version__') or tuple(map(int, np.__version__.split('.'))) < (1, 25, 0):
33
+ raise ImportError("This application requires numpy >= 1.25.0")
34
+
35
  # Configure logging
36
  logging.basicConfig(level=logging.INFO)
37
 
 
69
  cache_dir = Path('/cache')
70
  cache_dir.mkdir(parents=True, exist_ok=True)
71
 
72
+ # Ensure numpy version compatibility
73
+ if tuple(map(int, np.__version__.split('.'))) >= (1, 25, 0):
74
+ model_embedding = SentenceTransformer(
75
+ "jinaai/jina-embeddings-v3",
76
+ trust_remote_code=True,
77
+ cache_folder=str(cache_dir)
78
+ ).to(device)
79
+
80
+ if os.path.exists(model_path):
81
+ state_dict = torch.load(model_path, map_location=device)
82
+
83
+ # Handle tensor types
84
+ for key, tensor in state_dict.items():
85
+ if hasattr(tensor, 'dequantize'):
86
+ state_dict[key] = tensor.dequantize().to(dtype=torch.float32)
87
+ elif tensor.dtype == torch.bfloat16:
88
+ state_dict[key] = tensor.to(dtype=torch.float32)
89
+
90
+ model_embedding.load_state_dict(state_dict)
91
+ print("SentenceTransformer model loaded successfully.")
92
+ else:
93
+ print(f"Warning: Model file not found at {model_path}")
94
+
95
+ return model_embedding
96
  else:
97
+ raise ImportError("Incompatible numpy version")
 
 
98
  except Exception as e:
99
  print(f"Error loading model: {str(e)}")
100
  raise