chitkenkhoi commited on
Commit
6f2560d
·
1 Parent(s): c640f15

GPU to CPU

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -23,8 +23,8 @@ r = redis.Redis(
23
  decode_responses=True
24
  )
25
 
26
- # Device configuration
27
- device = "cuda" if torch.cuda.is_available() else "cpu"
28
 
29
  # Load CSV from Google Drive
30
  def load_csv_from_drive():
@@ -105,8 +105,8 @@ def retrieve_relevant_resources(query_vector, embeddings, similarity_threshold=0
105
  query_embedding = torch.from_numpy(query_vector).to(torch.float32)
106
  if len(query_embedding.shape) == 1:
107
  query_embedding = query_embedding.unsqueeze(0)
108
- query_embedding = query_embedding.cuda()
109
-
110
  if embeddings.shape[1] != query_embedding.shape[1]:
111
  query_embedding = torch.nn.functional.pad(
112
  query_embedding,
@@ -190,8 +190,8 @@ def ask_with_history_v3(query: str, conversation_id: str, isFirst):
190
  embeddings=embeddings
191
  )
192
 
193
- scores_cpu = [score.cpu() for score in scores]
194
- filtered_pairs = [(score, idx) for score, idx in zip(scores_cpu, indices) if score.item() >= threshold]
195
 
196
  if filtered_pairs:
197
  filtered_scores, filtered_indices = zip(*filtered_pairs)
 
23
  decode_responses=True
24
  )
25
 
26
+ # Device configuration - always use CPU
27
+ device = "cpu"
28
 
29
  # Load CSV from Google Drive
30
  def load_csv_from_drive():
 
105
  query_embedding = torch.from_numpy(query_vector).to(torch.float32)
106
  if len(query_embedding.shape) == 1:
107
  query_embedding = query_embedding.unsqueeze(0)
108
+
109
+ # Removed CUDA-specific code
110
  if embeddings.shape[1] != query_embedding.shape[1]:
111
  query_embedding = torch.nn.functional.pad(
112
  query_embedding,
 
190
  embeddings=embeddings
191
  )
192
 
193
+ # No need for CPU conversion since we're already on CPU
194
+ filtered_pairs = [(score.item(), idx) for score, idx in zip(scores, indices) if score.item() >= threshold]
195
 
196
  if filtered_pairs:
197
  filtered_scores, filtered_indices = zip(*filtered_pairs)