jonathanjordan21 commited on
Commit
eb89c17
·
verified ·
1 Parent(s): 625565d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -4
app.py CHANGED
@@ -387,8 +387,8 @@ model = None
387
  codes_emb = None
388
 
389
  def load_model(model_id):
390
- global model
391
- global codes_emb
392
  if model_id in model_ids[-2:]:
393
  model = CrossEncoder(
394
  # "jinaai/jina-reranker-v2-base-multilingual",
@@ -397,13 +397,15 @@ def load_model(model_id):
397
  automodel_args={"torch_dtype": "auto"},
398
  trust_remote_code=True,
399
  )
 
400
  else:
401
  model = SentenceTransformer(model_id, trust_remote_code=True)
402
  # codes_emb = model.encode([x[6:] for x in codes])
403
  codes_emb = model.encode([x["examples"] for x in examples])#.mean(axis=1)
404
  # codes_emb = np.mean([model.encode(x["examples"]) for x in examples], axis=1)
 
405
 
406
- load_model(model_id)
407
 
408
  # for x in examples:
409
  # codes_emb.append(model.encode(x["examples"]))
@@ -711,7 +713,7 @@ def reload(chosen_model_id):
711
  global codes_emb
712
 
713
  if chosen_model_id != model_id:
714
- load_model(model_id)
715
  # model = SentenceTransformer(chosen_model_id, trust_remote_code=True)
716
  # model_id = chosen_model_id
717
  # codes_emb = model.encode([x[6:] for x in codes])
 
387
  codes_emb = None
388
 
389
  def load_model(model_id):
390
+ # global model
391
+ # global codes_emb
392
  if model_id in model_ids[-2:]:
393
  model = CrossEncoder(
394
  # "jinaai/jina-reranker-v2-base-multilingual",
 
397
  automodel_args={"torch_dtype": "auto"},
398
  trust_remote_code=True,
399
  )
400
+ return model, _
401
  else:
402
  model = SentenceTransformer(model_id, trust_remote_code=True)
403
  # codes_emb = model.encode([x[6:] for x in codes])
404
  codes_emb = model.encode([x["examples"] for x in examples])#.mean(axis=1)
405
  # codes_emb = np.mean([model.encode(x["examples"]) for x in examples], axis=1)
406
+ return model, codes_emb
407
 
408
+ model, codes_emb = load_model(model_id)
409
 
410
  # for x in examples:
411
  # codes_emb.append(model.encode(x["examples"]))
 
713
  global codes_emb
714
 
715
  if chosen_model_id != model_id:
716
+ model, codes_emb = load_model(chosen_model_id)
717
  # model = SentenceTransformer(chosen_model_id, trust_remote_code=True)
718
  # model_id = chosen_model_id
719
  # codes_emb = model.encode([x[6:] for x in codes])