{ "cells": [ { "cell_type": "markdown", "id": "f56cc5ad", "metadata": {}, "source": [ "# NDIS Project - Azure OpenAI - PBSP Scoring - Page 4 - Safety Strategies" ] }, { "cell_type": "code", "execution_count": null, "id": "a8d844ea", "metadata": { "hide_input": false }, "outputs": [], "source": [ "import openai\n", "import re\n", "from ipywidgets import interact\n", "import ipywidgets as widgets\n", "from IPython.display import display, clear_output, Javascript, HTML, Markdown\n", "import matplotlib.pyplot as plt\n", "import matplotlib.ticker as mtick\n", "import json\n", "import spacy\n", "from spacy import displacy\n", "from dotenv import load_dotenv\n", "import pandas as pd\n", "import argilla as rg\n", "from argilla.metrics.text_classification import f1\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "%matplotlib inline\n", "pd.set_option('display.max_rows', 500)\n", "pd.set_option('display.max_colwidth', 10000)\n", "pd.set_option('display.width', 10000)" ] }, { "cell_type": "code", "execution_count": null, "id": "96b83a1d", "metadata": {}, "outputs": [], "source": [ "#initializations\n", "openai.api_key = os.environ['API_KEY']\n", "openai.api_base = os.environ['API_BASE']\n", "openai.api_type = os.environ['API_TYPE']\n", "openai.api_version = os.environ['API_VERSION']\n", "deployment_name = os.environ['DEPLOYMENT_ID']\n", "\n", "#argilla\n", "rg.init(\n", " api_url=os.environ[\"ARGILLA_API_URL\"],\n", " api_key=os.environ[\"ARGILLA_API_KEY\"]\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "8934eadb", "metadata": {}, "outputs": [], "source": [ "#sentence extraction\n", "def extract_sentences(paragraph):\n", " symbols = ['\\\\.', '!', '\\\\?', ';', ':', ',', '\\\\_', '\\n', '\\\\-']\n", " pattern = '|'.join([f'{symbol}' for symbol in symbols])\n", " sentences = re.split(pattern, paragraph)\n", " sentences = [sentence.strip() for sentence in sentences if sentence.strip()]\n", " return sentences" ] }, { "cell_type": "code", "execution_count": null, "id": "02fda761", "metadata": {}, "outputs": [], "source": [ "def process_response(response, query):\n", " sentences = []\n", " topics = []\n", " scores = []\n", " lines = response.strip().split(\"\\n\")\n", " for line in lines:\n", " if \"Safety Strategies:\" in line:\n", " topic = \"SAFETY STRATEGY\"\n", " elif \"None:\" in line:\n", " topic = \"NO STRATEGY\"\n", " else:\n", " try:\n", " phrase = line.split(\"(Confidence Score:\")[0].strip()\n", " score = float(line.split(\"(Confidence Score:\")[1].strip().replace(\")\", \"\"))\n", " sentences.append(phrase)\n", " topics.append(topic)\n", " scores.append(score)\n", " except:\n", " pass\n", " result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})\n", " try:\n", " result_df['Phrase'] = result_df['Phrase'].str.replace('\\d+\\.', '', regex=True)\n", " result_df['Phrase'] = result_df['Phrase'].str.replace('^\\s', '', regex=True)\n", " except:\n", " sentences = extract_sentences(query)\n", " topics = ['NO STRATEGY'] * len(sentences)\n", " scores = [0.9] * len(sentences)\n", " result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})\n", " return result_df" ] }, { "cell_type": "code", "execution_count": null, "id": "714fafb4", "metadata": {}, "outputs": [], "source": [ "def get_prompt(query):\n", " prompt = f\"\"\"\n", " Given the paragraph below in a behaviour support plan written by a disability practitioner, identify the phrases that represent strategies the practitioner creates to ensure the safety of the person with disability and/or others.\n", "\n", " Paragraph:\n", " {query}\n", "\n", " All the following requirements must be met:\n", " - Provide your answer in a numbered list. \n", " - All the phrases in your answer must be exact substrings in the original paragraph. without changing any characters.\n", " - All the upper case and lower case characters in the phrases in your answer must match the upper case and lower case characters in the original paragraph.\n", " - Start numbering the phrases from number 1.\n", " - Start your answer for the phrases with the title \"Safety Strategies:\"\n", " - For each phrase in your answer, provide a confidence score that ranges between 0.50 and 1.00, where a score of 0.50 indicates you are very weakly confident that the phrase represents strategies the practitioner creates to ensure the safety of the person with disability and/or others, whereas a score of 1.00 indicates you are very strongly confident that the phrase represents strategies the practitioner creates to ensure the safety of the person with disability and/or others.\n", " - Include another numbered list titled \"None:\", which includes all the remaining phrases in the paragraph that do not represent strategies the practitioner creates to ensure the safety of the person with disability and/or others.\n", " - For each phrase that belongs to the \"None\" category, provide a confidence score that ranges between 0.50 and 1.00, where a score of 0.50 means you are very weakly confident that the sentence belongs to the \"None\" category, whereas a score of 1.00 means you are very strongly confident that the sentence belongs to the \"None\" category.\n", " - There must not be any phrase from the paragraph that is not included in your answer.\n", "\n", " Example Paragraph:\n", " If Taylor continues to escalate, ensure the safety of all by telling other people in the room to leave immediately, keeping Taylor in your line of sight, position your back to the door and continue to speak. If Taylor begins to attempt to hit staff with his head, commence seclusion protocol. If Taylor does not stop his aggresive behaviour, do not try to ensure his safety.\n", "\n", " Example answer:\n", " Safety Strategies:\n", " 1. ensure the safety of all by telling other people in the room to leave immediately. (Confidence Score: 0.97)\n", " 2. keeping Taylor in your line of sight. (Confidence Score: 0.85)\n", " 3. position your back to the door and continue to speak. (Confidence Score: 0.87)\n", " 4. commence seclusion protocol. (Confidence Score: 0.93)\n", " \n", " None:\n", " 1. If Taylor continues to escalate, (Confidence Score: 0.99)\n", " 2. If Taylor begins to attempt to hit staff with his head, (Confidence Score: 0.97)\n", " 3. If Taylor does not stop his aggresive behaviour, do not try to ensure his safety (Confidence Score: 0.92)\n", " \"\"\"\n", " return prompt" ] }, { "cell_type": "code", "execution_count": null, "id": "99da147a", "metadata": {}, "outputs": [], "source": [ "def get_response_chatgpt(prompt):\n", " response=openai.ChatCompletion.create( \n", " engine=deployment_name, \n", " messages=[ \n", " {\"role\": \"system\", \"content\": \"You are a helpful assistant.\"}, \n", " {\"role\": \"user\", \"content\": prompt} \n", " ],\n", " temperature=0\n", " )\n", " reply = response[\"choices\"][0][\"message\"][\"content\"]\n", " return reply" ] }, { "cell_type": "code", "execution_count": null, "id": "56d2bac8", "metadata": {}, "outputs": [], "source": [ "def convert_df(result_df):\n", " new_df = pd.DataFrame(columns=['text', 'prediction'])\n", " new_df['text'] = result_df['Phrase']\n", " new_df['prediction'] = result_df.apply(lambda row: [[row['Topic'], row['Score']]], axis=1)\n", " return new_df" ] }, { "cell_type": "code", "execution_count": null, "id": "905eaf2a", "metadata": {}, "outputs": [], "source": [ "topic_color_dict = {\n", " 'SAFETY STRATEGY': '#90EE90',\n", " 'NONE': '#F08080'\n", " }\n", "\n", "def color(df, color):\n", " return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color=color)\n", "\n", "def annotate_query(highlights, query, topics):\n", " ents = []\n", " for h, t in zip(highlights, topics):\n", " ent_dict = {}\n", " for match in re.finditer(h, query, re.IGNORECASE):\n", " ent_dict = {\"start\": match.start(), \"end\": match.end(), \"label\": t}\n", " break\n", " if len(ent_dict.keys()) > 0:\n", " ents.append(ent_dict)\n", " return ents\n", "\n", "def path_to_image_html(path):\n", " return ''\n", "\n", "passing_score = 0.75\n", "final_passing = 0.0\n", "def display_final_df(agg_df):\n", " crits = [\n", " 'SAFETY STRATEGY'\n", " ]\n", " if not isinstance(agg_df, str):\n", " tags = []\n", " orig_crits = crits\n", " crits = [x for x in crits if x in agg_df.index.tolist()]\n", " bools = [agg_df.loc[crit, 'Final_Score'] > final_passing for crit in crits]\n", " paths = ['./thumbs_up.png' if x else './thumbs_down.png' for x in bools]\n", " df = pd.DataFrame({'Safety Strategy': crits, 'USED': paths})\n", " rem_crits = [x for x in orig_crits if x not in crits]\n", " if len(rem_crits) > 0:\n", " df2 = pd.DataFrame({'Safety Strategy': rem_crits, 'USED': ['./thumbs_down.png'] * len(rem_crits)})\n", " df = pd.concat([df, df2])\n", " else:\n", " df = pd.DataFrame({'Safety Strategy': [crits[0]], 'USED': ['./thumbs_down.png']})\n", " df = df.set_index('Safety Strategy')\n", " pd.set_option('display.max_colwidth', None)\n", " display(HTML('
' + df.to_html(classes=[\"align-center\"], index=True, escape=False ,formatters=dict(USED=path_to_image_html)) + '
'))\n", " " ] }, { "cell_type": "markdown", "id": "2c6e9fe7", "metadata": {}, "source": [ "### Strategies to ensure the safety of the person and/or others" ] }, { "cell_type": "code", "execution_count": null, "id": "76dd8cab", "metadata": { "scrolled": false }, "outputs": [], "source": [ "#demo with Voila\n", "\n", "bhvr_label = widgets.Label(value='Please type your answer:')\n", "bhvr_text_input = widgets.Textarea(\n", " value='',\n", " placeholder='Type your answer',\n", " description='',\n", " disabled=False,\n", " layout={'height': '300px', 'width': '90%'}\n", ")\n", "\n", "bhvr_nlp_btn = widgets.Button(\n", " description='Score Answer',\n", " disabled=False,\n", " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", " tooltip='Score Answer',\n", " icon='check',\n", " layout={'height': '70px', 'width': '250px'}\n", ")\n", "bhvr_agr_btn = widgets.Button(\n", " description='Validate Data',\n", " disabled=False,\n", " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", " tooltip='Validate Data',\n", " icon='check',\n", " layout={'height': '70px', 'width': '250px'}\n", ")\n", "bhvr_eval_btn = widgets.Button(\n", " description='Evaluate Model',\n", " disabled=False,\n", " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", " tooltip='Evaluate Model',\n", " icon='check',\n", " layout={'height': '70px', 'width': '250px'}\n", ")\n", "btn_box = widgets.HBox([bhvr_nlp_btn, bhvr_agr_btn, bhvr_eval_btn], \n", " layout={'width': '100%', 'height': '160%'})\n", "bhvr_outt = widgets.Output()\n", "bhvr_outt.layout.height = '100%'\n", "bhvr_outt.layout.width = '100%'\n", "bhvr_box = widgets.VBox([bhvr_text_input, btn_box, bhvr_outt], \n", " layout={'width': '100%', 'height': '160%'})\n", "dataset_rg_name = 'pbsp-page4-safety-argilla-ds'\n", "dataset_rg_url = f'http://localhost:6900/datasets/argilla/{dataset_rg_name}'\n", "agrilla_df = None\n", "annotated = False\n", "def on_bhvr_button_next(b):\n", " global agrilla_df\n", " with bhvr_outt:\n", " clear_output()\n", " query = bhvr_text_input.value\n", " prompt = get_prompt(query)\n", " response = get_response_chatgpt(prompt)\n", " result_df = process_response(response, query)\n", " sub_result_df = result_df[(result_df['Score'] >= passing_score) & (result_df['Topic'] != 'NO STRATEGY')]\n", " sub_2_result_df = result_df[result_df['Topic'] == 'NO STRATEGY']\n", " highlights = []\n", " if len(sub_result_df) > 0:\n", " highlights = sub_result_df['Phrase'].tolist()\n", " highlight_topics = sub_result_df['Topic'].tolist() \n", " ents = annotate_query(highlights, query, highlight_topics)\n", " colors = {}\n", " for ent, ht in zip(ents, highlight_topics):\n", " colors[ent['label']] = topic_color_dict[ht]\n", "\n", " ex = [{\"text\": query,\n", " \"ents\": ents,\n", " \"title\": None}]\n", " title = \"Safety Strategy Highlights\"\n", " display(HTML(f'

{title}

'))\n", " html = displacy.render(ex, style=\"ent\", manual=True, jupyter=True, options={'colors': colors})\n", " display(HTML(html))\n", " title = \"Safety Strategy Classifications\"\n", " display(HTML(f'

{title}

'))\n", " for top in topic_color_dict.keys():\n", " top_result_df = sub_result_df[sub_result_df['Topic'] == top]\n", " if len(top_result_df) > 0:\n", " top_result_df = top_result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n", " top_result_df = top_result_df.set_index('Phrase')\n", " top_result_df = top_result_df[['Score']]\n", " display(HTML(\n", " f'

{top}

'))\n", " display(color(top_result_df, topic_color_dict[top]))\n", " \n", " agg_df = sub_result_df.groupby('Topic')['Score'].sum()\n", " agg_df = agg_df.to_frame()\n", " agg_df.index.name = 'Topic'\n", " agg_df.columns = ['Total Score']\n", " agg_df = agg_df.assign(\n", " Final_Score=lambda x: x['Total Score'] / x['Total Score'].sum() * 100.00\n", " )\n", " agg_df = agg_df.sort_values(by='Final_Score', ascending=False)\n", " agg_df['Topic'] = agg_df.index\n", " rem_topics= [x for x in list(topic_color_dict.keys()) if not x in agg_df.Topic.tolist()]\n", " if len(rem_topics) > 0:\n", " rem_agg_df = pd.DataFrame({'Topic': rem_topics, 'Final_Score': 0.0, 'Total Score': 0.0})\n", " agg_df = pd.concat([agg_df, rem_agg_df])\n", " title = \"Final Scores\"\n", " display(HTML(f'

{title}

'))\n", " display_final_df(agg_df)\n", " if len(sub_2_result_df) > 0:\n", " sub_result_df = pd.concat([sub_result_df, sub_2_result_df]).reset_index(drop=True)\n", " agrilla_df = sub_result_df.copy()\n", " else:\n", " print(query)\n", " display_final_df('None')\n", " if len(sub_2_result_df) > 0:\n", " agrilla_df = sub_2_result_df.copy()\n", "\n", "def on_agr_button_next(b):\n", " global agrilla_df, annotated\n", " with bhvr_outt:\n", " clear_output()\n", " if agrilla_df is not None:\n", " # convert the dataframe to the structure accepted by argilla\n", " converted_df = convert_df(agrilla_df)\n", " # convert pandas dataframe to DatasetForTextClassification\n", " dataset_rg = rg.DatasetForTextClassification.from_pandas(converted_df)\n", " # delete the old DatasetForTextClassification from the Argilla web app if exists\n", " rg.delete(dataset_rg_name, workspace=\"admin\")\n", " # load the new DatasetForTextClassification into the Argilla web app\n", " rg.log(dataset_rg, name=dataset_rg_name, workspace=\"admin\")\n", " # Make sure all classes are present for annotation\n", " rg_settings = rg.TextClassificationSettings(label_schema=list(topic_color_dict.keys()))\n", " rg.configure_dataset(name=dataset_rg_name, workspace=\"admin\", settings=rg_settings)\n", " annotated = True\n", " else:\n", " display(Markdown(\"

Please score the answer first!

\"))\n", " \n", "def on_eval_button_next(b):\n", " global annotated\n", " with bhvr_outt:\n", " clear_output()\n", " if annotated:\n", " display(f1(dataset_rg_name).visualize())\n", " else:\n", " display(Markdown(\"

Please score the answer and validate the data first!

\"))\n", "\n", "bhvr_nlp_btn.on_click(on_bhvr_button_next)\n", "bhvr_agr_btn.on_click(on_agr_button_next)\n", "bhvr_eval_btn.on_click(on_eval_button_next)\n", "\n", "display(bhvr_label, bhvr_box)" ] }, { "cell_type": "code", "execution_count": null, "id": "a2e51901", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "hide_input": false, "kernelspec": { "display_name": "Python 3.9 (Argilla)", "language": "python", "name": "argilla" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": true, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "258.097px" }, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }