# NDIS Project - PBSP Scoring - Page 3

In [None]:
import os
from ipywidgets import interact
import ipywidgets as widgets
from IPython.display import display, clear_output, Javascript, HTML, Markdown
from qdrant_client import QdrantClient
from qdrant_client.http.models import Distance, VectorParams, Batch, Filter, FieldCondition, Range, MatchValue
import json
import spacy
from spacy import displacy
import nltk
from nltk import sent_tokenize
from sklearn.feature_extraction import text
from pprint import pprint
import re
from flair.embeddings import TransformerDocumentEmbeddings
from flair.data import Sentence
from sentence_transformers import SentenceTransformer, util
import pandas as pd
import argilla as rg
from argilla.metrics.text_classification import f1
from typing import Dict
from setfit import SetFitModel
from tqdm import tqdm
import time
for i in tqdm(range(30), disable=True):
    time.sleep(1)

In [None]:
#initializations
bhvr_onto_file = 'ontology_page3_bhvr.csv'
event_onto_file = 'ontology_page3_event.csv'
embedding = TransformerDocumentEmbeddings('distilbert-base-uncased')
client = QdrantClient(
    host=os.environ["QDRANT_API_URL"], 
    api_key=os.environ["QDRANT_API_KEY"],
    timeout=60,
    port=443
)
collection_name = "my_collection"
model = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
vector_dim = 384 #{distilbert-base-uncased: 768, multi-qa-MiniLM-L6-cos-v1:384}
sf_trig_model_name = "setfit-zero-shot-classification-pbsp-p3-trig"
sf_trig_model = SetFitModel.from_pretrained(f"aammari/{sf_trig_model_name}")
sf_cons_model_name = "setfit-zero-shot-classification-pbsp-p3-cons"
sf_cons_model = SetFitModel.from_pretrained(f"aammari/{sf_cons_model_name}")

# download nltk 'punkt' if not available
try:
    nltk.data.find('tokenizers/punkt')
except LookupError:
    nltk.download('punkt')

# download nltk 'averaged_perceptron_tagger' if not available
try:
    nltk.data.find('taggers/averaged_perceptron_tagger')
except LookupError:
    nltk.download('averaged_perceptron_tagger')
    
#argilla
rg.init(
    api_url=os.environ["ARGILLA_API_URL"],
    api_key=os.environ["ARGILLA_API_KEY"]
)

In [None]:
#Behaviour ontology to be loaded from a CSV file
bhvr_onto_df = pd.read_csv(bhvr_onto_file, header=None).dropna()
bhvr_onto_df.columns = ['text']
bhvr_onto_lst = bhvr_onto_df['text'].tolist()

In [None]:
#Setting event ontology to be loaded from a CSV file
event_onto_df = pd.read_csv(event_onto_file, header=None).dropna()
event_onto_df.columns = ['text']
event_onto_lst = event_onto_df['text'].tolist()

In [None]:
#Text Preprocessing
try:
    nlp = spacy.load('en_core_web_sm')
except OSError:
    spacy.cli.download('en_core_web_sm')
    nlp = spacy.load('en_core_web_sm')
sw_lst = text.ENGLISH_STOP_WORDS
def preprocess(onto_lst):
    cleaned_onto_lst = []
    pattern = re.compile(r'^[a-z ]*$')
    for document in onto_lst:
        text = []
        doc = nlp(document)
        person_tokens = []
        for w in doc:
            if w.ent_type_ == 'PERSON':
                person_tokens.append(w.lemma_)
        for w in doc:
            if not w.is_stop and not w.is_punct and not w.like_num and not len(w.text.strip()) == 0 and not w.lemma_ in person_tokens:
                text.append(w.lemma_.lower())
        texts = [t for t in text if len(t) > 1 and pattern.search(t) is not None and t not in sw_lst]
        cleaned_onto_lst.append(" ".join(texts))
    return cleaned_onto_lst

cl_bhvr_onto_lst = preprocess(bhvr_onto_lst)
cl_event_onto_lst = preprocess(event_onto_lst)

#pprint(cl_bhvr_onto_lst)
#pprint(cl_event_onto_lst)

In [None]:
#compute document embeddings

# distilbert-base-uncased from Flair
def embeddings(cl_onto_lst):
    emb_onto_lst = []
    for doc in cl_onto_lst:
        sentence = Sentence(doc)
        embedding.embed(sentence)
        emb_onto_lst.append(sentence.embedding.tolist())
    return emb_onto_lst

