Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -53,16 +53,14 @@ def adr_predict(x):
|
|
53 |
scores = output[0][0].detach()
|
54 |
scores = torch.nn.functional.softmax(scores)
|
55 |
|
56 |
-
shap_values = explainer([str(
|
57 |
# # Find the index of the class you want as the default reference (e.g., 'label_1')
|
58 |
# label_1_index = np.where(np.array(explainer.output_names) == 'label_1')[0][0]
|
59 |
-
label_1_index = np.where(np.array(explainer.output_names) == 'label_1')[0][0]
|
60 |
|
61 |
# # Plot the SHAP values for a specific instance in your dataset (e.g., instance 0)
|
62 |
# shap.plots.text(shap_values[label_1_index][0])
|
63 |
-
|
64 |
-
|
65 |
-
local_plot = shap.plots.text(shap_values[1], display=False)
|
66 |
|
67 |
# med = med_score(classifier(x+str(", There is a medication."))[0])
|
68 |
# sym = sym_score(classifier(x+str(", There is a symptom."))[0])
|
|
|
53 |
scores = output[0][0].detach()
|
54 |
scores = torch.nn.functional.softmax(scores)
|
55 |
|
56 |
+
shap_values = explainer([str(x).lower()])
|
57 |
# # Find the index of the class you want as the default reference (e.g., 'label_1')
|
58 |
# label_1_index = np.where(np.array(explainer.output_names) == 'label_1')[0][0]
|
|
|
59 |
|
60 |
# # Plot the SHAP values for a specific instance in your dataset (e.g., instance 0)
|
61 |
# shap.plots.text(shap_values[label_1_index][0])
|
62 |
+
|
63 |
+
local_plot = shap.plots.text(shap_values[0], display=False)
|
|
|
64 |
|
65 |
# med = med_score(classifier(x+str(", There is a medication."))[0])
|
66 |
# sym = sym_score(classifier(x+str(", There is a symptom."))[0])
|