sksameermujahid commited on
Commit
8ffabbf
·
verified ·
1 Parent(s): 430d641

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -12
app.py CHANGED
@@ -25,6 +25,8 @@ from werkzeug.utils import secure_filename
25
  from geopy.geocoders import Nominatim
26
  import pickle
27
  import numpy as np
 
 
28
 
29
  # Configure logging
30
  logging.basicConfig(level=logging.INFO)
@@ -50,24 +52,41 @@ model_dir = "./models/llm_model"
50
  device = "cuda" if torch.cuda.is_available() else "cpu"
51
  print(f"Using device: {device}")
52
 
 
 
 
 
 
53
  # Load SentenceTransformer model
54
  def load_sentence_transformer():
55
  print("Loading SentenceTransformer model...")
56
  try:
57
- model_embedding = SentenceTransformer("jinaai/jina-embeddings-v3", trust_remote_code=True).to(device)
 
 
 
 
 
 
 
 
58
 
59
  # Load and optimize model state dict
60
- state_dict = torch.load(model_path, map_location=device)
61
-
62
- # Dequantize if needed
63
- for key, tensor in state_dict.items():
64
- if hasattr(tensor, 'dequantize'): # Check if tensor is quantized
65
- state_dict[key] = tensor.dequantize().to(dtype=torch.float32) # Convert to FP32
66
- elif tensor.dtype == torch.bfloat16: # Handle bfloat16 tensors
67
- state_dict[key] = tensor.to(dtype=torch.float32) # Convert to FP32
68
-
69
- model_embedding.load_state_dict(state_dict)
70
- print("SentenceTransformer model loaded successfully.")
 
 
 
 
71
  return model_embedding
72
  except Exception as e:
73
  print(f"Error loading model: {str(e)}")
 
25
  from geopy.geocoders import Nominatim
26
  import pickle
27
  import numpy as np
28
+ import tempfile
29
+ from pathlib import Path
30
 
31
  # Configure logging
32
  logging.basicConfig(level=logging.INFO)
 
52
  device = "cuda" if torch.cuda.is_available() else "cpu"
53
  print(f"Using device: {device}")
54
 
55
+ # Configure cache directories
56
+ os.environ['TRANSFORMERS_CACHE'] = '/cache'
57
+ os.environ['HF_HOME'] = '/cache'
58
+ os.environ['XDG_CACHE_HOME'] = '/cache'
59
+
60
  # Load SentenceTransformer model
61
  def load_sentence_transformer():
62
  print("Loading SentenceTransformer model...")
63
  try:
64
+ # Create cache directory if it doesn't exist
65
+ cache_dir = Path('/cache')
66
+ cache_dir.mkdir(parents=True, exist_ok=True)
67
+
68
+ model_embedding = SentenceTransformer(
69
+ "jinaai/jina-embeddings-v3",
70
+ trust_remote_code=True,
71
+ cache_folder=str(cache_dir)
72
+ ).to(device)
73
 
74
  # Load and optimize model state dict
75
+ if os.path.exists(model_path):
76
+ state_dict = torch.load(model_path, map_location=device)
77
+
78
+ # Dequantize if needed
79
+ for key, tensor in state_dict.items():
80
+ if hasattr(tensor, 'dequantize'):
81
+ state_dict[key] = tensor.dequantize().to(dtype=torch.float32)
82
+ elif tensor.dtype == torch.bfloat16:
83
+ state_dict[key] = tensor.to(dtype=torch.float32)
84
+
85
+ model_embedding.load_state_dict(state_dict)
86
+ print("SentenceTransformer model loaded successfully.")
87
+ else:
88
+ print(f"Warning: Model file not found at {model_path}")
89
+
90
  return model_embedding
91
  except Exception as e:
92
  print(f"Error loading model: {str(e)}")