abdullahmubeen10 commited on
Commit
dfac0b9
·
verified ·
1 Parent(s): 049a87e

Update Demo.py

Browse files
Files changed (1) hide show
  1. Demo.py +5 -3
Demo.py CHANGED
@@ -60,7 +60,7 @@ def create_pipeline(model, task, zeroShotLables=['']):
60
  .pretrained(model, "en") \
61
  .setInputCols(["sentence", "token"]) \
62
  .setOutputCol("ner") \
63
- .setCaseSensitive(True) \
64
  .setMaxSentenceLength(512)
65
 
66
  ner_converter = NerConverter() \
@@ -221,7 +221,10 @@ except:
221
 
222
  # Initialize Spark and create pipeline
223
  spark = init_spark()
224
- pipeline = create_pipeline(model, task)
 
 
 
225
  output = fit_data(pipeline, text_to_analyze, task)
226
 
227
  # Display matched sentence
@@ -243,5 +246,4 @@ elif task == 'Zero-Shot Classification':
243
 
244
  elif task == 'Sequence Classification':
245
  st.markdown(f"Classified as : **{output[0]['class'][0].result}**")
246
-
247
 
 
60
  .pretrained(model, "en") \
61
  .setInputCols(["sentence", "token"]) \
62
  .setOutputCol("ner") \
63
+ .setCaseSensitive(False) \
64
  .setMaxSentenceLength(512)
65
 
66
  ner_converter = NerConverter() \
 
221
 
222
  # Initialize Spark and create pipeline
223
  spark = init_spark()
224
+ if task == 'Zero-Shot Classification':
225
+ pipeline = create_pipeline(model, task, zeroShotLables)
226
+ else:
227
+ pipeline = create_pipeline(model, task)
228
  output = fit_data(pipeline, text_to_analyze, task)
229
 
230
  # Display matched sentence
 
246
 
247
  elif task == 'Sequence Classification':
248
  st.markdown(f"Classified as : **{output[0]['class'][0].result}**")
 
249