m7mdal7aj commited on
Commit
e0a8d2a
·
verified ·
1 Parent(s): a56e6a3

Update my_model/tabs/results.py

Browse files
Files changed (1) hide show
  1. my_model/tabs/results.py +42 -13
my_model/tabs/results.py CHANGED
@@ -47,19 +47,48 @@ class ResultDemonstrator(KBVQAEvaluator):
47
  # Display the plot in Streamlit
48
  st.pyplot(fig)
49
 
50
- import plotly.express as px
51
-
 
 
52
 
53
- # Data
54
- df = pd.DataFrame({
55
- "x": range(10),
56
- "y": [2, 1, 4, 3, 5, 6, 9, 7, 10, 8],
57
- "color": ["red"] * 5 + ["blue"] * 5
58
- })
59
 
60
- # Create an interactive scatter plot
61
- fig = px.scatter(df, x='x', y='y', color='color')
 
 
 
 
 
62
 
63
- # Display the plot in Streamlit
64
- st.plotly_chart(fig)
65
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  # Display the plot in Streamlit
48
  st.pyplot(fig)
49
 
50
+ ############
51
+
52
+ # Load data from Excel
53
+ d = pd.read_excel('my_model/results/evaluation_results.xlsx', sheet_name="Main Data")
54
 
55
+ # Assume 'accuracies' and 'token_counts' need to be computed or are columns in the DataFrame
56
+ # Compute colors and labels for the plot (assuming these columns are already in the DataFrame)
57
+ d['color'] = d['accuracy'].apply(lambda x: 'green' if x == 1 else 'orange' if round(x, 2) == 0.67 else 'red')
58
+ d['label'] = d['accuracy'].apply(lambda x: 'Correct' if x == 1 else 'Partially Correct' if x == 0.67 else 'Incorrect')
 
 
59
 
60
+ # Creating the scatter plot
61
+ scatter_chart = alt.Chart(d).mark_circle(size=20).encode(
62
+ x=alt.X('index:Q', title='Index'), # Assuming 'index' is a column or can be created
63
+ y=alt.Y('token_counts:Q', title='Number of Tokens'),
64
+ color=alt.Color('color:N', legend=alt.Legend(title="VQA Score")),
65
+ tooltip=['index', 'token_counts', 'label']
66
+ ).interactive()
67
 
68
+ # Display the chart
69
+ st.altair_chart(scatter_chart, use_container_width=True)
70
+
71
+
72
+ ####################
73
+
74
+ # Define colors and labels for the legend
75
+ colors = ['green' if accuracy == 1 else 'orange' if round(accuracy,2) == 0.67 else 'red' for accuracy in accuracies]
76
+ labels = ['Correct' if accuracy == 1 else 'Partially Correct' if accuracy == 0.67 else 'Incorrect' for accuracy in accuracies]
77
+ plt.figure(figsize=(10, 6))
78
+ # Create a scatter plot with smaller dots using the 's' parameter
79
+ scatter = plt.scatter(range(len(token_counts)), token_counts, c=colors, s=20, label=labels)
80
+
81
+ # Create a custom legend
82
+ from matplotlib.lines import Line2D
83
+ legend_elements = [Line2D([0], [0], marker='o', color='w', label='Full VQA Score', markerfacecolor='green', markersize=10),
84
+ Line2D([0], [0], marker='o', color='w', label='Partial VQA Score', markerfacecolor='orange', markersize=10),
85
+ Line2D([0], [0], marker='o', color='w', label='Zero VQA Score', markerfacecolor='red', markersize=10)]
86
+ #plt.legend(handles=legend_elements, loc='upper right')
87
+ plt.legend(handles=legend_elements, loc='best', bbox_to_anchor=(1, 1))
88
+ # Set the title and labels
89
+ plt.title('Token Counts VS VQA Score')
90
+ plt.xlabel('Index')
91
+ plt.ylabel('Number of Tokens')
92
+
93
+ # Display the plot
94
+ plt.show()