lterriel's picture
clean & refactor components + add doc
74e2066
# -*- coding:utf-8 -*-
from itertools import combinations
from collections import defaultdict, Counter
import pandas as pd
import seaborn as sns
import matplotlib as plt
plt.use('Agg')
import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize, word_tokenize
from n4a_analytics_lib.project import Project
from n4a_analytics_lib.metrics_utils import (fleiss_kappa_function, cohen_kappa_function, percentage_agreement_pov)
class GlobalStatistics(Project):
def __init__(self, zip_project, remote=False):
super().__init__(zip_project=zip_project, remote=remote, type="global")
self.data = [(src_file, ne_label) for src_file, ann in self.annotations.items() for ne_label in ann['labels']]
self.df_base = pd.DataFrame(self.data, columns=["SOURCE_FILE", "LABEL"])
self.df_i = self.df_base.groupby(["LABEL"])["LABEL"].count().reset_index(name="TOTAL")
self.df_details = self.df_base.groupby(["SOURCE_FILE", "LABEL"])["LABEL"].count().reset_index(name="TOTAL")
self.total_annotations_project = self.df_i['TOTAL'].sum()
def create_plot(self, type_data: str) -> sns.barplot:
# apply data filter
data_tab_filtered = self.df_details.loc[self.df_details['SOURCE_FILE'] == type_data]
# create a new plot
ax = sns.barplot(x='LABEL', y='TOTAL', data=data_tab_filtered)
# add title to plot
ax.figure.suptitle(type_data)
# add value labels to bars
for container in ax.containers:
ax.bar_label(container)
return ax.figure
class IaaStatistics(Project):
def __init__(self, zip_project, baseline_text, remote=False):
super().__init__(zip_project=zip_project, remote=remote, type="iaa")
self.baseline_text = baseline_text.decode('utf-8')
self.mentions_per_coder = self.extract_refs(self.annotations, self.annotators, type="mentions")
self.labels_per_coder = self.extract_refs(self.annotations, self.annotators, type="labels")
self.annotations_per_coders = {coder: dict(zip(ann[1]['mentions'], ann[1]['labels'])) for coder, ann in zip(self.annotators, self.annotations.items())}
self.coders_pairs = list(combinations(self.annotations_per_coders, 2))
self.similar_mention = list(dict.fromkeys([l for i,j in self.mentions_per_coder.items() for l in j]))
self.labels_schema = list(dict.fromkeys([label for _, labels in self.labels_per_coder.items() for label in labels]))
# dataframes and matrix analysis
self.base_df = self.build_base_df()
self.df_agree = self.base_df [self.base_df[self.annotators].apply(lambda row: self.check_all_equal(row), axis=1)]
self.df_disagree = self.base_df[self.base_df[self.annotators].apply(lambda row: self.check_all_not_equal(row), axis=1)]
self.coders_matrix = self.base_df.apply(pd.Series.value_counts, 1).fillna(0).astype(int).values
# totals
self.total_annotations = len(self.base_df)
self.total_agree = len(self.df_agree)
self.total_disagree = len(self.df_disagree)
# access to metrics
self.fleiss_kappa = round(fleiss_kappa_function(self.coders_matrix), 2)
self.cohen_kappa_pairs = self.compute_pairs_cohen_kappa()
self.percent_agree = percentage_agreement_pov(self.total_agree, self.total_annotations)
self.percent_disagree = percentage_agreement_pov(self.total_disagree, self.total_annotations)
@staticmethod
def extract_refs(annotations: dict, annotators: list, type: str) -> dict:
return {
coder: data for coder, ann in zip(
annotators,
annotations.items()
) for ref, data in ann[1].items() if ref == type
}
@staticmethod
def check_all_equal(iterator: list) -> bool:
return len(set(iterator)) <= 1
@staticmethod
def check_all_not_equal(iterator: list) -> bool:
return len(set(iterator)) > 1
def plot_confusion_matrix(self, width: int, height: int) -> plt.pyplot.subplots:
intermediary = defaultdict(Counter)
for (src, tgt), count in self.cohen_kappa_pairs.items():
intermediary[src][tgt] = count
letters = sorted({key for inner in intermediary.values() for key in inner} | set(intermediary.keys()))
confusion_matrix = [[intermediary[src][tgt] for tgt in letters] for src in letters]
df_cm = pd.DataFrame(confusion_matrix, letters, letters)
mask = df_cm.values == 0
sns.set(font_scale=0.7) # for label size
colors = ["#e74c3c", "#f39c12", "#f4d03f", "#5dade2", "#58d68d", "#28b463"]
fig, ax = plt.pyplot.subplots(figsize=(width, height))
sns.heatmap(df_cm, cmap=colors, annot=True, mask=mask, annot_kws={"size": 7}, vmin=0, vmax=1, ax=ax) # font size
return ax
def build_base_df(self) -> pd.DataFrame:
df = pd.DataFrame(self.annotations_per_coders, index=self.similar_mention)
for ann in self.annotators:
df[ann] = 'None'
for mention, value in self.annotations_per_coders[ann].items():
df.loc[mention, ann] = value
return df
def compute_pairs_cohen_kappa(self) -> dict:
return {
(c1, c2): cohen_kappa_function(self.labels_per_coder[c1],
self.labels_per_coder[c2]) for c1, c2 in self.coders_pairs
}
def count_total_annotations_label(self) -> list:
return [
(label, self.base_df.astype(object).eq(label).any(1).sum()) for label in self.labels_schema
]
def total_agree_disagree_per_label(self) -> list:
# t[0] : label
# t[1] : total_rows_with_label
return [(
t[0],
t[1],
(self.base_df[self.base_df.nunique(1).eq(1)].eq(t[0]).any(1).sum() / t[1]) * 100,
((t[1] - self.base_df[self.base_df.nunique(1).eq(1)].eq(t[0]).any(1).sum()) / t[1]) * 100
)
for t in self.count_total_annotations_label()]
def plot_agreement_pies(self) -> plt.pyplot.subplots:
my_labels = 'agree', 'disagree'
my_colors = ['#47DBCD', '#F5B14C']
my_explode = (0, 0.1)
counter = 0
tasks_to_pie = self.total_agree_disagree_per_label()
fig, axes = plt.pyplot.subplots(1, len(tasks_to_pie), figsize=(20, 3))
for t in tasks_to_pie:
tasks = [t[2], t[3]]
axes[counter].pie(tasks, autopct='%1.1f%%', startangle=15, shadow=True, colors=my_colors,
explode=my_explode)
axes[counter].set_title(t[0])
axes[counter].axis('equal')
counter += 1
fig.set_facecolor("white")
fig.legend(labels=my_labels, loc="center right", borderaxespad=0.1, title="Labels alignement")
# plt.savefig(f'./out/pie_alignement_labels_{filename_no_extension}.png', dpi=400)
return fig
def analyze_text(self) -> list:
"""returns total sentences, words and characters
in list format
"""
return [
len(sent_tokenize(self.baseline_text, language="french")),
len(word_tokenize(self.baseline_text, language="french")),
len(self.baseline_text)
]