# NDIS Project - PBSP Scoring - Page 2

In [None]:
import os
from ipywidgets import interact
import ipywidgets as widgets
from IPython.display import display, clear_output, Javascript, HTML, Markdown
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import json
import spacy
from spacy import displacy
from flair.data import Corpus
from flair.datasets import CSVClassificationCorpus
from flair.models import TARSClassifier
from flair.data import Sentence
from flair.trainers import ModelTrainer
from sklearn.feature_extraction import text
from pprint import pprint
import re
import pandas as pd
import argilla as rg
from argilla.metrics.text_classification import f1
from argilla.training import ArgillaTrainer
import joblib
import random
from typing import Dict
import warnings
warnings.filterwarnings('ignore')
import logging
logging.getLogger().setLevel(logging.CRITICAL)
logging.getLogger("flair").setLevel(logging.ERROR)
logging.disable(logging.CRITICAL)
logging.basicConfig(level=logging.ERROR)
%matplotlib inline

In [None]:
#initializations
tars_model_path = 'few-shot-model-2'
tars = TARSClassifier().load(tars_model_path+'/best-model.pt')

# argilla
rg.init(
    api_url=os.environ["ARGILLA_API_URL"],
    api_key=os.environ["ARGILLA_API_KEY"]
)

In [None]:
#Text Preprocessing
try:
    nlp = spacy.load('en_core_web_sm')
except OSError:
    spacy.cli.download('en_core_web_sm')
    nlp = spacy.load('en_core_web_sm')
sw_lst = text.ENGLISH_STOP_WORDS
def preprocess(onto_lst):
    cleaned_onto_lst = []
    pattern = re.compile(r'^[a-z ]*$')
    for document in onto_lst:
        text = []
        doc = nlp(document)
        person_tokens = []
        for w in doc:
            if w.ent_type_ == 'PERSON':
                person_tokens.append(w.lemma_)
        for w in doc:
            if not w.is_stop and not w.is_punct and not w.like_num and not len(w.text.strip()) == 0 and not w.lemma_ in person_tokens:
                text.append(w.lemma_.lower())
        texts = [t for t in text if len(t) > 1 and pattern.search(t) is not None and t not in sw_lst]
        cleaned_onto_lst.append(" ".join(texts))
    return cleaned_onto_lst

In [None]:
#sentence extraction
def extract_sentences(query):
    # Compile the regular expression pattern
    pattern = re.compile(r'[.,;!?]')
    # Split the sentences on the punctuation characters
    sentences = [query]
    split_sentences = [pattern.split(sentence) for sentence in sentences]
    # Flatten the list of split sentences
    flat_list = [item for sublist in split_sentences for item in sublist]
    # Remove empty strings from the list
    filtered_sentences = [sentence.strip() for sentence in flat_list if sentence.strip()]
    return filtered_sentences

In [None]:
#query and get predicted topic

p_classes = {'psychiatric_assessment': 0,
           'medical_assessment': 1,
           'no_assessment': 2,
           'speech_and_language_assessment': 3}
def get_topic(sentences):
    preds = []
    for t in sentences:
        sentence = Sentence(t)
        tars.predict(sentence)
        try:
            pred = p_classes[sentence.tag]
        except:
            pred = 2
        preds.append(pred)
    return preds
def get_topic_scores(sentences):
    preds = []
    for t in sentences:
        sentence = Sentence(t)
        tars.predict(sentence)
        try:
            pred = sentence.score
        except:
            pred = 0.75
        preds.append(pred)
    return preds

In [None]:
def convert_df(result_df):
    new_df = pd.DataFrame(columns=['text', 'prediction'])
    new_df['text'] = result_df['Phrase']
    new_df['prediction'] = result_df.apply(lambda row: [[row['Topic'], row['Score']]], axis=1)
    return new_df