# multi-qa-MiniLM-L6-cos-v1 from sentence_transformers
def sentence_embeddings(cl_onto_lst):
    emb_onto_lst_temp = model.encode(cl_onto_lst)
    emb_onto_lst = [x.tolist() for x in emb_onto_lst_temp]
    return emb_onto_lst

'''
emb_bhvr_onto_lst = embeddings(cl_bhvr_onto_lst)
emb_fh_onto_lst = embeddings(cl_fh_onto_lst)
emb_rep_onto_lst = embeddings(cl_rep_onto_lst)
'''

emb_bhvr_onto_lst = sentence_embeddings(cl_bhvr_onto_lst)
emb_event_onto_lst = sentence_embeddings(cl_event_onto_lst)

In [None]:
#add to qdrant collection
def add_to_collection():
    global cl_bhvr_onto_lst, emb_bhvr_onto_lst, cl_event_onto_lst, emb_event_onto_lst
    client.recreate_collection(
        collection_name=collection_name,
        vectors_config=VectorParams(size=vector_dim, distance=Distance.COSINE),
    )
    doc_count = len(emb_bhvr_onto_lst) + len(emb_event_onto_lst)
    ids = list(range(1, doc_count+1))
    payloads = [{"ontology": "behaviours", "phrase": x} for x in cl_bhvr_onto_lst] + \
               [{"ontology": "setting_events", "phrase": y} for y in cl_event_onto_lst]
    vectors = emb_bhvr_onto_lst+emb_event_onto_lst
    client.upsert(
        collection_name=f"{collection_name}",
        points=Batch(
            ids=ids,
            payloads=payloads,
            vectors=vectors
        ),
    )

def count_collection():
    return len(client.scroll(
            collection_name=f"{collection_name}"
        )[0])

add_to_collection()
point_count = count_collection()

In [None]:
#print(point_count)

In [None]:
#query_filter=Filter(
#        must=[ 
#            FieldCondition(
#                key='ontology',
#                match=MatchValue(value="setting_events")# Condition based on values of `rand_number` field.
#            )
#        ]
#    )

In [None]:
#noun phrase extraction
def extract_noun_phrases(text):
    # Tokenize the text
    tokens = nltk.word_tokenize(text)

    # Part-of-speech tag the tokens
    tagged_tokens = nltk.pos_tag(tokens)

    # Define the noun phrase grammar
    grammar = r"""
    NP: {<DT|PP\$>?<JJ>*<NN|NNS|NNP|NNPS>+}  # noun phrase with optional determiner and adjectives
        {<NNP>+}                              # proper noun phrase
        {<PRP\$>?<NN|NNS|NNP|NNPS>+}          # noun phrase with optional possessive pronoun
    """

    # Extract the noun phrases
    parser = nltk.RegexpParser(grammar)
    tree = parser.parse(tagged_tokens)

    # Extract the phrase text from the tree
    phrases = []
    for subtree in tree.subtrees():
        if subtree.label() == "NP":
            phrase = " ".join([token[0] for token in subtree.leaves()])
            phrases.append(phrase)
    return phrases


#verb phrase extraction
def extract_vbs(data_chunked):
    for tup in data_chunked:
        if len(tup) > 2:
            yield(str(" ".join(str(x[0]) for x in tup)))

