Update my_model/tabs/results.py
Browse files- my_model/tabs/results.py +42 -13
my_model/tabs/results.py
CHANGED
@@ -47,19 +47,48 @@ class ResultDemonstrator(KBVQAEvaluator):
|
|
47 |
# Display the plot in Streamlit
|
48 |
st.pyplot(fig)
|
49 |
|
50 |
-
|
51 |
-
|
|
|
|
|
52 |
|
53 |
-
#
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
"color": ["red"] * 5 + ["blue"] * 5
|
58 |
-
})
|
59 |
|
60 |
-
#
|
61 |
-
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
-
# Display the
|
64 |
-
st.
|
65 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
47 |
# Display the plot in Streamlit
|
48 |
st.pyplot(fig)
|
49 |
|
50 |
+
############
|
51 |
+
|
52 |
+
# Load data from Excel
|
53 |
+
d = pd.read_excel('my_model/results/evaluation_results.xlsx', sheet_name="Main Data")
|
54 |
|
55 |
+
# Assume 'accuracies' and 'token_counts' need to be computed or are columns in the DataFrame
|
56 |
+
# Compute colors and labels for the plot (assuming these columns are already in the DataFrame)
|
57 |
+
d['color'] = d['accuracy'].apply(lambda x: 'green' if x == 1 else 'orange' if round(x, 2) == 0.67 else 'red')
|
58 |
+
d['label'] = d['accuracy'].apply(lambda x: 'Correct' if x == 1 else 'Partially Correct' if x == 0.67 else 'Incorrect')
|
|
|
|
|
59 |
|
60 |
+
# Creating the scatter plot
|
61 |
+
scatter_chart = alt.Chart(d).mark_circle(size=20).encode(
|
62 |
+
x=alt.X('index:Q', title='Index'), # Assuming 'index' is a column or can be created
|
63 |
+
y=alt.Y('token_counts:Q', title='Number of Tokens'),
|
64 |
+
color=alt.Color('color:N', legend=alt.Legend(title="VQA Score")),
|
65 |
+
tooltip=['index', 'token_counts', 'label']
|
66 |
+
).interactive()
|
67 |
|
68 |
+
# Display the chart
|
69 |
+
st.altair_chart(scatter_chart, use_container_width=True)
|
70 |
+
|
71 |
+
|
72 |
+
####################
|
73 |
+
|
74 |
+
# Define colors and labels for the legend
|
75 |
+
colors = ['green' if accuracy == 1 else 'orange' if round(accuracy,2) == 0.67 else 'red' for accuracy in accuracies]
|
76 |
+
labels = ['Correct' if accuracy == 1 else 'Partially Correct' if accuracy == 0.67 else 'Incorrect' for accuracy in accuracies]
|
77 |
+
plt.figure(figsize=(10, 6))
|
78 |
+
# Create a scatter plot with smaller dots using the 's' parameter
|
79 |
+
scatter = plt.scatter(range(len(token_counts)), token_counts, c=colors, s=20, label=labels)
|
80 |
+
|
81 |
+
# Create a custom legend
|
82 |
+
from matplotlib.lines import Line2D
|
83 |
+
legend_elements = [Line2D([0], [0], marker='o', color='w', label='Full VQA Score', markerfacecolor='green', markersize=10),
|
84 |
+
Line2D([0], [0], marker='o', color='w', label='Partial VQA Score', markerfacecolor='orange', markersize=10),
|
85 |
+
Line2D([0], [0], marker='o', color='w', label='Zero VQA Score', markerfacecolor='red', markersize=10)]
|
86 |
+
#plt.legend(handles=legend_elements, loc='upper right')
|
87 |
+
plt.legend(handles=legend_elements, loc='best', bbox_to_anchor=(1, 1))
|
88 |
+
# Set the title and labels
|
89 |
+
plt.title('Token Counts VS VQA Score')
|
90 |
+
plt.xlabel('Index')
|
91 |
+
plt.ylabel('Number of Tokens')
|
92 |
+
|
93 |
+
# Display the plot
|
94 |
+
plt.show()
|