Spaces:
Sleeping
Sleeping
napatswift
commited on
Commit
·
2f1c18f
1
Parent(s):
48fd428
Refactor get_closest_budget_item function to accept a parameter for the number of results
Browse files
app.py
CHANGED
@@ -28,10 +28,10 @@ budget_items = df.apply(get_name, axis=1).unique().tolist()
|
|
28 |
|
29 |
budget_item_embeddings = torch.stack(list(map(get_embedding, budget_items)))
|
30 |
|
31 |
-
def get_closest_budget_item(text):
|
32 |
text_embedding = get_embedding(text)
|
33 |
-
scores = (budget_item_embeddings
|
34 |
-
top_idx = scores.argsort(
|
35 |
return pd.DataFrame({
|
36 |
'budget_item': np.array(budget_items)[top_idx],
|
37 |
'score': scores[top_idx].tolist()
|
@@ -39,7 +39,7 @@ def get_closest_budget_item(text):
|
|
39 |
|
40 |
demo = gr.Interface(
|
41 |
fn=get_closest_budget_item,
|
42 |
-
inputs='textbox',
|
43 |
outputs='dataframe',
|
44 |
)
|
45 |
|
|
|
28 |
|
29 |
budget_item_embeddings = torch.stack(list(map(get_embedding, budget_items)))
|
30 |
|
31 |
+
def get_closest_budget_item(text, num_results=5):
|
32 |
text_embedding = get_embedding(text)
|
33 |
+
scores = torch.norm(budget_item_embeddings - text_embedding, dim=1)
|
34 |
+
top_idx = scores.argsort()[:num_results]
|
35 |
return pd.DataFrame({
|
36 |
'budget_item': np.array(budget_items)[top_idx],
|
37 |
'score': scores[top_idx].tolist()
|
|
|
39 |
|
40 |
demo = gr.Interface(
|
41 |
fn=get_closest_budget_item,
|
42 |
+
inputs=['textbox', gr.Slider(minimum=1, maximum=50, step=5, default=5, label="Number of results")],
|
43 |
outputs='dataframe',
|
44 |
)
|
45 |
|