Update Demo.py
Browse files
Demo.py
CHANGED
@@ -60,7 +60,7 @@ def create_pipeline(model, task, zeroShotLables=['']):
|
|
60 |
.pretrained(model, "en") \
|
61 |
.setInputCols(["sentence", "token"]) \
|
62 |
.setOutputCol("ner") \
|
63 |
-
.setCaseSensitive(
|
64 |
.setMaxSentenceLength(512)
|
65 |
|
66 |
ner_converter = NerConverter() \
|
@@ -221,7 +221,10 @@ except:
|
|
221 |
|
222 |
# Initialize Spark and create pipeline
|
223 |
spark = init_spark()
|
224 |
-
|
|
|
|
|
|
|
225 |
output = fit_data(pipeline, text_to_analyze, task)
|
226 |
|
227 |
# Display matched sentence
|
@@ -243,5 +246,4 @@ elif task == 'Zero-Shot Classification':
|
|
243 |
|
244 |
elif task == 'Sequence Classification':
|
245 |
st.markdown(f"Classified as : **{output[0]['class'][0].result}**")
|
246 |
-
|
247 |
|
|
|
60 |
.pretrained(model, "en") \
|
61 |
.setInputCols(["sentence", "token"]) \
|
62 |
.setOutputCol("ner") \
|
63 |
+
.setCaseSensitive(False) \
|
64 |
.setMaxSentenceLength(512)
|
65 |
|
66 |
ner_converter = NerConverter() \
|
|
|
221 |
|
222 |
# Initialize Spark and create pipeline
|
223 |
spark = init_spark()
|
224 |
+
if task == 'Zero-Shot Classification':
|
225 |
+
pipeline = create_pipeline(model, task, zeroShotLables)
|
226 |
+
else:
|
227 |
+
pipeline = create_pipeline(model, task)
|
228 |
output = fit_data(pipeline, text_to_analyze, task)
|
229 |
|
230 |
# Display matched sentence
|
|
|
246 |
|
247 |
elif task == 'Sequence Classification':
|
248 |
st.markdown(f"Classified as : **{output[0]['class'][0].result}**")
|
|
|
249 |
|