In [None]:
def get_trainer_preds(trainer, records):
    sentences = [records[x].text for x in range(0, len(records))]
    topics = [records[x].prediction[0][0] for x in range(0, len(records))]
    scores = [records[x].prediction[0][1] for x in range(0, len(records))]
    result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})
    return result_df

In [None]:
def custom_f1(data: Dict[str, float], title: str):
    from plotly.subplots import make_subplots
    import plotly.colors
    import random

    fig = make_subplots(
        rows=2,
        cols=1,
        subplot_titles=[        "Overall Model Score",        "Model Score By Category",    ],
    )

    x = ['precision', 'recall', 'f1']
    macro_data = [v for k, v in data.items() if "macro" in k]
    fig.add_bar(
        x=x,
        y=macro_data,
        row=1,
        col=1,
    )
    per_label = {
        k: v
        for k, v in data.items()
        if all(key not in k for key in ["macro", "micro", "support"])
    }

    num_labels = int(len(per_label.keys())/3)
    fixed_colors = [str(color) for color in plotly.colors.qualitative.Plotly]
    colors = random.sample(fixed_colors, num_labels)

    fig.add_bar(
        x=[k for k, v in per_label.items()],
        y=[v for k, v in per_label.items()],
        row=2,
        col=1,
        marker_color=[colors[int(i/3)] for i in range(0, len(per_label.keys()))]
    )
    fig.update_layout(showlegend=False, title_text=title)

    return fig

In [None]:
# format output
ind_topic_dict = {
        0: 'PSYCHIATRIC-ASSESSMENT',
        1: 'MEDICAL-ASSESSMENT',
        2: 'NO-ASSESSMENT',
        3: 'SPEECH-AND-LANGUAGE-ASSESSMENT'
    }

topic_color_dict = {
        'PSYCHIATRIC-ASSESSMENT': '#FFCCCC',
        'MEDICAL-ASSESSMENT': '#CCFFFF',
        'SPEECH-AND-LANGUAGE-ASSESSMENT': '#FF69B4'
    }

passing_score = 0.25

def color(df, color):
    return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color=color)

def annotate_query(highlights, query, topics):
    ents = []
    for h, t in zip(highlights, topics):
        ent_dict = {}
        for match in re.finditer(h, query):
            ent_dict = {"start": match.start(), "end": match.end(), "label": t}
            break
        if len(ent_dict.keys()) > 0:
            ents.append(ent_dict)
    return ents

In [None]:
def path_to_image_html(path):
    return '<img src="'+ path + '" width="30" height="15" />'

final_passing = 0.0
def display_final_df(agg_df):
    tags = []
    crits = [
            'PSYCHIATRIC-ASSESSMENT',
            'MEDICAL-ASSESSMENT',
            'SPEECH-AND-LANGUAGE-ASSESSMENT'
            ]
    orig_crits = crits
    crits = [x for x in crits if x in agg_df.index.tolist()]
    bools = [agg_df.loc[crit, 'Final_Score'] > final_passing for crit in crits]
    paths = ['./tick_green.png' if x else './cross_red.png' for x in bools]
    df = pd.DataFrame({'Assessment': crits, 'USED': paths})
    rem_crits = [x for x in orig_crits if x not in crits]
    if len(rem_crits) > 0:
        df2 = pd.DataFrame({'Assessment': rem_crits, 'USED': ['./cross_red.png'] * len(rem_crits)})
        df = pd.concat([df, df2])
    df = df.set_index('Assessment')
    pd.set_option('display.max_colwidth', None)
    display(HTML('<div style="text-align: center;">' + df.to_html(classes=["align-center"], index=True, escape=False ,formatters=dict(USED=path_to_image_html)) + '</div>'))

### Outline any additional non-behavioural assessments undertaken or recent assessments reviewed 

In [None]:
#demo with Voila

bhvr_label = widgets.Label(value='Please type your answer:')
bhvr_text_input = widgets.Textarea(
    value='',
    placeholder='Type your answer',
    description='',
    disabled=False,
    layout={'height': '300px', 'width': '90%'}
)