def get_verb_phrases(nltk_query):
    data_tok = nltk.word_tokenize(nltk_query) #tokenisation
    data_pos = nltk.pos_tag(data_tok) #POS tagging
    cfgs = [
        "CUSTOMCHUNK: {<VB><.*>{0,3}<NN>}",
        "CUSTOMCHUNK: {<VB><.*>{0,3}<NNP>}",
        "CUSTOMCHUNK: {<VB><.*>{0,3}<PRP><NN>}",
        "CUSTOMCHUNK: {<VB><.*>{0,3}<PRP><NNS>}",
        "CUSTOMCHUNK: {<VB><.*>{0,3}<NNPS>}",
        "CUSTOMCHUNK: {<VB><.*>{0,3}<NNS>}",
        "CUSTOMCHUNK: {<VB><.*>{0,3}<PRP><NNP>}",
        "CUSTOMCHUNK: {<VB><.*>{0,3}<PRP><NNPS>}",
        "CUSTOMCHUNK: {<VBN><.*>{0,3}<NN>}",
        "CUSTOMCHUNK: {<VBN><.*>{0,3}<NNP>}",
        "CUSTOMCHUNK: {<VBN><.*>{0,3}<PRP><NN>}",
        "CUSTOMCHUNK: {<VBN><.*>{0,3}<PRP><NNS>}",
        "CUSTOMCHUNK: {<VBN><.*>{0,3}<NNPS>}",
        "CUSTOMCHUNK: {<VBN><.*>{0,3}<NNS>}",
        "CUSTOMCHUNK: {<VBN><.*>{0,3}<PRP><NNP>}",
        "CUSTOMCHUNK: {<VBN><.*>{0,3}<PRP><NNPS>}",
        "CUSTOMCHUNK: {<VBG><.*>{0,3}<NN>}",
        "CUSTOMCHUNK: {<VBG><.*>{0,3}<NNP>}",
        "CUSTOMCHUNK: {<VBG><.*>{0,3}<PRP><NN>}",
        "CUSTOMCHUNK: {<VBG><.*>{0,3}<PRP><NNS>}",
        "CUSTOMCHUNK: {<VBG><.*>{0,3}<NNPS>}",
        "CUSTOMCHUNK: {<VBG><.*>{0,3}<NNS>}",
        "CUSTOMCHUNK: {<VBG><.*>{0,3}<PRP><NNP>}",
        "CUSTOMCHUNK: {<VBG><.*>{0,3}<PRP><NNPS>}",
        "CUSTOMCHUNK: {<VBP><.*>{0,3}<NN>}",
        "CUSTOMCHUNK: {<VBP><.*>{0,3}<NNP>}",
        "CUSTOMCHUNK: {<VBP><.*>{0,3}<PRP><NN>}",
        "CUSTOMCHUNK: {<VBP><.*>{0,3}<PRP><NNS>}",
        "CUSTOMCHUNK: {<VBP><.*>{0,3}<NNPS>}",
        "CUSTOMCHUNK: {<VBP><.*>{0,3}<NNS>}",
        "CUSTOMCHUNK: {<VBP><.*>{0,3}<PRP><NNP>}",
        "CUSTOMCHUNK: {<VBP><.*>{0,3}<PRP><NNPS>}",
        "CUSTOMCHUNK: {<VBZ><.*>{0,3}<NN>}",
        "CUSTOMCHUNK: {<VBZ><.*>{0,3}<NNP>}",
        "CUSTOMCHUNK: {<VBZ><.*>{0,3}<PRP><NN>}",
        "CUSTOMCHUNK: {<VBZ><.*>{0,3}<PRP><NNS>}",
        "CUSTOMCHUNK: {<VBZ><.*>{0,3}<NNPS>}",
        "CUSTOMCHUNK: {<VBZ><.*>{0,3}<NNS>}",
        "CUSTOMCHUNK: {<VBZ><.*>{0,3}<PRP><NNP>}",
        "CUSTOMCHUNK: {<VBZ><.*>{0,3}<PRP><NNPS>}"
       ]
    vbs = []
    for cfg_1 in cfgs: 
        chunker = nltk.RegexpParser(cfg_1)
        data_chunked = chunker.parse(data_pos)
        vbs += extract_vbs(data_chunked)
    return vbs

In [None]:
#text = "The quick brown fox jumps over the lazy dog."
#phrases = extract_noun_phrases(text)
#cl_phrases = preprocess(phrases)
#print(cl_phrases) 

In [None]:
#use the get_verb_phrases function to enrich the behaviour ontology
#from itertools import chain
#maria_file = 'behaviour_score_maria_2.csv'
#maria_df = pd.read_csv(maria_file).dropna()
#clf_df = maria_df[['Behaviour', 'Behaviour Score']]
#one_lst = clf_df[clf_df['Behaviour Score'] == 1]['Behaviour']
#list_of_lists = [get_verb_phrases(x) for x in one_lst]
#vbs = list(set(list(chain.from_iterable(list_of_lists))))
#cl_vbs = preprocess(vbs)
#cl_vbs = [x for x in cl_vbs if len(x.split()) > 1]
#for cl_vb in cl_vbs:
#    print(cl_vb)

In [None]:
#use the get_verb_phrases function to enrich the setting events ontology
#from itertools import chain
#maria_file = 'behaviour_score_maria_2.csv'
#maria_df = pd.read_csv(maria_file).dropna()
#clf_df = maria_df[['Setting Event']]
#one_lst = clf_df['Setting Event'].tolist()
#list_of_lists = [get_verb_phrases(x) for x in one_lst]
#vbs = list(set(list(chain.from_iterable(list_of_lists))))
#cl_vbs = preprocess(vbs)
#cl_vbs = list(set([x for x in cl_vbs if len(x.split()) > 1]))
#for cl_vb in cl_vbs:
#    print(cl_vb)

