# NDIS Project - Azure OpenAI - PBSP Scoring - Page 4 - Safety Strategies

In [None]:
import openai
import re
from ipywidgets import interact
import ipywidgets as widgets
from IPython.display import display, clear_output, Javascript, HTML, Markdown
import matplotlib.pyplot as plt
import matplotlib.ticker as mtick
import json
import spacy
from spacy import displacy
from dotenv import load_dotenv
import pandas as pd
import argilla as rg
from argilla.metrics.text_classification import f1
import warnings
warnings.filterwarnings('ignore')
%matplotlib inline
pd.set_option('display.max_rows', 500)
pd.set_option('display.max_colwidth', 10000)
pd.set_option('display.width', 10000)

In [None]:
#initializations
openai.api_key = os.environ['API_KEY']
openai.api_base = os.environ['API_BASE']
openai.api_type = os.environ['API_TYPE']
openai.api_version = os.environ['API_VERSION']
deployment_name = os.environ['DEPLOYMENT_ID']

#argilla
rg.init(
    api_url=os.environ["ARGILLA_API_URL"],
    api_key=os.environ["ARGILLA_API_KEY"]
)

In [None]:
#sentence extraction
def extract_sentences(paragraph):
    symbols = ['\\.', '!', '\\?', ';', ':', ',', '\\_', '\n', '\\-']
    pattern = '|'.join([f'{symbol}' for symbol in symbols])
    sentences = re.split(pattern, paragraph)
    sentences = [sentence.strip() for sentence in sentences if sentence.strip()]
    return sentences

In [None]:
def process_response(response, query):
    sentences = []
    topics = []
    scores = []
    lines = response.strip().split("\n")
    for line in lines:
        if "Safety Strategies:" in line:
            topic = "SAFETY STRATEGY"
        elif "None:" in line:
            topic = "NO STRATEGY"
        else:
            try:
                phrase = line.split("(Confidence Score:")[0].strip()
                score = float(line.split("(Confidence Score:")[1].strip().replace(")", ""))
                sentences.append(phrase)
                topics.append(topic)
                scores.append(score)
            except:
                pass
    result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})
    try:
        result_df['Phrase'] = result_df['Phrase'].str.replace('\d+\.', '', regex=True)
        result_df['Phrase'] = result_df['Phrase'].str.replace('^\s', '', regex=True)
    except:
        sentences = extract_sentences(query)
        topics = ['NO STRATEGY'] * len(sentences)
        scores = [0.9] * len(sentences)
        result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})
    return result_df

In [None]:
def get_prompt(query):
    prompt = f"""
    Given the paragraph below in a behaviour support plan written by a disability practitioner, identify the phrases that represent strategies the practitioner creates to ensure the safety of the person with disability and/or others.

    Paragraph:
    {query}

    All the following requirements must be met:
    - Provide your answer in a numbered list. 
    - All the phrases in your answer must be exact substrings in the original paragraph. without changing any characters.
    - All the upper case and lower case characters in the phrases in your answer must match the upper case and lower case characters in the original paragraph.
    - Start numbering the phrases from number 1.
    - Start your answer for the phrases with the title "Safety Strategies:"
    - For each phrase in your answer, provide a confidence score that ranges between 0.50 and 1.00, where a score of 0.50 indicates you are very weakly confident that the phrase represents strategies the practitioner creates to ensure the safety of the person with disability and/or others, whereas a score of 1.00 indicates you are very strongly confident that the phrase represents strategies the practitioner creates to ensure the safety of the person with disability and/or others.
    - Include another numbered list titled "None:", which includes all the remaining phrases in the paragraph that do not represent strategies the practitioner creates to ensure the safety of the person with disability and/or others.
    - For each phrase that belongs to the "None" category, provide a confidence score that ranges between 0.50 and 1.00, where a score of 0.50 means you are very weakly confident that the sentence belongs to the "None" category, whereas a score of 1.00 means you are very strongly confident that the sentence belongs to the "None" category.
    - There must not be any phrase from the paragraph that is not included in your answer.

    Example Paragraph:
    If Taylor continues to escalate, ensure the safety of all by telling other people in the room to leave immediately, keeping Taylor in your line of sight, position your back to the door and continue to speak. If Taylor begins to attempt to hit staff with his head, commence seclusion protocol. If Taylor does not stop his aggresive behaviour, do not try to ensure his safety.

    Example answer:
    Safety Strategies:
    1. ensure the safety of all by telling other people in the room to leave immediately. (Confidence Score: 0.97)
    2. keeping Taylor in your line of sight. (Confidence Score: 0.85)
    3. position your back to the door and continue to speak. (Confidence Score: 0.87)
    4. commence seclusion protocol. (Confidence Score: 0.93)
    
    None:
    1. If Taylor continues to escalate, (Confidence Score: 0.99)
    2. If Taylor begins to attempt to hit staff with his head, (Confidence Score: 0.97)
    3. If Taylor does not stop his aggresive behaviour, do not try to ensure his safety (Confidence Score: 0.92)
    """
    return prompt

