Commit
·
1308769
1
Parent(s):
437edbb
Updating app.py
Browse filesAdding app with BART and zero shot pipeline to classify transactions
app.py
CHANGED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import pipeline
|
3 |
+
|
4 |
+
title = "Fold: Contextual Tag Recommendation System"
|
5 |
+
description = "powered by bart-large-mnli, made by @abhisheky127"
|
6 |
+
|
7 |
+
classifier = pipeline("zero-shot-classification",
|
8 |
+
model="facebook/bart-large-mnli")
|
9 |
+
|
10 |
+
#define a function to process your input and output
|
11 |
+
def zero_shot(doc, candidates):
|
12 |
+
given_labels = candidates.split(", ")
|
13 |
+
dictionary = classifier(doc, given_labels)
|
14 |
+
labels = dictionary['labels']
|
15 |
+
scores = dictionary['scores']
|
16 |
+
return dict(zip(labels, scores))
|
17 |
+
|
18 |
+
#define a function to preprocess transaction query
|
19 |
+
def preprocess(transaction):
|
20 |
+
pattern = r'([A-Za-z0-9\s]+)(?:/| |$)'
|
21 |
+
match = re.search(pattern, transaction)
|
22 |
+
if match:
|
23 |
+
return match.group(1).strip()
|
24 |
+
return None
|
25 |
+
|
26 |
+
|
27 |
+
#create input and output objects
|
28 |
+
#input object1
|
29 |
+
input1 = gr.Textbox(label="Text")
|
30 |
+
|
31 |
+
#input object 2
|
32 |
+
input2 = gr.Textbox(label="Labels")
|
33 |
+
|
34 |
+
#output object
|
35 |
+
output = gr.Label(label="Output")
|
36 |
+
|
37 |
+
#example object
|
38 |
+
transactions_and_tags = [
|
39 |
+
["MPS/TRUFFLES /202303261700/034587/Bangalore", "Medical,Food,Shopping,Subscription,Travel, Miscellaneous"],
|
40 |
+
["MPS/TACO BELL /202304012247/108300/BANGALORE", "Medical,Food,Shopping,Subscription,Travel, Miscellaneous"],
|
41 |
+
["POS XXXXXXXXXXXX0001 APOLLO PHARMACY", "Medical,Food,Shopping,Subscription,Travel, Miscellaneous"],
|
42 |
+
["BIL/ONL/000471093694/1MG Techno/X7ZRUSVLURFQZO", "Medical,Food,Shopping,Subscription,Travel, Miscellaneous"],
|
43 |
+
["POS XXXXXXXXXXXX1111 DECATHLON SPORTS", "Medical,Food,Shopping,Subscription,Travel, Miscellaneous"],
|
44 |
+
["POS XXXXXXXXXXXX1111 IKEA INDIA PVT L", "Medical,Food,Shopping,Subscription,Travel, Miscellaneous"],
|
45 |
+
["POS XXXXXXXXXXXX1111 WWW AMAZON IN", "Medical,Food,Shopping,Subscription,Travel, Miscellaneous"],
|
46 |
+
["ME DC SI XXXXXXXXXXXX1111 SPOTIFY SI", "Medical,Food,Shopping,Subscription,Travel, Miscellaneous"],
|
47 |
+
["POS/NETFLIX/1140920002/100623/17:25", "Medical,Food,Shopping,Subscription,Travel, Miscellaneous"],
|
48 |
+
["POS XXXXXXXXXXXX1110 MAKEMYTRIP INDIA", "Medical,Food,Shopping,Subscription,Travel, Miscellaneous"],
|
49 |
+
["BIL/ONL/000691178015/IRCTC Serv/XZZBX91LTCY1AZ", "Medical,Food,Shopping,Subscription,Travel, Miscellaneous"]
|
50 |
+
]
|
51 |
+
|
52 |
+
#create interface
|
53 |
+
gui = gr.Interface(title=title,
|
54 |
+
description=description,
|
55 |
+
fn=zero_shot,
|
56 |
+
inputs=[preprocess(input1), input2],
|
57 |
+
outputs=[output],
|
58 |
+
examples=transactions_and_tags)
|
59 |
+
|
60 |
+
#display the interface
|
61 |
+
gui.launch()
|