In [None]:
#use the extract_noun_phrases function to enrich the setting events ontology
#from itertools import chain
#maria_file = 'behaviour_score_maria_2.csv'
#maria_df = pd.read_csv(maria_file).dropna()
#clf_df = maria_df[['Setting Event']]
#one_lst = clf_df['Setting Event'].tolist()
#list_of_lists = [extract_noun_phrases(x) for x in one_lst]
#vbs = list(set(list(chain.from_iterable(list_of_lists))))
#cl_vbs = preprocess(vbs)
#cl_vbs = list(set([x for x in cl_vbs if len(x.split()) > 1]))
#for cl_vb in cl_vbs:
#    print(cl_vb)

In [None]:
#query and get score

# distilbert-base-uncased from Flair
def get_query_vector(query):
    sentence = Sentence(query)
    embedding.embed(sentence)
    query_vector = sentence.embedding.tolist()
    return query_vector

# multi-qa-MiniLM-L6-cos-v1 from sentence_transformers
def sentence_get_query_vector(query):
    query_vector = model.encode(query)
    return query_vector

def search_collection(ontology, query_vector):
    query_filter=Filter(
        must=[  
            FieldCondition(
                key='ontology',
                match=MatchValue(value=ontology)
            )
        ]
    )
    
    hits = client.search(
        collection_name=f"{collection_name}",
        query_vector=query_vector,
        query_filter=query_filter, 
        append_payload=True,  
        limit=point_count 
    )
    return hits

semantic_passing_score = 0.50


#ontology = 'behaviours'
#query = 'punch father face'
#query_vector = sentence_get_query_vector(query)
#hist = search_collection(ontology, query_vector)

In [None]:
# format output
def bhvr_color(df):
    return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color='#90EE90')

def event_color(df):
    return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color='#9370DB')

def annotate_query(highlights, query, topic):
    ents = []
    for h in highlights:
        ent_dict = {}
        for match in re.finditer(h, query):
            ent_dict = {"start": match.start(), "end": match.end(), "label": topic}
            break
        if len(ent_dict.keys()) > 0:
            ents.append(ent_dict)
    return ents

In [None]:
#setfit trig and cons sentence extraction
def extract_sentences(nltk_query):
    sentences = sent_tokenize(nltk_query)
    return sentences

In [None]:
def convert_df(result_df):
    new_df = pd.DataFrame(columns=['text', 'prediction'])
    new_df['text'] = result_df['Phrase']
    new_df['prediction'] = result_df.apply(lambda row: [[row['Topic'], min(row['Score'], 1.0)]], axis=1)
    return new_df

In [None]:
def custom_f1(data: Dict[str, float], title: str):
    from plotly.subplots import make_subplots
    import plotly.colors
    import random

    fig = make_subplots(
        rows=2,
        cols=1,
        subplot_titles=[        "Overall Model Score",        "Model Score By Category",    ],
    )

    x = ['precision', 'recall', 'f1']
    macro_data = [v for k, v in data.items() if "macro" in k]
    fig.add_bar(
        x=x,
        y=macro_data,
        row=1,
        col=1,
    )
    per_label = {
        k: v
        for k, v in data.items()
        if all(key not in k for key in ["macro", "micro", "support"])
    }

    num_labels = int(len(per_label.keys())/3)
    fixed_colors = [str(color) for color in plotly.colors.qualitative.Plotly]
    colors = random.sample(fixed_colors, num_labels)

    fig.add_bar(
        x=[k for k, v in per_label.items()],
        y=[v for k, v in per_label.items()],
        row=2,
        col=1,
        marker_color=[colors[int(i/3)] for i in range(0, len(per_label.keys()))]
    )
    fig.update_layout(showlegend=False, title_text=title)

    return fig

In [None]:
def get_null_class_df(sentences, result_df):
    sents = result_df['Phrase'].tolist()
    null_sents = [x for x in sentences if x not in sents]
    topics = ['NONE'] * len(null_sents)
    scores = [0.90] * len(null_sents)
    null_df = pd.DataFrame({'Phrase': null_sents, 'Topic': topics, 'Score': scores})
    return null_df

In [None]:
#setfit trig query and get predicted topic

def get_sf_trig_topic(sentences):
    preds = list(sf_trig_model(sentences))
    return preds
def get_sf_trig_topic_scores(sentences):
    preds = sf_trig_model.predict_proba(sentences)
    preds = [max(list(x)) for x in preds]
    return preds

In [None]:
# setfit trig format output
ind_trig_topic_dict = {
        0: 'NO TRIGGER',
        1: 'TRIGGER',
    }

highlight_threshold = 0.25
passing_score = 0.50