In [None]:
def get_response_chatgpt(prompt):
    response=openai.ChatCompletion.create(   
        engine=deployment_name,   
        messages=[         
        {"role": "system", "content": "You are a helpful assistant."},                  
        {"role": "user", "content": prompt}     
        ],
        temperature=0
    )
    reply = response["choices"][0]["message"]["content"]
    return reply

In [None]:
def convert_df(result_df):
    new_df = pd.DataFrame(columns=['text', 'prediction'])
    new_df['text'] = result_df['Phrase']
    new_df['prediction'] = result_df.apply(lambda row: [[row['Topic'], row['Score']]], axis=1)
    return new_df

In [None]:
topic_color_dict = {
        'SAFETY STRATEGY': '#90EE90',
        'NONE': '#F08080'
    }

def color(df, color):
    return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color=color)

def annotate_query(highlights, query, topics):
    ents = []
    for h, t in zip(highlights, topics):
        ent_dict = {}
        for match in re.finditer(h, query, re.IGNORECASE):
            ent_dict = {"start": match.start(), "end": match.end(), "label": t}
            break
        if len(ent_dict.keys()) > 0:
            ents.append(ent_dict)
    return ents

def path_to_image_html(path):
    return '<img src="'+ path + '" width="30" height="15" />'

passing_score = 0.75
final_passing = 0.0
def display_final_df(agg_df):
    crits = [
            'SAFETY STRATEGY'
        ]
    if not isinstance(agg_df, str):
        tags = []
        orig_crits = crits
        crits = [x for x in crits if x in agg_df.index.tolist()]
        bools = [agg_df.loc[crit, 'Final_Score'] > final_passing for crit in crits]
        paths = ['./thumbs_up.png' if x else './thumbs_down.png' for x in bools]
        df = pd.DataFrame({'Safety Strategy': crits, 'USED': paths})
        rem_crits = [x for x in orig_crits if x not in crits]
        if len(rem_crits) > 0:
            df2 = pd.DataFrame({'Safety Strategy': rem_crits, 'USED': ['./thumbs_down.png'] * len(rem_crits)})
            df = pd.concat([df, df2])
    else:
        df = pd.DataFrame({'Safety Strategy': [crits[0]], 'USED': ['./thumbs_down.png']})
    df = df.set_index('Safety Strategy')
    pd.set_option('display.max_colwidth', None)
    display(HTML('<div style="text-align: center;">' + df.to_html(classes=["align-center"], index=True, escape=False ,formatters=dict(USED=path_to_image_html)) + '</div>'))
    

### Strategies to ensure the safety of the person and/or others

In [None]:
#demo with Voila

bhvr_label = widgets.Label(value='Please type your answer:')
bhvr_text_input = widgets.Textarea(
    value='',
    placeholder='Type your answer',
    description='',
    disabled=False,
    layout={'height': '300px', 'width': '90%'}
)

bhvr_nlp_btn = widgets.Button(
    description='Score Answer',
    disabled=False,
    button_style='success', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Score Answer',
    icon='check',
    layout={'height': '70px', 'width': '250px'}
)
bhvr_agr_btn = widgets.Button(
    description='Validate Data',
    disabled=False,
    button_style='success', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Validate Data',
    icon='check',
    layout={'height': '70px', 'width': '250px'}
)
bhvr_eval_btn = widgets.Button(
    description='Evaluate Model',
    disabled=False,
    button_style='success', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Evaluate Model',
    icon='check',
    layout={'height': '70px', 'width': '250px'}
)
btn_box = widgets.HBox([bhvr_nlp_btn, bhvr_agr_btn, bhvr_eval_btn], 
                       layout={'width': '100%', 'height': '160%'})
bhvr_outt = widgets.Output()
bhvr_outt.layout.height = '100%'
bhvr_outt.layout.width = '100%'
bhvr_box = widgets.VBox([bhvr_text_input, btn_box, bhvr_outt], 
                   layout={'width': '100%', 'height': '160%'})
