Update Demo.py
Browse files
Demo.py
CHANGED
|
@@ -60,7 +60,7 @@ def create_pipeline(model, task, zeroShotLables=['']):
|
|
| 60 |
.pretrained(model, "en") \
|
| 61 |
.setInputCols(["sentence", "token"]) \
|
| 62 |
.setOutputCol("ner") \
|
| 63 |
-
.setCaseSensitive(
|
| 64 |
.setMaxSentenceLength(512)
|
| 65 |
|
| 66 |
ner_converter = NerConverter() \
|
|
@@ -221,7 +221,10 @@ except:
|
|
| 221 |
|
| 222 |
# Initialize Spark and create pipeline
|
| 223 |
spark = init_spark()
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
| 225 |
output = fit_data(pipeline, text_to_analyze, task)
|
| 226 |
|
| 227 |
# Display matched sentence
|
|
@@ -243,5 +246,4 @@ elif task == 'Zero-Shot Classification':
|
|
| 243 |
|
| 244 |
elif task == 'Sequence Classification':
|
| 245 |
st.markdown(f"Classified as : **{output[0]['class'][0].result}**")
|
| 246 |
-
|
| 247 |
|
|
|
|
| 60 |
.pretrained(model, "en") \
|
| 61 |
.setInputCols(["sentence", "token"]) \
|
| 62 |
.setOutputCol("ner") \
|
| 63 |
+
.setCaseSensitive(False) \
|
| 64 |
.setMaxSentenceLength(512)
|
| 65 |
|
| 66 |
ner_converter = NerConverter() \
|
|
|
|
| 221 |
|
| 222 |
# Initialize Spark and create pipeline
|
| 223 |
spark = init_spark()
|
| 224 |
+
if task == 'Zero-Shot Classification':
|
| 225 |
+
pipeline = create_pipeline(model, task, zeroShotLables)
|
| 226 |
+
else:
|
| 227 |
+
pipeline = create_pipeline(model, task)
|
| 228 |
output = fit_data(pipeline, text_to_analyze, task)
|
| 229 |
|
| 230 |
# Display matched sentence
|
|
|
|
| 246 |
|
| 247 |
elif task == 'Sequence Classification':
|
| 248 |
st.markdown(f"Classified as : **{output[0]['class'][0].result}**")
|
|
|
|
| 249 |
|