def sf_trig_color(df):
    return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color='#ADD8E6')

def sf_annotate_query(highlights, query, topics):
    ents = []
    for h, t in zip(highlights, topics):
        ent_dict = {}
        for match in re.finditer(h, query):
            ent_dict = {"start": match.start(), "end": match.end(), "label": t}
            break
        if len(ent_dict.keys()) > 0:
            ents.append(ent_dict)
    return ents

In [None]:
#setfit cons query and get predicted topic

def get_sf_cons_topic(sentences):
    preds = list(sf_cons_model(sentences))
    return preds
def get_sf_cons_topic_scores(sentences):
    preds = sf_cons_model.predict_proba(sentences)
    preds = [max(list(x)) for x in preds]
    return preds

In [None]:
# setfit cons format output
ind_cons_topic_dict = {
        0: 'NO CONSEQUENCE',
        1: 'CONSEQUENCE',
    }

def sf_cons_color(df):
    return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color='#F08080')

In [None]:
def rem_prev_detections(sub_current_df, sub_prev_df):
    prevs = sub_prev_df['Phrase'].tolist()
    cl_sub_current_df = sub_current_df[~sub_current_df['Phrase'].isin(prevs)]
    return cl_sub_current_df

def path_to_image_html(path):
    return '<img src="'+ path + '" width="30" height="15" />'

def display_final_df(tags):
    crits = [
        'Setting Event',
        'Triggers',
        'Behaviour',
        'Consequences'
    ]
    descs = [
        'Does the plan identify at least one setting event for the behaviour?',
        'Does the plan identify at least one immediate trigger/antecedent for the behaviour?',
        'Does the plan identify at least one behaviour?',
        'Does the plan identify at least one maintaining consequence for the behaviour?'
    ]
    paths = ['./thumbs_up.png' if x else './thumbs_down.png' for x in tags]
    df = pd.DataFrame({'Criteria': crits, 'Descrption': descs, 'Score': paths})
    df = df.set_index('Criteria')
    pd.set_option('display.max_colwidth', None)
    display(HTML('<div style="text-align: center;">' + df.to_html(classes=["align-center"], index=True, escape=False ,formatters=dict(Score=path_to_image_html)) + '</div>'))

### Please complete the following A-B-C chain to demonstrate how the identified <font color='blue'>triggers</font>  are linked to the personâ€™s <font color='green'>behaviour</font>, and what happens after the behaviour to reinforce it, and therefore maintain the <font color='red'>consequences</font>. Also include <font color='purple'>setting events</font>

In [None]:
#demo with Voila

event_label = widgets.Label(value = r'\(\color{purple} {' + 'Setting Events:'  + '}\)')
event_text_input = widgets.Textarea(
    value='',
    placeholder='Type your answer',
    description='',
    disabled=False,
    layout={'height': '100%', 'width': '90%'}
)

trig_label = widgets.Label(value = r'\(\color{blue} {' + 'Triggers:'  + '}\)')
trig_text_input = widgets.Textarea(
    value='',
    placeholder='Type your answer',
    description='',
    disabled=False,
    layout={'height': '100%', 'width': '90%'}
)

bhvr_label = widgets.Label(value = r'\(\color{green} {' + 'Behaviours:'  + '}\)')
bhvr_text_input = widgets.Textarea(
    value='',
    placeholder='Type your answer',
    description='',
    disabled=False,
    layout={'height': '100%', 'width': '90%'}
)

cons_label = widgets.Label(value = r'\(\color{red} {' + 'Consequences:'  + '}\)')
cons_text_input = widgets.Textarea(
    value='',
    placeholder='Type your answer',
    description='',
    disabled=False,
    layout={'height': '100%', 'width': '90%'}
)

bhvr_nlp_btn = widgets.Button(
    description='Score Answer',
    disabled=False,
    button_style='success', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Score Answer',
    icon='check',
    layout={'height': '70px', 'width': '250px'}
)
bhvr_agr_btn = widgets.Button(
    description='Validate Data',
    disabled=False,
    button_style='success', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Validate Data',
    icon='check',
    layout={'height': '70px', 'width': '250px'}
)
bhvr_eval_btn = widgets.Button(
    description='Evaluate Model',
    disabled=False,
    button_style='success', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Evaluate Model',
    icon='check',
    layout={'height': '70px', 'width': '250px'}
)
btn_box = widgets.HBox([bhvr_nlp_btn, bhvr_agr_btn, bhvr_eval_btn], 
                       layout={'width': '100%', 'height': '160%'})
