|
import pandas as pd |
|
from fuzzywuzzy import fuzz |
|
from collections import Counter |
|
from nltk.stem import PorterStemmer |
|
from ast import literal_eval |
|
from typing import Union, List |
|
import streamlit as st |
|
from my_model.results.evaluation import KBVQAEvaluator |
|
from my_model.config import evaluation_config as config |
|
|
|
class ResultDemonstrator(KBVQAEvaluator): |
|
|
|
|
|
|
|
def run(self): |
|
|
|
import pandas as pd |
|
import altair as alt |
|
|
|
|
|
data = pd.DataFrame({ |
|
'x': range(10), |
|
'y': [2, 1, 4, 3, 5, 6, 9, 7, 10, 8] |
|
}) |
|
|
|
|
|
chart = alt.Chart(data).mark_point().encode( |
|
x='x', |
|
y='y' |
|
) |
|
|
|
|
|
st.altair_chart(chart, use_container_width=True) |
|
|
|
st.altair_chart(chart, use_container_width=True) |
|
|
|
import matplotlib.pyplot as plt |
|
import numpy as np |
|
|
|
|
|
x = np.random.randn(100) |
|
y = np.random.randn(100) |
|
|
|
|
|
fig, ax = plt.subplots() |
|
ax.scatter(x, y) |
|
|
|
|
|
st.pyplot(fig) |
|
|
|
|
|
|
|
|
|
d = pd.read_excel('my_model/results/evaluation_results.xlsx', sheet_name="Main Data") |
|
|
|
|
|
|
|
d['color'] = d['accuracy'].apply(lambda x: 'green' if x == 1 else 'orange' if round(x, 2) == 0.67 else 'red') |
|
d['label'] = d['accuracy'].apply(lambda x: 'Correct' if x == 1 else 'Partially Correct' if x == 0.67 else 'Incorrect') |
|
|
|
|
|
scatter_chart = alt.Chart(d).mark_circle(size=20).encode( |
|
x=alt.X('index:Q', title='Index'), |
|
y=alt.Y('token_counts:Q', title='Number of Tokens'), |
|
color=alt.Color('color:N', legend=alt.Legend(title="VQA Score")), |
|
tooltip=['index', 'token_counts', 'label'] |
|
).interactive() |
|
|
|
|
|
st.altair_chart(scatter_chart, use_container_width=True) |
|
|
|
|
|
|
|
scores = d['vqa_score_13b_caption+detic'] |
|
token_counts = d['trimmed_tokens_count_caption_detic'] |
|
|
|
|
|
colors = ['green' if score == 1 else 'orange' if round(score,2) == 0.67 else 'red' for score in scores] |
|
labels = ['Correct' if score == 1 else 'Partially Correct' if score == 0.67 else 'Incorrect' for score in scores] |
|
plt.figure(figsize=(10, 6)) |
|
|
|
scatter = plt.scatter(range(len(token_counts)), token_counts, c=colors, s=20, label=labels) |
|
|
|
|
|
from matplotlib.lines import Line2D |
|
legend_elements = [Line2D([0], [0], marker='o', color='w', label='Full VQA Score', markerfacecolor='green', markersize=10), |
|
Line2D([0], [0], marker='o', color='w', label='Partial VQA Score', markerfacecolor='orange', markersize=10), |
|
Line2D([0], [0], marker='o', color='w', label='Zero VQA Score', markerfacecolor='red', markersize=10)] |
|
|
|
plt.legend(handles=legend_elements, loc='best', bbox_to_anchor=(1, 1)) |
|
|
|
plt.title('Token Counts VS VQA Score') |
|
plt.xlabel('Index') |
|
plt.ylabel('Number of Tokens') |
|
|
|
|
|
plt.show() |