dataset_rg_name = 'pbsp-page4-safety-argilla-ds'
dataset_rg_url = f'http://localhost:6900/datasets/argilla/{dataset_rg_name}'
agrilla_df = None
annotated = False
def on_bhvr_button_next(b):
    global agrilla_df
    with bhvr_outt:
        clear_output()
        query = bhvr_text_input.value
        prompt = get_prompt(query)
        response = get_response_chatgpt(prompt)
        result_df = process_response(response, query)
        sub_result_df = result_df[(result_df['Score'] >= passing_score) & (result_df['Topic'] != 'NO STRATEGY')]
        sub_2_result_df = result_df[result_df['Topic'] == 'NO STRATEGY']
        highlights = []
        if len(sub_result_df) > 0:
            highlights = sub_result_df['Phrase'].tolist()
            highlight_topics = sub_result_df['Topic'].tolist()    
            ents = annotate_query(highlights, query, highlight_topics)
            colors = {}
            for ent, ht in zip(ents, highlight_topics):
                colors[ent['label']] = topic_color_dict[ht]

            ex = [{"text": query,
                   "ents": ents,
                   "title": None}]
            title = "Safety Strategy Highlights"
            display(HTML(f'<center><h1>{title}</h1></center>'))
            html = displacy.render(ex, style="ent", manual=True, jupyter=True, options={'colors': colors})
            display(HTML(html))
            title = "Safety Strategy Classifications"
            display(HTML(f'<center><h1>{title}</h1></center>'))
            for top in topic_color_dict.keys():
                top_result_df = sub_result_df[sub_result_df['Topic'] == top]
                if len(top_result_df) > 0:
                    top_result_df = top_result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)
                    top_result_df = top_result_df.set_index('Phrase')
                    top_result_df = top_result_df[['Score']]
                    display(HTML(
                        f'<left><h2 style="text-decoration: underline; text-decoration-color:{topic_color_dict[top]};">{top}</h2></left>'))
                    display(color(top_result_df, topic_color_dict[top]))
            
            agg_df = sub_result_df.groupby('Topic')['Score'].sum()
            agg_df = agg_df.to_frame()
            agg_df.index.name = 'Topic'
            agg_df.columns = ['Total Score']
            agg_df = agg_df.assign(
                Final_Score=lambda x: x['Total Score'] / x['Total Score'].sum() * 100.00
            )
            agg_df = agg_df.sort_values(by='Final_Score', ascending=False)
            agg_df['Topic'] = agg_df.index
            rem_topics= [x for x in list(topic_color_dict.keys()) if not x in agg_df.Topic.tolist()]
            if len(rem_topics) > 0:
                rem_agg_df = pd.DataFrame({'Topic': rem_topics, 'Final_Score': 0.0, 'Total Score': 0.0})
                agg_df = pd.concat([agg_df, rem_agg_df])
            title = "Final Scores"
            display(HTML(f'<left><h1>{title}</h1></left>'))
            display_final_df(agg_df)
            if len(sub_2_result_df) > 0:
                sub_result_df = pd.concat([sub_result_df, sub_2_result_df]).reset_index(drop=True)
            agrilla_df = sub_result_df.copy()
        else:
            print(query)
            display_final_df('None')
            if len(sub_2_result_df) > 0:
                agrilla_df = sub_2_result_df.copy()

def on_agr_button_next(b):
    global agrilla_df, annotated
    with bhvr_outt:
        clear_output()
        if agrilla_df is not None:
            # convert the dataframe to the structure accepted by argilla
            converted_df = convert_df(agrilla_df)
            # convert pandas dataframe to DatasetForTextClassification
            dataset_rg = rg.DatasetForTextClassification.from_pandas(converted_df)
            # delete the old DatasetForTextClassification from the Argilla web app if exists
            rg.delete(dataset_rg_name, workspace="admin")
            # load the new DatasetForTextClassification into the Argilla web app
            rg.log(dataset_rg, name=dataset_rg_name, workspace="admin")
            # Make sure all classes are present for annotation
            rg_settings = rg.TextClassificationSettings(label_schema=list(topic_color_dict.keys()))
            rg.configure_dataset(name=dataset_rg_name, workspace="admin", settings=rg_settings)
            annotated = True
        else:
            display(Markdown("<h2 style='color:red; text-align:center;'>Please score the answer first!</h2>"))
            
def on_eval_button_next(b):
    global annotated
    with bhvr_outt:
        clear_output()
        if annotated:
            display(f1(dataset_rg_name).visualize())
        else:
            display(Markdown("<h2 style='color:red; text-align:center;'>Please score the answer and validate the data first!</h2>"))

bhvr_nlp_btn.on_click(on_bhvr_button_next)
bhvr_agr_btn.on_click(on_agr_button_next)
bhvr_eval_btn.on_click(on_eval_button_next)

display(bhvr_label, bhvr_box)