bhvr_outt = widgets.Output()
bhvr_outt.layout.height = '100%'
bhvr_outt.layout.width = '100%'

event_answer_box = widgets.VBox([event_label, event_text_input], 
                   layout={'width': '400px', 'height': '200px'})

trig_answer_box = widgets.VBox([trig_label, trig_text_input], 
                   layout={'width': '400px', 'height': '200px'})

bhvr_answer_box = widgets.VBox([bhvr_label, bhvr_text_input], 
                   layout={'width': '400px', 'height': '200px'})

cons_answer_box = widgets.VBox([cons_label, cons_text_input], 
                   layout={'width': '400px', 'height': '200px'})

answer_box = widgets.HBox([event_answer_box, trig_answer_box, bhvr_answer_box, cons_answer_box], 
                   layout={'width': '90%', 'height': '400px'})

total_box = widgets.VBox([answer_box, btn_box, bhvr_outt], 
                   layout={'width': '100%', 'height': '100%'})
dataset_rg_name = 'pbsp-page3-abc-argilla-ds'
agrilla_df = None
annotated = False
sub_2_result_dfs = []
def on_bhvr_button_next(b):
    global bhvr_onto_lst, cl_bhvr_onto_lst, event_onto_lst, cl_event_onto_lst, agrilla_df
    with bhvr_outt:
        bhvr_tag = False
        event_tag = False
        trig_tag = False
        cons_tag = False
        clear_output()
        #semantic behaviour
        orig_cl_dict = {x:y for x,y in zip(cl_bhvr_onto_lst, bhvr_onto_lst)}
        query = bhvr_text_input.value
        vbs = get_verb_phrases(query)
        cl_vbs = preprocess(vbs)
        emb_vbs = sentence_embeddings(cl_vbs)
        vb_ind = -1
        highlights = []
        highlight_scores = []
        bhvr_result_dfs = []
        for query_vector in emb_vbs:
            vb_ind += 1
            hist = search_collection('behaviours', query_vector)
            hist_dict = [dict(x) for x in hist]
            scores = [x['score'] for x in hist_dict]
            payloads = [orig_cl_dict[x['payload']['phrase']] for x in hist_dict]
            result_df = pd.DataFrame({'Score': scores, 'Glossary': payloads})
            result_df = result_df[result_df['Score'] >= semantic_passing_score]
            if len(result_df) > 0:
                highlights.append(vbs[vb_ind])
                highlight_scores.append(result_df.Score.max())
                result_df['Phrase'] = [vbs[vb_ind]] * len(result_df)
                result_df = result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)
                bhvr_result_dfs.append(result_df)
            else:
                continue
        ents = []
        colors = {}
        if len(highlights) > 0:
            ents = annotate_query(highlights, query, "BEHAVIOUR")
            for ent in ents:
                colors[ent['label']] = '#90EE90'
        options = {"ents": list(colors), "colors": colors}
        ex = [{"text": query,
               "ents": ents,
               "title": None}]
        if len(ents) > 0:
            title = "Behaviour Phrases"
            display(HTML(f'<center><h1>{title}</h1></center>'))
            html = displacy.render(ex, style="ent", manual=True, options=options)
            display(HTML(html))
        else:
            pass
            
        if len(bhvr_result_dfs) > 0:
            bhvr_tag = True
            result_df = pd.concat(bhvr_result_dfs).reset_index(drop = True)
            result_df = result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)
            sub_2_result_df = result_df.copy()
            sub_2_result_df['Topic'] = ['BEHAVIOUR'] * len(result_df)
            sub_2_result_df = sub_2_result_df[['Phrase', 'Topic', 'Score']].drop_duplicates().reset_index(drop=True)
            null_df = get_null_class_df(vbs, sub_2_result_df)
            if len(null_df) > 0:
                sub_2_result_df = pd.concat([sub_2_result_df, null_df]).reset_index(drop=True)
            sub_2_result_dfs.append(sub_2_result_df)
            agg_df = result_df.groupby(result_df.Phrase).max()
            agg_df['Phrase'] = agg_df.index
            agg_df = agg_df.reset_index(drop=True)
            agg_df = agg_df.drop(columns=['Glossary'])
            result_df = pd.merge(result_df, agg_df, 'inner', ['Phrase', 'Score'])
            result_df = result_df[['Phrase', 'Glossary', 'Score']]
            result_df = result_df.set_index('Phrase')
            display(bhvr_color(result_df))
                
        #semantic setting events
        orig_cl_dict = {x:y for x,y in zip(cl_event_onto_lst, event_onto_lst)}
        query = event_text_input.value
        vbs = get_verb_phrases(query)
        cl_vbs = preprocess(vbs)
        nouns = extract_noun_phrases(query)
        cl_nouns = preprocess(nouns)
        sents = vbs+nouns
        emb_sents = sentence_embeddings(cl_vbs+cl_nouns)
        vb_ind = -1
        highlights = []
        highlight_scores = []
        event_result_dfs = []
        for query_vector in emb_sents:
            vb_ind += 1
            if len(sents[vb_ind].split()) <= 1:
                continue
            hist = search_collection('setting_events', query_vector)
            hist_dict = [dict(x) for x in hist]
            scores = [x['score'] for x in hist_dict]
            payloads = [orig_cl_dict[x['payload']['phrase']] for x in hist_dict]
            result_df = pd.DataFrame({'Score': scores, 'Glossary': payloads})
            result_df = result_df[result_df['Score'] >= semantic_passing_score]
            if len(result_df) > 0:
                highlights.append(sents[vb_ind])
                highlight_scores.append(result_df.Score.max())
                result_df['Phrase'] = [sents[vb_ind]] * len(result_df)
                result_df = result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)
                event_result_dfs.append(result_df)
            else:
                continue
        event_ents = []
        colors = {}
        if len(highlights) > 0:
            event_ents = annotate_query(highlights, query, "SETTING EVENT")
            for ent in event_ents:
                colors[ent['label']] = '#9370DB'
        options = {"ents": list(colors), "colors": colors}
        ex = [{"text": query,
               "ents": event_ents,
               "title": None}]
        if len(event_ents) > 0:
            title = "Setting Event Phrases"
            display(HTML(f'<center><h1>{title}</h1></center>'))
            html = displacy.render(ex, style="ent", manual=True, options=options)
            display(HTML(html))
        else:
            pass

        if len(event_result_dfs) > 0:
            event_tag = True
            result_df = pd.concat(event_result_dfs).reset_index(drop = True)
            result_df = result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)
            sub_2_result_df = result_df.copy()
            sub_2_result_df['Topic'] = ['SETTING EVENT'] * len(result_df)
            sub_2_result_df = sub_2_result_df[['Phrase', 'Topic', 'Score']].drop_duplicates().reset_index(drop=True)
            null_df = get_null_class_df(sents, sub_2_result_df)
            if len(null_df) > 0:
                sub_2_result_df = pd.concat([sub_2_result_df, null_df]).reset_index(drop=True)
            sub_2_result_dfs.append(sub_2_result_df)
            agg_df = result_df.groupby(result_df.Phrase).max()
            agg_df['Phrase'] = agg_df.index
            agg_df = agg_df.reset_index(drop=True)
            agg_df = agg_df.drop(columns=['Glossary'])
            result_df = pd.merge(result_df, agg_df, 'inner', ['Phrase', 'Score'])
            result_df = result_df[['Phrase', 'Glossary', 'Score']]
            result_df = result_df.drop_duplicates()
            result_df = result_df.set_index('Phrase')
            #display(result_df)
            display(event_color(result_df))
        
        #setfit trig
        query = trig_text_input.value
        sentences = extract_sentences(query)
        cl_sentences = preprocess(sentences)
        topic_inds = get_sf_trig_topic(cl_sentences)
        topics = [ind_trig_topic_dict[i] for i in topic_inds]
        scores = get_sf_trig_topic_scores(cl_sentences)
        sf_trig_result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})
        sf_trig_sub_result_df = sf_trig_result_df[sf_trig_result_df['Topic'] == 'TRIGGER']
        sub_2_result_df = sf_trig_result_df[sf_trig_result_df['Topic'] == 'NO TRIGGER']
        sub_2_result_df = pd.concat([sub_2_result_df, sf_trig_sub_result_df]).reset_index(drop=True)
        sub_2_result_dfs.append(sub_2_result_df)
        sf_trig_highlights = []
        sf_trig_ents = []
        colors = {}
        if len(sf_trig_sub_result_df) > 0:
            sf_trig_highlights = sf_trig_sub_result_df['Phrase'].tolist()
            sf_trig_highlight_topics = sf_trig_sub_result_df['Topic'].tolist()
            sf_trig_highlight_scores = sf_trig_sub_result_df['Score'].tolist()    
            sf_trig_ents = sf_annotate_query(sf_trig_highlights, query, sf_trig_highlight_topics)
            for ent, hs in zip(sf_trig_ents, sf_trig_highlight_scores):
                if hs >= passing_score:
                    colors[ent['label']] = '#ADD8E6'
                else:
                    colors[ent['label']] = '#FFCC66'
            options = {"ents": list(colors), "colors": colors}
            ex = [{"text": query,
                   "ents": sf_trig_ents,
                   "title": None}]
            if len(sf_trig_ents) > 0:
                title = "Trigger Phrases"
                display(HTML(f'<center><h1>{title}</h1></center>'))
                html = displacy.render(ex, style="ent", manual=True, options=options)
                display(HTML(html))
        else:
            pass
            
        if len(sf_trig_sub_result_df) > 0:
            trig_tag = True
            result_df = sf_trig_sub_result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)
            result_df = result_df.set_index('Phrase')
            display(sf_trig_color(result_df))
            
        #setfit cons
        query = cons_text_input.value
        sentences = extract_sentences(query)
        cl_sentences = preprocess(sentences)
        topic_inds = get_sf_cons_topic(cl_sentences)
        topics = [ind_cons_topic_dict[i] for i in topic_inds]
        scores = get_sf_cons_topic_scores(cl_sentences)
        sf_cons_result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})
        sf_cons_sub_result_df = sf_cons_result_df[sf_cons_result_df['Topic'] == 'CONSEQUENCE']
        sub_2_result_df = sf_cons_result_df[sf_cons_result_df['Topic'] == 'NO CONSEQUENCE']
        sub_2_result_df = pd.concat([sub_2_result_df, sf_cons_sub_result_df]).reset_index(drop=True)
        sub_2_result_dfs.append(sub_2_result_df)
        sf_cons_highlights = []
        sf_cons_ents = []
        colors = {}
        if len(sf_cons_sub_result_df) > 0:
            sf_cons_highlights = sf_cons_sub_result_df['Phrase'].tolist()
            sf_cons_highlight_topics = sf_cons_sub_result_df['Topic'].tolist()
            sf_cons_highlight_scores = sf_cons_sub_result_df['Score'].tolist()    
            sf_cons_ents = sf_annotate_query(sf_cons_highlights, query, sf_cons_highlight_topics)
            for ent, hs in zip(sf_cons_ents, sf_cons_highlight_scores):
                if hs >= passing_score:
                    colors[ent['label']] = '#F08080'
                else:
                    colors[ent['label']] = '#FFCC66'
            options = {"ents": list(colors), "colors": colors}
            ex = [{"text": query,
                   "ents": sf_cons_ents,
                   "title": None}]
            if len(sf_cons_ents) > 0:
                title = "Consequence Phrases"
                display(HTML(f'<center><h1>{title}</h1></center>'))
                html = displacy.render(ex, style="ent", manual=True, options=options)
                display(HTML(html))
        else:
            pass
        
        if len(sf_cons_sub_result_df) > 0:
            cons_tag = True
            result_df = sf_cons_sub_result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)
            result_df = result_df.set_index('Phrase')
            display(sf_cons_color(result_df))
            
        title = "Final Scores"
        display(HTML(f'<left><h1>{title}</h1></left>'))
        display_final_df([event_tag, trig_tag, bhvr_tag, cons_tag])
        if len(sub_2_result_dfs) > 0:
            sub_2_result_df = pd.concat(sub_2_result_dfs).reset_index(drop=True)
            agrilla_df = sub_2_result_df.copy()

