napatswift commited on
Commit
2f1c18f
·
1 Parent(s): 48fd428

Refactor get_closest_budget_item function to accept a parameter for the number of results

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -28,10 +28,10 @@ budget_items = df.apply(get_name, axis=1).unique().tolist()
28
 
29
  budget_item_embeddings = torch.stack(list(map(get_embedding, budget_items)))
30
 
31
- def get_closest_budget_item(text):
32
  text_embedding = get_embedding(text)
33
- scores = (budget_item_embeddings * text_embedding).sum(axis=1)
34
- top_idx = scores.argsort(descending=True)[:5]
35
  return pd.DataFrame({
36
  'budget_item': np.array(budget_items)[top_idx],
37
  'score': scores[top_idx].tolist()
@@ -39,7 +39,7 @@ def get_closest_budget_item(text):
39
 
40
  demo = gr.Interface(
41
  fn=get_closest_budget_item,
42
- inputs='textbox',
43
  outputs='dataframe',
44
  )
45
 
 
28
 
29
  budget_item_embeddings = torch.stack(list(map(get_embedding, budget_items)))
30
 
31
+ def get_closest_budget_item(text, num_results=5):
32
  text_embedding = get_embedding(text)
33
+ scores = torch.norm(budget_item_embeddings - text_embedding, dim=1)
34
+ top_idx = scores.argsort()[:num_results]
35
  return pd.DataFrame({
36
  'budget_item': np.array(budget_items)[top_idx],
37
  'score': scores[top_idx].tolist()
 
39
 
40
  demo = gr.Interface(
41
  fn=get_closest_budget_item,
42
+ inputs=['textbox', gr.Slider(minimum=1, maximum=50, step=5, default=5, label="Number of results")],
43
  outputs='dataframe',
44
  )
45