cchukwu commited on
Commit
5807b0b
·
verified ·
1 Parent(s): 115af97

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -5
app.py CHANGED
@@ -53,16 +53,14 @@ def adr_predict(x):
53
  scores = output[0][0].detach()
54
  scores = torch.nn.functional.softmax(scores)
55
 
56
- shap_values = explainer([str("The young woman had a severe drug reaction.").lower()])
57
  # # Find the index of the class you want as the default reference (e.g., 'label_1')
58
  # label_1_index = np.where(np.array(explainer.output_names) == 'label_1')[0][0]
59
- label_1_index = np.where(np.array(explainer.output_names) == 'label_1')[0][0]
60
 
61
  # # Plot the SHAP values for a specific instance in your dataset (e.g., instance 0)
62
  # shap.plots.text(shap_values[label_1_index][0])
63
- shap.plots.text(shap_values[label_1_index][1])
64
-
65
- local_plot = shap.plots.text(shap_values[1], display=False)
66
 
67
  # med = med_score(classifier(x+str(", There is a medication."))[0])
68
  # sym = sym_score(classifier(x+str(", There is a symptom."))[0])
 
53
  scores = output[0][0].detach()
54
  scores = torch.nn.functional.softmax(scores)
55
 
56
+ shap_values = explainer([str(x).lower()])
57
  # # Find the index of the class you want as the default reference (e.g., 'label_1')
58
  # label_1_index = np.where(np.array(explainer.output_names) == 'label_1')[0][0]
 
59
 
60
  # # Plot the SHAP values for a specific instance in your dataset (e.g., instance 0)
61
  # shap.plots.text(shap_values[label_1_index][0])
62
+
63
+ local_plot = shap.plots.text(shap_values[0], display=False)
 
64
 
65
  # med = med_score(classifier(x+str(", There is a medication."))[0])
66
  # sym = sym_score(classifier(x+str(", There is a symptom."))[0])