Spaces:
Sleeping
Sleeping
load model
Browse files
app.py
CHANGED
@@ -13,12 +13,6 @@ from transformers import pipeline
|
|
13 |
|
14 |
from sentence_transformers import SentenceTransformer, util
|
15 |
|
16 |
-
classifier_model = pipeline(
|
17 |
-
"zero-shot-classification", model="MoritzLaurer/deberta-v3-large-zeroshot-v1"
|
18 |
-
)
|
19 |
-
|
20 |
-
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
21 |
-
|
22 |
# get db info from env vars
|
23 |
db_host = os.environ.get("DB_HOST")
|
24 |
db_user = os.environ.get("DB_USER")
|
@@ -77,6 +71,9 @@ potential_labels = get_potential_labels()
|
|
77 |
# Function to handle the classification
|
78 |
def classify_email_and_generate_response(representative_email, constituent_email):
|
79 |
potential_labels = get_potential_labels()
|
|
|
|
|
|
|
80 |
print("classifying email")
|
81 |
model_out = classifier_model(constituent_email, potential_labels, multi_label=True)
|
82 |
print("classification complete")
|
@@ -148,6 +145,7 @@ def get_similar_messages(constituent_email):
|
|
148 |
)
|
149 |
|
150 |
messages_for_category = db_cursor.fetchall()
|
|
|
151 |
|
152 |
all_message_chains = []
|
153 |
|
|
|
13 |
|
14 |
from sentence_transformers import SentenceTransformer, util
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
# get db info from env vars
|
17 |
db_host = os.environ.get("DB_HOST")
|
18 |
db_user = os.environ.get("DB_USER")
|
|
|
71 |
# Function to handle the classification
|
72 |
def classify_email_and_generate_response(representative_email, constituent_email):
|
73 |
potential_labels = get_potential_labels()
|
74 |
+
classifier_model = pipeline(
|
75 |
+
"zero-shot-classification", model="MoritzLaurer/deberta-v3-large-zeroshot-v1"
|
76 |
+
)
|
77 |
print("classifying email")
|
78 |
model_out = classifier_model(constituent_email, potential_labels, multi_label=True)
|
79 |
print("classification complete")
|
|
|
145 |
)
|
146 |
|
147 |
messages_for_category = db_cursor.fetchall()
|
148 |
+
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")
|
149 |
|
150 |
all_message_chains = []
|
151 |
|