def on_agr_button_next(b):
    global agrilla_df, annotated
    with bhvr_outt:
        clear_output()
        if agrilla_df is not None:
            # convert the dataframe to the structure accepted by argilla
            converted_df = convert_df(agrilla_df)
            # convert pandas dataframe to DatasetForTextClassification
            dataset_rg = rg.DatasetForTextClassification.from_pandas(converted_df)
            # delete the old DatasetForTextClassification from the Argilla web app if exists
            rg.delete(dataset_rg_name)
            # load the new DatasetForTextClassification into the Argilla web app
            rg.log(dataset_rg, name=dataset_rg_name)
            annotated = True
        else:
            display(Markdown("<h2 style='color:red; text-align:center;'>Please score the answer first!</h2>"))
            
def on_eval_button_next(b):
    global annotated
    with bhvr_outt:
        clear_output()
        if annotated:
            data = dict(f1(dataset_rg_name))['data']
            display(custom_f1(data, "Model Evaluation Results"))
        else:
            display(Markdown("<h2 style='color:red; text-align:center;'>Please score the answer and validate the data first!</h2>"))

bhvr_nlp_btn.on_click(on_bhvr_button_next)
bhvr_agr_btn.on_click(on_agr_button_next)
bhvr_eval_btn.on_click(on_eval_button_next)

display(total_box)