bhvr_nlp_btn = widgets.Button(
    description='Score Answer',
    disabled=False,
    button_style='success', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Score Answer',
    icon='check',
    layout={'height': '70px', 'width': '250px'}
)
bhvr_agr_btn = widgets.Button(
    description='Validate Data',
    disabled=False,
    button_style='success', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Validate Data',
    icon='check',
    layout={'height': '70px', 'width': '250px'}
)
bhvr_eval_btn = widgets.Button(
    description='Evaluate Model',
    disabled=False,
    button_style='success', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Evaluate Model',
    icon='check',
    layout={'height': '70px', 'width': '250px'}
)
bhvr_trn_btn = widgets.Button(
    description='Re-train Model',
    disabled=False,
    button_style='success', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Re-train Model',
    icon='check',
    layout={'height': '70px', 'width': '250px'}
)
btn_box = widgets.HBox([bhvr_nlp_btn, bhvr_agr_btn, bhvr_eval_btn, bhvr_trn_btn], 
                       layout={'width': '100%', 'height': '160%'})
bhvr_outt = widgets.Output()
bhvr_outt.layout.height = '100%'
bhvr_outt.layout.width = '100%'
bhvr_box = widgets.VBox([bhvr_text_input, btn_box, bhvr_outt], 
                   layout={'width': '100%', 'height': '160%'})
dataset_rg_name = 'pbsp-page2-q2-argilla-ds'
trainer_rg_name = 'pbsp-page2-q2-argilla-trn'
agrilla_df = None
dataset_rg = None
annotated = False
trainer = None
def on_bhvr_button_next(b):
    global agrilla_df, trainer, dataset_rg
    with bhvr_outt:
        clear_output()
        query = bhvr_text_input.value
        sentences = extract_sentences(query)
        if trainer is not None:
            records = trainer.predict(dataset_rg.to_pandas()['text'].tolist(), as_argilla_records=True)
            result_df = get_trainer_preds(trainer, records)
        else:
            cl_sentences = preprocess(sentences)
            topic_inds = get_topic(cl_sentences)
            topics = [ind_topic_dict[i] for i in topic_inds]
            scores = get_topic_scores(cl_sentences)
            result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})
        sub_result_df = result_df[(result_df['Score'] >= passing_score) & (result_df['Topic'] != 'NO-ASSESSMENT')]
        sub_2_result_df = result_df[result_df['Topic'] == 'NO-ASSESSMENT']
        highlights = []
        if len(sub_result_df) > 0:
            highlights = sub_result_df['Phrase'].tolist()
            highlight_topics = sub_result_df['Topic'].tolist()    
            ents = annotate_query(highlights, query, highlight_topics)
            colors = {}
            for ent, ht in zip(ents, highlight_topics):
                colors[ent['label']] = topic_color_dict[ht]

            ex = [{"text": query,
                   "ents": ents,
                   "title": None}]
            title = "Highlighting Assessments"
            display(HTML(f'<center><h1>{title}</h1></center>'))
            html = displacy.render(ex, style="ent", manual=True, jupyter=True, options={'colors': colors})
            display(HTML(html))
            title = "Assessment Classifications"
            display(HTML(f'<center><h1>{title}</h1></center>'))
            for top in topic_color_dict.keys():
                top_result_df = sub_result_df[sub_result_df['Topic'] == top]
                if len(top_result_df) > 0:
                    top_result_df = top_result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)
                    top_result_df = top_result_df.set_index('Phrase')
                    top_result_df = top_result_df[['Score']]
                    display(HTML(
                        f'<left><h2 style="text-decoration: underline; text-decoration-color:{topic_color_dict[top]};">{top}</h2></left>'))
                    display(color(top_result_df, topic_color_dict[top]))
            
            agg_df = sub_result_df.groupby('Topic')['Score'].sum()
            agg_df = agg_df.to_frame()
            agg_df.index.name = 'Topic'
            agg_df.columns = ['Total Score']
            agg_df = agg_df.assign(
                Final_Score=lambda x: x['Total Score'] / x['Total Score'].sum() * 100.00
            )
            agg_df = agg_df.sort_values(by='Final_Score', ascending=False)
            title = "Assessment Coverage"
            display(HTML(f'<center><h1>{title}</h1></center>'))
            agg_df['Topic'] = agg_df.index
            rem_topics= [x for x in list(topic_color_dict.keys()) if not x in agg_df.Topic.tolist()]
            if len(rem_topics) > 0:
                rem_agg_df = pd.DataFrame({'Topic': rem_topics, 'Final_Score': 0.0, 'Total Score': 0.0})
                agg_df = pd.concat([agg_df, rem_agg_df])
            labels = agg_df['Final_Score'].round(1).astype('str') + '%'
            ax = agg_df.plot.bar(x='Topic', y='Final_Score', rot=0, figsize=(20, 5), align='center')
            for container in ax.containers:
                ax.bar_label(container, labels=labels)
                ax.yaxis.set_major_formatter(mtick.PercentFormatter())
                ax.legend(["Final Score (%)"])
                ax.set_xlabel('')
            plt.show()
            title = "Final Assessments Scores"
            display(HTML(f'<left><h1>{title}</h1></left>'))
            display_final_df(agg_df)
            if len(sub_2_result_df) > 0:
                sub_result_df = pd.concat([sub_result_df, sub_2_result_df]).reset_index(drop=True)
            agrilla_df = sub_result_df.copy()
        else:
            print(query)

