File size: 3,856 Bytes
dab7e6b
 
 
 
 
 
 
 
7bb2c40
dab7e6b
c41a81d
dab7e6b
bd62fcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0a8d2a
 
 
 
bd62fcd
e0a8d2a
 
 
 
bd62fcd
e0a8d2a
 
 
 
 
 
 
bd62fcd
e0a8d2a
 
 
 
 
7bb2c40
 
e0a8d2a
 
7bb2c40
 
e0a8d2a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import pandas as pd
from fuzzywuzzy import fuzz
from collections import Counter
from nltk.stem import PorterStemmer
from ast import literal_eval
from typing import Union, List
import streamlit as st
from my_model.results.evaluation import KBVQAEvaluator
from my_model.config import evaluation_config as config 

class ResultDemonstrator(KBVQAEvaluator):



    def run(self):
       
        import pandas as pd
        import altair as alt
        
        # Sample data
        data = pd.DataFrame({
            'x': range(10),
            'y': [2, 1, 4, 3, 5, 6, 9, 7, 10, 8]
        })
        
        # Create a scatter plot
        chart = alt.Chart(data).mark_point().encode(
            x='x',
            y='y'
        )
        
        # Display the chart in Streamlit
        st.altair_chart(chart, use_container_width=True)
        # Display the chart in Streamlit
        st.altair_chart(chart, use_container_width=True)

        import matplotlib.pyplot as plt
        import numpy as np
        
        # Data
        x = np.random.randn(100)
        y = np.random.randn(100)
        
        # Create a scatter plot
        fig, ax = plt.subplots()
        ax.scatter(x, y)
        
        # Display the plot in Streamlit
        st.pyplot(fig)

############

        # Load data from Excel
        d = pd.read_excel('my_model/results/evaluation_results.xlsx', sheet_name="Main Data")
        
        # Assume 'accuracies' and 'token_counts' need to be computed or are columns in the DataFrame
        # Compute colors and labels for the plot (assuming these columns are already in the DataFrame)
        d['color'] = d['accuracy'].apply(lambda x: 'green' if x == 1 else 'orange' if round(x, 2) == 0.67 else 'red')
        d['label'] = d['accuracy'].apply(lambda x: 'Correct' if x == 1 else 'Partially Correct' if x == 0.67 else 'Incorrect')
        
        # Creating the scatter plot
        scatter_chart = alt.Chart(d).mark_circle(size=20).encode(
            x=alt.X('index:Q', title='Index'),  # Assuming 'index' is a column or can be created
            y=alt.Y('token_counts:Q', title='Number of Tokens'),
            color=alt.Color('color:N', legend=alt.Legend(title="VQA Score")),
            tooltip=['index', 'token_counts', 'label']
        ).interactive()
        
        # Display the chart
        st.altair_chart(scatter_chart, use_container_width=True)


####################
        scores = d['vqa_score_13b_caption+detic']
        token_counts = d['trimmed_tokens_count_caption_detic']

        # Define colors and labels for the legend
        colors = ['green' if score == 1 else 'orange' if round(score,2) == 0.67 else 'red' for score in scores]
        labels = ['Correct' if score == 1 else 'Partially Correct' if score == 0.67 else 'Incorrect' for score in scores]
        plt.figure(figsize=(10, 6))
        # Create a scatter plot with smaller dots using the 's' parameter
        scatter = plt.scatter(range(len(token_counts)), token_counts, c=colors, s=20, label=labels)
        
        # Create a custom legend
        from matplotlib.lines import Line2D
        legend_elements = [Line2D([0], [0], marker='o', color='w', label='Full VQA Score', markerfacecolor='green', markersize=10),
                           Line2D([0], [0], marker='o', color='w', label='Partial VQA Score', markerfacecolor='orange', markersize=10),
                           Line2D([0], [0], marker='o', color='w', label='Zero VQA Score', markerfacecolor='red', markersize=10)]
        #plt.legend(handles=legend_elements, loc='upper right')
        plt.legend(handles=legend_elements, loc='best', bbox_to_anchor=(1, 1))
        # Set the title and labels
        plt.title('Token Counts VS VQA Score')
        plt.xlabel('Index')
        plt.ylabel('Number of Tokens')
        
        # Display the plot
        plt.show()