sksameermujahid commited on
Commit
4fea5a9
·
verified ·
1 Parent(s): e68a0f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -28
app.py CHANGED
@@ -28,9 +28,9 @@ import numpy as np
28
  import tempfile
29
  from pathlib import Path
30
 
31
- # Add at the top of app.py after imports
32
- if not hasattr(np, '__version__') or tuple(map(int, np.__version__.split('.'))) < (1, 25, 0):
33
- raise ImportError("This application requires numpy >= 1.25.0")
34
 
35
  # Configure logging
36
  logging.basicConfig(level=logging.INFO)
@@ -69,32 +69,34 @@ def load_sentence_transformer():
69
  cache_dir = Path('/cache')
70
  cache_dir.mkdir(parents=True, exist_ok=True)
71
 
72
- # Ensure numpy version compatibility
73
- if tuple(map(int, np.__version__.split('.'))) >= (1, 25, 0):
74
- model_embedding = SentenceTransformer(
75
- "jinaai/jina-embeddings-v3",
76
- trust_remote_code=True,
77
- cache_folder=str(cache_dir)
78
- ).to(device)
79
-
80
- if os.path.exists(model_path):
81
- state_dict = torch.load(model_path, map_location=device)
82
-
83
- # Handle tensor types
84
- for key, tensor in state_dict.items():
85
- if hasattr(tensor, 'dequantize'):
86
- state_dict[key] = tensor.dequantize().to(dtype=torch.float32)
87
- elif tensor.dtype == torch.bfloat16:
88
- state_dict[key] = tensor.to(dtype=torch.float32)
89
-
90
- model_embedding.load_state_dict(state_dict)
91
- print("SentenceTransformer model loaded successfully.")
92
- else:
93
- print(f"Warning: Model file not found at {model_path}")
94
-
95
- return model_embedding
96
  else:
97
- raise ImportError("Incompatible numpy version")
 
 
98
  except Exception as e:
99
  print(f"Error loading model: {str(e)}")
100
  raise
 
28
  import tempfile
29
  from pathlib import Path
30
 
31
+ # Update the numpy version check at the top of the file
32
+ if not hasattr(np, '__version__') or tuple(map(int, np.__version__.split('.'))) > (1, 24, 0):
33
+ print(f"Warning: Using numpy version {np.__version__}. Some features may not work properly.")
34
 
35
  # Configure logging
36
  logging.basicConfig(level=logging.INFO)
 
69
  cache_dir = Path('/cache')
70
  cache_dir.mkdir(parents=True, exist_ok=True)
71
 
72
+ # Import einops here to ensure it's available
73
+ try:
74
+ import einops
75
+ except ImportError:
76
+ raise ImportError("einops is required. Please install it with 'pip install einops'")
77
+
78
+ model_embedding = SentenceTransformer(
79
+ "jinaai/jina-embeddings-v3",
80
+ trust_remote_code=True,
81
+ cache_folder=str(cache_dir)
82
+ ).to(device)
83
+
84
+ if os.path.exists(model_path):
85
+ state_dict = torch.load(model_path, map_location=device)
86
+
87
+ # Handle tensor types
88
+ for key, tensor in state_dict.items():
89
+ if hasattr(tensor, 'dequantize'):
90
+ state_dict[key] = tensor.dequantize().to(dtype=torch.float32)
91
+ elif tensor.dtype == torch.bfloat16:
92
+ state_dict[key] = tensor.to(dtype=torch.float32)
93
+
94
+ model_embedding.load_state_dict(state_dict)
95
+ print("SentenceTransformer model loaded successfully.")
96
  else:
97
+ print(f"Warning: Model file not found at {model_path}")
98
+
99
+ return model_embedding
100
  except Exception as e:
101
  print(f"Error loading model: {str(e)}")
102
  raise