def on_agr_button_next(b):
    global agrilla_df, annotated, dataset_rg
    with bhvr_outt:
        clear_output()
        if agrilla_df is not None:
            # convert the dataframe to the structure accepted by argilla
            converted_df = convert_df(agrilla_df)
            # convert pandas dataframe to DatasetForTextClassification
            dataset_rg = rg.DatasetForTextClassification.from_pandas(converted_df)
            # delete the old DatasetForTextClassification from the Argilla web app if exists
            rg.delete(dataset_rg_name, workspace="admin")
            # load the new DatasetForTextClassification into the Argilla web app
            rg.log(dataset_rg, name=dataset_rg_name, workspace="admin")
            # Make sure all classes are present for annotation
            rg_settings = rg.TextClassificationSettings(label_schema=list(ind_topic_dict.values()))
            rg.configure_dataset(name=dataset_rg_name, workspace="admin", settings=rg_settings)
            annotated = True
        else:
            display(Markdown("<h2 style='color:red; text-align:center;'>Please score the answer first!</h2>"))
            
def on_eval_button_next(b):
    global annotated
    with bhvr_outt:
        clear_output()
        if annotated:
            data = dict(f1(dataset_rg_name))['data']
            display(custom_f1(data, "Model Evaluation Results"))
        else:
            display(Markdown("<h2 style='color:red; text-align:center;'>Please score the answer and validate the data first!</h2>"))

def on_trn_button_next(b):
    global annotated, trainer
    with bhvr_outt:
        clear_output()
        if annotated:
            trainer = ArgillaTrainer(
                name=dataset_rg_name,
                workspace="admin",
                framework="setfit",
                train_size=1.0
            )
            trainer.update_config(
                pretrained_model_name_or_path = "all-mpnet-base-v2",
                force_download = False,
                resume_download = False,
                proxies = None,
                token = None,
                cache_dir = None,
                local_files_only = False,
                num_iterations=10
            )
            trainer.train(output_dir=trainer_rg_name)
            
        else:
            display(Markdown("<h2 style='color:red; text-align:center;'>Please score the answer and validate the data first!</h2>"))
            
bhvr_nlp_btn.on_click(on_bhvr_button_next)
bhvr_agr_btn.on_click(on_agr_button_next)
bhvr_eval_btn.on_click(on_eval_button_next)
bhvr_trn_btn.on_click(on_trn_button_next)

display(bhvr_label, bhvr_box)