{ "cells": [ { "cell_type": "markdown", "id": "f56cc5ad", "metadata": {}, "source": [ "# NDIS Project - PBSP Scoring - Page 3" ] }, { "cell_type": "code", "execution_count": null, "id": "a8d844ea", "metadata": { "hide_input": false }, "outputs": [], "source": [ "import os\n", "from ipywidgets import interact\n", "import ipywidgets as widgets\n", "from IPython.display import display, clear_output, Javascript, HTML, Markdown\n", "from qdrant_client import QdrantClient\n", "from qdrant_client.http.models import Distance, VectorParams, Batch, Filter, FieldCondition, Range, MatchValue\n", "import json\n", "import spacy\n", "from spacy import displacy\n", "import nltk\n", "from nltk import sent_tokenize\n", "from sklearn.feature_extraction import text\n", "from pprint import pprint\n", "import re\n", "from flair.embeddings import TransformerDocumentEmbeddings\n", "from flair.data import Sentence\n", "from sentence_transformers import SentenceTransformer, util\n", "import pandas as pd\n", "import argilla as rg\n", "from argilla.metrics.text_classification import f1\n", "from typing import Dict\n", "from setfit import SetFitModel\n", "from tqdm import tqdm\n", "import time\n", "for i in tqdm(range(60), disable=True):\n", " time.sleep(1)\n", "import warnings\n", "warnings.filterwarnings('ignore')\n", "%matplotlib inline\n", "pd.set_option('display.max_rows', 500)\n", "pd.set_option('display.max_colwidth', 10000)\n", "pd.set_option('display.width', 10000)" ] }, { "cell_type": "code", "execution_count": null, "id": "96b83a1d", "metadata": {}, "outputs": [], "source": [ "#initializations\n", "embedding = TransformerDocumentEmbeddings('distilbert-base-uncased')\n", "client = QdrantClient(\n", " host=os.environ[\"QDRANT_API_URL\"], \n", " api_key=os.environ[\"QDRANT_API_KEY\"],\n", " timeout=60,\n", " port=443\n", ")\n", "collection_name = \"my_collection\"\n", "model = SentenceTransformer('./sentence-transformers_multi-qa-MiniLM-L6-cos-v1')\n", "vector_dim = 384 #{distilbert-base-uncased: 768, multi-qa-MiniLM-L6-cos-v1:384}\n", "sf_bhvr_model_name = \"setfit-zero-shot-classification-pbsp-p3-bhvr\"\n", "sf_bhvr_model = SetFitModel.from_pretrained(f\"aammari/{sf_bhvr_model_name}\")\n", "sf_sev_model_name = \"setfit-zero-shot-classification-pbsp-p3-sev\"\n", "sf_sev_model = SetFitModel.from_pretrained(f\"aammari/{sf_sev_model_name}\")\n", "\n", "# download nltk 'punkt' if not available\n", "try:\n", " nltk.data.find('tokenizers/punkt')\n", "except LookupError:\n", " nltk.download('punkt')\n", "\n", "# download nltk 'averaged_perceptron_tagger' if not available\n", "try:\n", " nltk.data.find('taggers/averaged_perceptron_tagger')\n", "except LookupError:\n", " nltk.download('averaged_perceptron_tagger')\n", " \n", "#argilla\n", "rg.init(\n", " api_url=os.environ[\"ARGILLA_API_URL\"],\n", " api_key=os.environ[\"ARGILLA_API_KEY\"]\n", ")" ] }, { "cell_type": "markdown", "id": "84add56f", "metadata": { "hide_input": true }, "source": [ "### Domain Expert Section\n", "#### Enter the Topic Glossary" ] }, { "cell_type": "code", "execution_count": null, "id": "17fe501c", "metadata": { "hide_input": false }, "outputs": [], "source": [ "bhvr_onto_lst = [\n", " 'hit employees',\n", " 'push people',\n", " 'throw objects',\n", " 'beat students' \n", "]\n", "bhvr_onto_text_input = widgets.Textarea(\n", " value='\\n'.join(bhvr_onto_lst),\n", " placeholder='Type your answer',\n", " description='',\n", " disabled=False,\n", " layout={'height': '100%', 'width': '90%'}\n", ")\n", "bhvr_onto_label = widgets.Label(value='Behaviours')\n", "bhvr_onto_box = widgets.VBox([bhvr_onto_label, bhvr_onto_text_input], \n", " layout={'width': '400px', 'height': '150px'})" ] }, { "cell_type": "code", "execution_count": null, "id": "7fa6ce86", "metadata": {}, "outputs": [], "source": [ "fh_onto_lst = [\n", " 'Gain the teacher attention',\n", " 'Complete work in class',\n", " 'Avoid difficult work'\n", "]\n", "\n", "fh_onto_text_input = widgets.Textarea(\n", " value='\\n'.join(fh_onto_lst),\n", " placeholder='Type your answer',\n", " description='',\n", " disabled=False,\n", " layout={'height': '100%', 'width': '90%'}\n", ")\n", "fh_onto_label = widgets.Label(value='Functional Hypothesis')\n", "fh_onto_box = widgets.VBox([fh_onto_label, fh_onto_text_input], \n", " layout={'width': '400px', 'height': '150px'})" ] }, { "cell_type": "code", "execution_count": null, "id": "20a1c75c", "metadata": { "scrolled": true }, "outputs": [], "source": [ "rep_onto_lst = [\n", " 'Ask teacher for help',\n", " 'Replace full body slam',\n", " 'Use a next sign'\n", "]\n", "\n", "rep_onto_text_input = widgets.Textarea(\n", " value='\\n'.join(rep_onto_lst),\n", " placeholder='Type your answer',\n", " description='',\n", " disabled=False,\n", " layout={'height': '100%', 'width': '90%'}\n", ")\n", "rep_onto_label = widgets.Label(value='Replacement Behaviour')\n", "rep_onto_box = widgets.VBox([rep_onto_label, rep_onto_text_input], \n", " layout={'width': '400px', 'height': '150px'})\n", "\n", "#onto_boxes = widgets.HBox([bhvr_onto_box, fh_onto_box, rep_onto_box], \n", "# layout={'width': '90%', 'height': '150px'})\n", "\n", "onto_boxes = widgets.HBox([bhvr_onto_box], \n", " layout={'width': '90%', 'height': '150px'})\n", "\n", "display(onto_boxes)" ] }, { "cell_type": "code", "execution_count": null, "id": "72c2c6f9", "metadata": { "hide_input": false }, "outputs": [], "source": [ "#Text Preprocessing\n", "try:\n", " nlp = spacy.load('en_core_web_sm')\n", "except OSError:\n", " spacy.cli.download('en_core_web_sm')\n", " nlp = spacy.load('en_core_web_sm')\n", "sw_lst = text.ENGLISH_STOP_WORDS\n", "def preprocess(onto_lst):\n", " cleaned_onto_lst = []\n", " pattern = re.compile(r'^[a-z ]*$')\n", " for document in onto_lst:\n", " text = []\n", " doc = nlp(document)\n", " person_tokens = []\n", " for w in doc:\n", " if w.ent_type_ == 'PERSON':\n", " person_tokens.append(w.lemma_)\n", " for w in doc:\n", " if not w.is_stop and not w.is_punct and not w.like_num and not len(w.text.strip()) == 0 and not w.lemma_ in person_tokens:\n", " text.append(w.lemma_.lower())\n", " texts = [t for t in text if len(t) > 1 and pattern.search(t) is not None and t not in sw_lst]\n", " cleaned_onto_lst.append(\" \".join(texts))\n", " return cleaned_onto_lst\n", "\n", "cl_bhvr_onto_lst = preprocess(bhvr_onto_lst)\n", "cl_fh_onto_lst = preprocess(fh_onto_lst)\n", "cl_rep_onto_lst = preprocess(rep_onto_lst)\n", "\n", "#pprint(cl_bhvr_onto_lst)\n", "#pprint(cl_fh_onto_lst)\n", "#pprint(cl_rep_onto_lst)" ] }, { "cell_type": "code", "execution_count": null, "id": "a1f934eb", "metadata": {}, "outputs": [], "source": [ "#compute document embeddings\n", "\n", "# distilbert-base-uncased from Flair\n", "def embeddings(cl_onto_lst):\n", " emb_onto_lst = []\n", " for doc in cl_onto_lst:\n", " sentence = Sentence(doc)\n", " embedding.embed(sentence)\n", " emb_onto_lst.append(sentence.embedding.tolist())\n", " return emb_onto_lst\n", "\n", "# multi-qa-MiniLM-L6-cos-v1 from sentence_transformers\n", "def sentence_embeddings(cl_onto_lst):\n", " emb_onto_lst_temp = model.encode(cl_onto_lst)\n", " emb_onto_lst = [x.tolist() for x in emb_onto_lst_temp]\n", " return emb_onto_lst\n", "\n", "'''\n", "emb_bhvr_onto_lst = embeddings(cl_bhvr_onto_lst)\n", "emb_fh_onto_lst = embeddings(cl_fh_onto_lst)\n", "emb_rep_onto_lst = embeddings(cl_rep_onto_lst)\n", "'''\n", "\n", "emb_bhvr_onto_lst = sentence_embeddings(cl_bhvr_onto_lst)\n", "emb_fh_onto_lst = sentence_embeddings(cl_fh_onto_lst)\n", "emb_rep_onto_lst = sentence_embeddings(cl_rep_onto_lst)" ] }, { "cell_type": "code", "execution_count": null, "id": "6302e312", "metadata": { "scrolled": false }, "outputs": [], "source": [ "#add to qdrant collection\n", "def add_to_collection():\n", " global cl_bhvr_onto_lst, emb_bhvr_onto_lst, cl_fh_onto_lst, emb_fh_onto_lst, cl_rep_onto_lst, emb_rep_onto_lst\n", " client.recreate_collection(\n", " collection_name=collection_name,\n", " vectors_config=VectorParams(size=vector_dim, distance=Distance.COSINE),\n", " )\n", " doc_count = len(emb_bhvr_onto_lst) + len(emb_fh_onto_lst) + len(emb_rep_onto_lst)\n", " ids = list(range(1, doc_count+1))\n", " payloads = [{\"ontology\": \"behaviours\", \"phrase\": x} for x in cl_bhvr_onto_lst] + \\\n", " [{\"ontology\": \"functional_hypothesis\", \"phrase\": y} for y in cl_fh_onto_lst] + \\\n", " [{\"ontology\": \"replacement_behaviour\", \"phrase\": z} for z in cl_rep_onto_lst]\n", " vectors = emb_bhvr_onto_lst+emb_fh_onto_lst+emb_rep_onto_lst\n", " client.upsert(\n", " collection_name=f\"{collection_name}\",\n", " points=Batch(\n", " ids=ids,\n", " payloads=payloads,\n", " vectors=vectors\n", " ),\n", " )\n", "\n", "def count_collection():\n", " return len(client.scroll(\n", " collection_name=f\"{collection_name}\"\n", " )[0])\n", "\n", "add_to_collection()\n", "point_count = count_collection()\n", "#print(point_count)" ] }, { "cell_type": "code", "execution_count": null, "id": "b74861d4", "metadata": {}, "outputs": [], "source": [ "query_filter=Filter(\n", " must=[ \n", " FieldCondition(\n", " key='ontology',\n", " match=MatchValue(value=\"functional_hypothesis\")# Condition based on values of `rand_number` field.\n", " )\n", " ]\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "40070313", "metadata": {}, "outputs": [], "source": [ "#verb phrase extraction\n", "def extract_vbs(data_chunked):\n", " for tup in data_chunked:\n", " if len(tup) > 2:\n", " yield(str(\" \".join(str(x[0]) for x in tup)))\n", "\n", "def get_verb_phrases(nltk_query):\n", " data_tok = nltk.word_tokenize(nltk_query) #tokenisation\n", " data_pos = nltk.pos_tag(data_tok) #POS tagging\n", " cfgs = [\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\",\n", " \"CUSTOMCHUNK: {<.*>{0,3}}\"\n", " ]\n", " vbs = []\n", " for cfg_1 in cfgs: \n", " chunker = nltk.RegexpParser(cfg_1)\n", " data_chunked = chunker.parse(data_pos)\n", " vbs += extract_vbs(data_chunked)\n", " return vbs" ] }, { "cell_type": "code", "execution_count": null, "id": "1550437b", "metadata": {}, "outputs": [], "source": [ "#query and get score\n", "\n", "# distilbert-base-uncased from Flair\n", "def get_query_vector(query):\n", " sentence = Sentence(query)\n", " embedding.embed(sentence)\n", " query_vector = sentence.embedding.tolist()\n", " return query_vector\n", "\n", "# multi-qa-MiniLM-L6-cos-v1 from sentence_transformers\n", "def sentence_get_query_vector(query):\n", " query_vector = model.encode(query)\n", " return query_vector\n", "\n", "def search_collection(ontology, query_vector):\n", " query_filter=Filter(\n", " must=[ \n", " FieldCondition(\n", " key='ontology',\n", " match=MatchValue(value=ontology)\n", " )\n", " ]\n", " )\n", " \n", " hits = client.search(\n", " collection_name=f\"{collection_name}\",\n", " query_vector=query_vector,\n", " query_filter=query_filter, \n", " append_payload=True, \n", " limit=point_count \n", " )\n", " return hits\n", "\n", "semantic_passing_score = 0.50\n", "\n", "\n", "#ontology = 'behaviours'\n", "#query = 'punch father face'\n", "#query_vector = sentence_get_query_vector(query)\n", "#hist = search_collection(ontology, query_vector)" ] }, { "cell_type": "code", "execution_count": null, "id": "02fda761", "metadata": {}, "outputs": [], "source": [ "# format output\n", "def color(df):\n", " return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color='#ADD8E6')\n", "\n", "def annotate_query(highlights, query):\n", " ents = []\n", " for h in highlights:\n", " ent_dict = {}\n", " for match in re.finditer(h, query):\n", " ent_dict = {\"start\": match.start(), \"end\": match.end(), \"label\": 'GLOSSARY'}\n", " break\n", " if len(ent_dict.keys()) > 0:\n", " ents.append(ent_dict)\n", " return ents" ] }, { "cell_type": "code", "execution_count": null, "id": "79b519a6", "metadata": {}, "outputs": [], "source": [ "#setfit bhvr sentence extraction\n", "def extract_sentences(nltk_query):\n", " sentences = sent_tokenize(nltk_query)\n", " return sentences" ] }, { "cell_type": "code", "execution_count": null, "id": "0bd4f2f0", "metadata": {}, "outputs": [], "source": [ "def convert_df(result_df):\n", " new_df = pd.DataFrame(columns=['text', 'prediction'])\n", " new_df['text'] = result_df['Phrase']\n", " new_df['prediction'] = result_df.apply(lambda row: [[row['Topic'], row['Score']]], axis=1)\n", " return new_df" ] }, { "cell_type": "code", "execution_count": null, "id": "26ebff60", "metadata": {}, "outputs": [], "source": [ "def custom_f1(data: Dict[str, float], title: str):\n", " from plotly.subplots import make_subplots\n", " import plotly.colors\n", " import random\n", "\n", " fig = make_subplots(\n", " rows=2,\n", " cols=1,\n", " subplot_titles=[ \"Overall Model Score\", \"Model Score By Category\", ],\n", " )\n", "\n", " x = ['precision', 'recall', 'f1']\n", " macro_data = [v for k, v in data.items() if \"macro\" in k]\n", " fig.add_bar(\n", " x=x,\n", " y=macro_data,\n", " row=1,\n", " col=1,\n", " )\n", " per_label = {\n", " k: v\n", " for k, v in data.items()\n", " if all(key not in k for key in [\"macro\", \"micro\", \"support\"])\n", " }\n", "\n", " num_labels = int(len(per_label.keys())/3)\n", " fixed_colors = [str(color) for color in plotly.colors.qualitative.Plotly]\n", " colors = random.sample(fixed_colors, num_labels)\n", "\n", " fig.add_bar(\n", " x=[k for k, v in per_label.items()],\n", " y=[v for k, v in per_label.items()],\n", " row=2,\n", " col=1,\n", " marker_color=[colors[int(i/3)] for i in range(0, len(per_label.keys()))]\n", " )\n", " fig.update_layout(showlegend=False, title_text=title)\n", "\n", " return fig" ] }, { "cell_type": "code", "execution_count": null, "id": "48627758", "metadata": {}, "outputs": [], "source": [ "def get_null_class_df(sentences, result_df):\n", " sents = result_df['Phrase'].tolist()\n", " null_sents = [x for x in sentences if x not in sents]\n", " topics = ['NONE'] * len(null_sents)\n", " scores = [0.90] * len(null_sents)\n", " null_df = pd.DataFrame({'Phrase': null_sents, 'Topic': topics, 'Score': scores})\n", " return null_df" ] }, { "cell_type": "code", "execution_count": null, "id": "0d5f8d3c", "metadata": {}, "outputs": [], "source": [ "#setfit bhvr query and get predicted topic\n", "\n", "def get_sf_bhvr_topic(sentences):\n", " preds = list(sf_bhvr_model(sentences))\n", " return preds\n", "def get_sf_bhvr_topic_scores(sentences):\n", " preds = sf_bhvr_model.predict_proba(sentences)\n", " preds = [max(list(x)) for x in preds]\n", " return preds" ] }, { "cell_type": "code", "execution_count": null, "id": "67bf2154", "metadata": {}, "outputs": [], "source": [ "# setfit bhvr format output\n", "ind_bhvr_topic_dict = {\n", " 0: 'NO BEHAVIOUR',\n", " 1: 'BEHAVIOUR',\n", " }\n", "\n", "highlight_threshold = 0.25\n", "passing_score = 0.50\n", "\n", "def sf_bhvr_color(df):\n", " return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color='#CCFFCC')\n", "\n", "def sf_annotate_query(highlights, query, topics):\n", " ents = []\n", " for h, t in zip(highlights, topics):\n", " ent_dict = {}\n", " for match in re.finditer(h, query):\n", " ent_dict = {\"start\": match.start(), \"end\": match.end(), \"label\": t}\n", " break\n", " if len(ent_dict.keys()) > 0:\n", " ents.append(ent_dict)\n", " return ents" ] }, { "cell_type": "code", "execution_count": null, "id": "a66eaa42", "metadata": {}, "outputs": [], "source": [ "#regex freq query and get predicted topic\n", "\n", "def detect_frequency(sentences):\n", " frequency_patterns = [\n", " r\"(\\d+|(once|twice|thrice))\\s*(time(s)?)?\\s*(per)?\\s*(a|an)?\\s*((minute|hour|day|week|month|year)s?)\\b\",\n", " r\"(\\b\\d+\\b)(\\s*\\btime(s)?\\b)?\\s*\\b(a|an)?\\s*\\b(minute(s)?|hour(s)?|day(s)?|week(s)?|month(s)?|year(s)?|month(s)?\\b)\",\n", " r\"\\b(hourly|daily|weekly|fortnightly|monthly|yearly)\\b\",\n", " r\"(\\d+(\\.\\d+)?(\\s*\\w+)?(\\s+\\w+)?\\s*per\\s*(hr|hour|day|fortnight|month|year))\",\n", " r\"\\b\\d+(\\s+or\\s+\\d+)?\\s+\\w+\\s+(every|each|per)\\s+(a\\s+single\\s+|single\\s+|couple\\s+of\\s+|\\d+\\s+)?(minute(s)?|min(s)?|hour(s)?|hr(s)?|day(s)?|week(s)?|month(s)?|year(s)?|yr(s)?)\\b\",\n", " r\"\\b(one|two|three|four|five|six|seven|eight|nine|ten)\\s+(or\\s+(one|two|three|four|five|six|seven|eight|nine|ten))?\\s+\\w+\\s+(every|each|per)\\s+(a\\s+single\\s+|single\\s+|couple\\s+of\\s+|(one|two|three|four|five|six|seven|eight|nine|ten)\\s+)?(minute(s)?|min(s)?|hour(s)?|hr(s)?|day(s)?|week(s)?|month(s)?|year(s)?|yr(s)?)\\b\",\n", " r\"((once|twice|thrice)\\s*(every|each|per)?\\s*(\\d+)\\s*((minute|hour|day|week|month|year)s?)\\b)\"\n", " ]\n", "\n", " sf_freq_result_df = pd.DataFrame(columns=['Phrase', 'Topic', 'Score'])\n", "\n", " for sentence in sentences:\n", " freq_matches = []\n", " temp_matches = []\n", " for pattern in frequency_patterns:\n", " match = re.search(pattern, sentence, flags=re.IGNORECASE)\n", " if match:\n", " temp_matches.append(match.group(0))\n", " if temp_matches:\n", " freq_matches.append(max(temp_matches, key=len))\n", "\n", " if freq_matches:\n", " sf_freq_result_df = sf_freq_result_df.append({'Phrase': \", \".join(freq_matches),\n", " 'Topic': 'FREQUENCY',\n", " 'Score': 0.75}, ignore_index=True)\n", " else:\n", " sf_freq_result_df = sf_freq_result_df.append({'Phrase': '',\n", " 'Topic': 'NO FREQUENCY',\n", " 'Score': 0.75}, ignore_index=True)\n", "\n", " if len(sf_freq_result_df) > 0:\n", " for i in range(len(sf_freq_result_df)):\n", " phrase = sf_freq_result_df.loc[i, 'Phrase']\n", " if ',' in phrase:\n", " sf_freq_result_df.loc[i, 'Phrase'] = phrase.split(',')[0]\n", " \n", " duration_patterns = [\n", " r\"\\b\\d+\\s*(minute(s)?|hour(s)?|day(s)?|week(s)?|month(s)?|year(s)?)\\b\",\n", " r\"\\bhalf an hour\\b|\\ban hour\\b|\\btwo hours\\b|\\ba day\\b|\\btwo days\\b|\\bthree days\\b|\\ba week\\b|\\btwo weeks\\b|\\bthree weeks\\b|\\ba month\\b|\\btwo months\\b|\\bthree months\\b|\\ba year\\b|\\btwo years\\b|\\bthree years\\b\"\n", " ]\n", "\n", " sf_dur_result_df = pd.DataFrame(columns=['Phrase', 'Topic', 'Score'])\n", "\n", " for sentence in sentences:\n", " dur_matches = []\n", "\n", " for pattern in duration_patterns:\n", " match = re.search(pattern, sentence, flags=re.IGNORECASE)\n", " if match:\n", " dur_matches.append(match.group(0))\n", "\n", " if dur_matches:\n", " sf_dur_result_df = sf_dur_result_df.append({'Phrase': \", \".join(dur_matches),\n", " 'Topic': 'DURATION',\n", " 'Score': 0.75}, ignore_index=True)\n", " else:\n", " sf_dur_result_df = sf_dur_result_df.append({'Phrase': '',\n", " 'Topic': 'NO DURATION',\n", " 'Score': 0.75}, ignore_index=True)\n", "\n", " if len(sf_dur_result_df) > 0:\n", " for i in range(len(sf_dur_result_df)):\n", " phrase = sf_dur_result_df.loc[i, 'Phrase']\n", " if ',' in phrase:\n", " sf_dur_result_df.loc[i, 'Phrase'] = phrase.split(',')[0]\n", " sf_dur_lst = sf_dur_result_df['Phrase'].tolist()\n", " sf_freq_result_df = sf_freq_result_df[~sf_freq_result_df['Phrase'].isin(sf_dur_lst)] \n", "\n", " return sf_freq_result_df" ] }, { "cell_type": "code", "execution_count": null, "id": "91774d6f", "metadata": {}, "outputs": [], "source": [ "# setfit freq format output\n", "ind_freq_topic_dict = {\n", " 0: 'NO FREQUENCY',\n", " 1: 'FREQUENCY',\n", " }\n", "\n", "def sf_freq_color(df):\n", " return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color='#FFFF00')" ] }, { "cell_type": "code", "execution_count": null, "id": "86e8c99b", "metadata": {}, "outputs": [], "source": [ "#regex dur query and get predicted topic\n", "\n", "def detect_duration(sentences, sf_freq_result_df):\n", " duration_patterns = [\n", " r\"\\b\\d+\\s*(minute(s)?|hour(s)?|day(s)?|week(s)?|month(s)?|year(s)?)\\b\",\n", " r\"\\bhalf an hour\\b|\\ban hour\\b|\\btwo hours\\b|\\bthree hours\\b|\\bfour hours\\b|\\bfive hours\\b|\\bsix hours\\b|\\bseven hours\\b|\\beight hours\\b|\\bnine hours\\b|\\bten hours\\b|\\ba minute\\b|\\btwo minutes\\b|\\bthree minutes\\b|\\bfour minutes\\b|\\bfive minutes\\b|\\bsix minutes\\b|\\bseven minutes\\b|\\beight minutes\\b|\\bnine minutes\\b|\\bten minutes\\b|\\ba day\\b|\\btwo days\\b|\\bthree days\\b|\\bfour days\\b|\\bfive days\\b|\\bsix days\\b|\\bseven days\\b|\\beight days\\b|\\bnine days\\b|\\bten days\\b|\\ba week\\b|\\btwo weeks\\b|\\bthree weeks\\b|\\bfour weeks\\b|\\bfive weeks\\b|\\bsix weeks\\b|\\bseven weeks\\b|\\beight weeks\\b|\\bnine weeks\\b|\\bten weeks\\b|\\ba month\\b|\\btwo months\\b|\\bthree months\\b|\\bfour months\\b|\\bfive months\\b|\\bsix months\\b|\\bseven months\\b|\\beight months\\b|\\bnine months\\b|\\bten months\\b|\\ba year\\b|\\btwo years\\b|\\bthree years\\b|\\bfour years\\b|\\bfive years\\b|\\bsix years\\b|\\bseven years\\b|\\beight years\\b|\\bnine years\\b|\\bten years\\b\",\n", " r\"\\b\\d+\\s*(min|mins)\\b\", # e.g., \"5 mins\"\n", " r\"\\b\\d+\\s*(hr|hrs|hour|hours)\\b\", # e.g., \"2 hrs\"\n", " r\"\\b\\d+\\s*(d|day|days)\\b\", # e.g., \"3 days\"\n", " r\"\\b\\d+\\s*(w|week|weeks)\\b\", # e.g., \"4 weeks\"\n", " r\"\\b\\d+\\s*(m|month|months)\\b\", # e.g., \"6 months\"\n", " r\"\\b\\d+\\s*(y|yr|year|years)\\b\", # e.g., \"1 yr\"\n", " r\"\\b(\\d+\\s*(minute(s)?|hour(s)?|day(s)?|week(s)?|month(s)?|year(s)?)\\s*,\\s*){2,}\\d+\\s*(minute(s)?|hour(s)?|day(s)?|week(s)?|month(s)?|year(s)?)\\b\", # e.g., \"2 hours, 30 minutes\"\n", " r\"\\b(half|quarter)\\s+an?\\s+(hour|hr)\\b\", # e.g., \"half an hour\"\n", " r\"\\b(\\d+(?:\\.\\d+)?|\\d+(?:/\\d+))\\s*(hour|hr)s?\\s*(and|&)\\s*(\\d+(?:\\.\\d+)?|\\d+(?:/\\d+))\\s*(minute|min)s?\\b\", # e.g., \"1.5 hours & 30 mins\"\n", " r\"\\b(\\d+)\\s*-\\s*(\\d+)\\s*(minute|min|hour|hr|day|week|month|year)s?\\b\", # e.g., \"5 - 10 mins\"\n", " r\"\\b(more than|less than)\\s*\\d+\\s*(minute(s)?|hour(s)?|day(s)?|week(s)?|month(s)?|year(s)?)\\b\" # e.g., \"more than 3 hours\"\n", " ]\n", "\n", " if len(sf_freq_result_df) > 0:\n", " sf_freq_lst = sf_freq_result_df['Phrase'].tolist()\n", " else:\n", " sf_freq_lst = []\n", " \n", " sf_dur_result_df = pd.DataFrame(columns=['Phrase', 'Topic', 'Score'])\n", "\n", " for sentence in sentences:\n", " dur_matches = []\n", " temp_matches = []\n", " \n", " for phrase in sf_freq_lst:\n", " sentence = sentence.replace(phrase, \"\")\n", " # Remove extra spaces\n", " sentence = ' '.join(sentence.split())\n", "\n", " for pattern in duration_patterns:\n", " match = re.search(pattern, sentence, flags=re.IGNORECASE)\n", " if match:\n", " temp_matches.append(match.group(0))\n", " if temp_matches:\n", " dur_matches.append(max(temp_matches, key=len))\n", "\n", " if dur_matches:\n", " sf_dur_result_df = sf_dur_result_df.append({'Phrase': \", \".join(dur_matches),\n", " 'Topic': 'DURATION',\n", " 'Score': 0.75}, ignore_index=True)\n", " else:\n", " sf_dur_result_df = sf_dur_result_df.append({'Phrase': '',\n", " 'Topic': 'NO DURATION',\n", " 'Score': 0.75}, ignore_index=True)\n", "\n", " if len(sf_dur_result_df) > 0:\n", " for i in range(len(sf_dur_result_df)):\n", " phrase = sf_dur_result_df.loc[i, 'Phrase']\n", " if ',' in phrase:\n", " sf_dur_result_df.loc[i, 'Phrase'] = phrase.split(',')[0]\n", "\n", " return sf_dur_result_df" ] }, { "cell_type": "code", "execution_count": null, "id": "bfc3f615", "metadata": {}, "outputs": [], "source": [ "# setfit dur format output\n", "ind_dur_topic_dict = {\n", " 0: 'NO DURATION',\n", " 1: 'DURATION',\n", " }\n", "\n", "def sf_dur_color(df):\n", " return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color='#DDA0DD')" ] }, { "cell_type": "code", "execution_count": null, "id": "41d468ac", "metadata": {}, "outputs": [], "source": [ "#setfit sev query and get predicted topic\n", "\n", "def get_sf_sev_topic(sentences):\n", " preds = list(sf_sev_model(sentences))\n", " return preds\n", "def get_sf_sev_topic_scores(sentences):\n", " preds = sf_sev_model.predict_proba(sentences)\n", " preds = [max(list(x)) for x in preds]\n", " return preds" ] }, { "cell_type": "code", "execution_count": null, "id": "4dd74f0c", "metadata": {}, "outputs": [], "source": [ "# setfit sev format output\n", "ind_sev_topic_dict = {\n", " 0: 'NO SEVERITY',\n", " 1: 'SEVERITY',\n", " }\n", "\n", "def sf_sev_color(df):\n", " return df.style.format({'Score': '{:,.2%}'.format}).bar(subset=['Score'], color='#FFCCCB')" ] }, { "cell_type": "code", "execution_count": null, "id": "39ba54ec", "metadata": {}, "outputs": [], "source": [ "def path_to_image_html(path):\n", " return ''\n", "\n", "def display_final_df(tags):\n", " crits = [\n", " 'Behaviour',\n", " 'Frequency',\n", " 'Duration',\n", " 'Severity'\n", " ]\n", " descs = [\n", " 'Are all behaviours described in a way that would allow another person to act them out?',\n", " 'Has information been provided about how often the behaviours occur?',\n", " 'Has information been provided about how long the behaviours last for?',\n", " 'Has information been provided about how damaging or destructive the behaviours are?'\n", " ]\n", " paths = ['./thumbs_up.png' if x else './thumbs_down.png' for x in tags]\n", " df = pd.DataFrame({'Criteria': crits, 'Descrption': descs, 'Score': paths})\n", " df = df.set_index('Criteria')\n", " pd.set_option('display.max_colwidth', None)\n", " display(HTML('
' + df.to_html(classes=[\"align-center\"], index=True, escape=False ,formatters=dict(Score=path_to_image_html)) + '
'))" ] }, { "cell_type": "markdown", "id": "2c6e9fe7", "metadata": {}, "source": [ "### Practitioner Section\n", "#### Enter description of behaviours that align with this function. Include frequency, duration, and severity" ] }, { "cell_type": "code", "execution_count": null, "id": "76dd8cab", "metadata": { "scrolled": false }, "outputs": [], "source": [ "#demo with Voila\n", "\n", "bhvr_label = widgets.Label(value='Please type your answer:')\n", "bhvr_text_input = widgets.Textarea(\n", " value='',\n", " placeholder='Type your answer',\n", " description='',\n", " disabled=False,\n", " layout={'height': '300px', 'width': '90%'}\n", ")\n", "\n", "bhvr_nlp_btn = widgets.Button(\n", " description='Score Behaviours',\n", " disabled=False,\n", " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", " tooltip='Score Behaviours',\n", " icon='check',\n", " layout={'height': '70px', 'width': '250px'}\n", ")\n", "bhvr_agr_btn = widgets.Button(\n", " description='Validate Data',\n", " disabled=False,\n", " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", " tooltip='Validate Data',\n", " icon='check',\n", " layout={'height': '70px', 'width': '250px'}\n", ")\n", "bhvr_eval_btn = widgets.Button(\n", " description='Evaluate Model',\n", " disabled=False,\n", " button_style='success', # 'success', 'info', 'warning', 'danger' or ''\n", " tooltip='Evaluate Model',\n", " icon='check',\n", " layout={'height': '70px', 'width': '250px'}\n", ")\n", "btn_box = widgets.HBox([bhvr_nlp_btn, bhvr_agr_btn, bhvr_eval_btn], \n", " layout={'width': '100%', 'height': '160%'})\n", "bhvr_outt = widgets.Output()\n", "bhvr_outt.layout.height = '100%'\n", "bhvr_outt.layout.width = '100%'\n", "bhvr_box = widgets.VBox([bhvr_text_input, btn_box, bhvr_outt], \n", " layout={'width': '100%', 'height': '160%'})\n", "dataset_rg_name = 'pbsp-page3-bhvr-argilla-ds'\n", "agrilla_df = None\n", "annotated = False\n", "sub_2_result_dfs = []\n", "def on_bhvr_button_next(b):\n", " global bhvr_onto_lst, cl_bhvr_onto_lst, emb_bhvr_onto_lst, agrilla_df\n", " with bhvr_outt:\n", " clear_output()\n", " bhvr_onto_lst = bhvr_onto_text_input.value.split(\"\\n\")\n", " cl_bhvr_onto_lst = preprocess(bhvr_onto_lst)\n", " orig_cl_dict = {x:y for x,y in zip(cl_bhvr_onto_lst, bhvr_onto_lst)}\n", " emb_bhvr_onto_lst = sentence_embeddings(cl_bhvr_onto_lst)\n", " add_to_collection()\n", " query = bhvr_text_input.value\n", " vbs = get_verb_phrases(query)\n", " cl_vbs = preprocess(vbs)\n", " emb_vbs = sentence_embeddings(cl_vbs)\n", " vb_ind = -1\n", " highlights = []\n", " highlight_scores = []\n", " result_dfs = []\n", " for query_vector in emb_vbs:\n", " vb_ind += 1\n", " hist = search_collection('behaviours', query_vector)\n", " hist_dict = [dict(x) for x in hist]\n", " scores = [x['score'] for x in hist_dict]\n", " payloads = [orig_cl_dict[x['payload']['phrase']] for x in hist_dict]\n", " result_df = pd.DataFrame({'Score': scores, 'Glossary': payloads})\n", " result_df = result_df[result_df['Score'] >= semantic_passing_score]\n", " if len(result_df) > 0:\n", " highlights.append(vbs[vb_ind])\n", " highlight_scores.append(result_df.Score.max())\n", " result_df['Phrase'] = [vbs[vb_ind]] * len(result_df)\n", " result_df = result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n", " result_dfs.append(result_df)\n", " else:\n", " continue\n", " ents = []\n", " colors = {}\n", " if len(highlights) > 0:\n", " ents = annotate_query(highlights, query)\n", " for ent in ents:\n", " colors[ent['label']] = '#ADD8E6'\n", " \n", " #setfit behaviour\n", " sentences = extract_sentences(query)\n", " cl_sentences = preprocess(sentences)\n", " topic_inds = get_sf_bhvr_topic(cl_sentences)\n", " topics = [ind_bhvr_topic_dict[i] for i in topic_inds]\n", " scores = get_sf_bhvr_topic_scores(cl_sentences)\n", " sf_bhvr_result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})\n", " sf_bhvr_sub_result_df = sf_bhvr_result_df[sf_bhvr_result_df['Topic'] == 'BEHAVIOUR']\n", " sub_2_result_df = sf_bhvr_sub_result_df.copy()\n", " if len(sub_2_result_df) > 0:\n", " sub_2_result_dfs.append(sub_2_result_df)\n", " sf_bhvr_highlights = []\n", " sf_bhvr_ents = []\n", " if len(sf_bhvr_sub_result_df) > 0:\n", " sf_bhvr_highlights = sf_bhvr_sub_result_df['Phrase'].tolist()\n", " sf_bhvr_highlight_topics = sf_bhvr_sub_result_df['Topic'].tolist()\n", " sf_bhvr_highlight_scores = sf_bhvr_sub_result_df['Score'].tolist() \n", " sf_bhvr_ents = sf_annotate_query(sf_bhvr_highlights, query, sf_bhvr_highlight_topics)\n", " for ent, hs in zip(sf_bhvr_ents, sf_bhvr_highlight_scores):\n", " if hs >= passing_score:\n", " colors[ent['label']] = '#CCFFCC'\n", " else:\n", " colors[ent['label']] = '#FFCC66'\n", " options = {\"ents\": list(colors), \"colors\": colors}\n", " if len(sf_bhvr_ents) > 0:\n", " ents = ents + sf_bhvr_ents\n", " \n", " #regex frequency\n", " sf_freq_result_df = detect_frequency(sentences)\n", " sf_freq_sub_result_df = sf_freq_result_df[sf_freq_result_df['Topic'] == 'FREQUENCY']\n", " sub_2_result_df = sf_freq_sub_result_df.copy()\n", " if len(sub_2_result_df) > 0:\n", " sub_2_result_dfs.append(sub_2_result_df)\n", " sf_freq_highlights = []\n", " sf_freq_ents = []\n", " if len(sf_freq_sub_result_df) > 0:\n", " sf_freq_highlights = sf_freq_sub_result_df['Phrase'].tolist()\n", " sf_freq_highlight_topics = sf_freq_sub_result_df['Topic'].tolist()\n", " sf_freq_highlight_scores = sf_freq_sub_result_df['Score'].tolist() \n", " sf_freq_ents = sf_annotate_query(sf_freq_highlights, query, sf_freq_highlight_topics)\n", " for ent, hs in zip(sf_freq_ents, sf_freq_highlight_scores):\n", " if hs >= passing_score:\n", " colors[ent['label']] = '#FFFF00'\n", " else:\n", " colors[ent['label']] = '#FFCC66'\n", " options = {\"ents\": list(colors), \"colors\": colors}\n", " if len(sf_freq_ents) > 0:\n", " ents = ents + sf_freq_ents\n", " \n", " #regex duration\n", " sf_dur_result_df = detect_duration(sentences, sf_freq_result_df)\n", " sf_dur_sub_result_df = sf_dur_result_df[sf_dur_result_df['Topic'] == 'DURATION']\n", " sub_2_result_df = sf_dur_sub_result_df.copy()\n", " if len(sub_2_result_df) > 0:\n", " sub_2_result_dfs.append(sub_2_result_df)\n", " sf_dur_highlights = []\n", " sf_dur_ents = []\n", " if len(sf_dur_sub_result_df) > 0:\n", " sf_dur_highlights = sf_dur_sub_result_df['Phrase'].tolist()\n", " sf_dur_highlight_topics = sf_dur_sub_result_df['Topic'].tolist()\n", " sf_dur_highlight_scores = sf_dur_sub_result_df['Score'].tolist() \n", " sf_dur_ents = sf_annotate_query(sf_dur_highlights, query, sf_dur_highlight_topics)\n", " for ent, hs in zip(sf_dur_ents, sf_dur_highlight_scores):\n", " if hs >= passing_score:\n", " colors[ent['label']] = '#DDA0DD'\n", " else:\n", " colors[ent['label']] = '#FFCC66'\n", " options = {\"ents\": list(colors), \"colors\": colors}\n", " if len(sf_dur_ents) > 0:\n", " ents = ents + sf_dur_ents\n", " \n", " #setfit severity\n", " topic_inds = get_sf_sev_topic(sentences)\n", " topics = [ind_sev_topic_dict[i] for i in topic_inds]\n", " scores = get_sf_sev_topic_scores(sentences)\n", " sf_sev_result_df = pd.DataFrame({'Phrase': sentences, 'Topic': topics, 'Score': scores})\n", " sf_sev_sub_result_df = sf_sev_result_df[sf_sev_result_df['Topic'] == 'SEVERITY']\n", " sub_2_result_df = sf_sev_sub_result_df.copy()\n", " if len(sub_2_result_df) > 0:\n", " sub_2_result_dfs.append(sub_2_result_df)\n", " sf_sev_highlights = []\n", " sf_sev_ents = []\n", " if len(sf_sev_sub_result_df) > 0:\n", " sf_sev_highlights = sf_sev_sub_result_df['Phrase'].tolist()\n", " sf_sev_highlight_topics = sf_sev_sub_result_df['Topic'].tolist()\n", " sf_sev_highlight_scores = sf_sev_sub_result_df['Score'].tolist() \n", " sf_sev_ents = sf_annotate_query(sf_sev_highlights, query, sf_sev_highlight_topics)\n", " for ent, hs in zip(sf_sev_ents, sf_sev_highlight_scores):\n", " if hs >= passing_score:\n", " colors[ent['label']] = '#FFCCCB'\n", " else:\n", " colors[ent['label']] = '#FFCC66'\n", " options = {\"ents\": list(colors), \"colors\": colors}\n", " if len(sf_sev_ents) > 0:\n", " ents = ents + sf_sev_ents\n", " \n", " ex = [{\"text\": query,\n", " \"ents\": ents,\n", " \"title\": None}]\n", " if len(ents) > 0:\n", " title = \"Answer Highlights\"\n", " display(HTML(f'

{title}

'))\n", " html = displacy.render(ex, style=\"ent\", manual=True, options=options)\n", " display(HTML(html))\n", " if len(result_dfs) > 0:\n", " title = \"Subtopics\"\n", " display(HTML(f'

{title}

'))\n", " result_df = pd.concat(result_dfs).reset_index(drop = True)\n", " result_df = result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n", " sub_2_result_df = result_df.copy()\n", " sub_2_result_df['Topic'] = ['BEHAVIOUR'] * len(result_df)\n", " sub_2_result_df = sub_2_result_df[['Phrase', 'Topic', 'Score']].drop_duplicates().reset_index(drop=True)\n", " sub_2_result_dfs.append(sub_2_result_df)\n", " agg_df = result_df.groupby(result_df.Phrase).max()\n", " agg_df['Phrase'] = agg_df.index\n", " agg_df = agg_df.reset_index(drop=True)\n", " agg_df = agg_df.drop(columns=['Glossary'])\n", " result_df = pd.merge(result_df, agg_df, 'inner', ['Phrase', 'Score'])\n", " result_df = result_df[['Phrase', 'Glossary', 'Score']]\n", " result_df = result_df.set_index('Phrase')\n", " display(color(result_df))\n", " bhvr_tag = False\n", " freq_tag = False\n", " dur_tag = False\n", " sev_tag = False\n", " if len(sf_bhvr_sub_result_df) > 0:\n", " bhvr_tag = True\n", " title = \"Relevant Behaviours\"\n", " display(HTML(f'

{title}

'))\n", " result_df = sf_bhvr_sub_result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n", " result_df = result_df.set_index('Phrase')\n", " display(sf_bhvr_color(result_df))\n", " if len(sf_freq_sub_result_df) > 0:\n", " freq_tag = True\n", " title = \"Relevant Frequencies\"\n", " display(HTML(f'

{title}

'))\n", " result_df = sf_freq_sub_result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n", " result_df = result_df.set_index('Phrase')\n", " display(sf_freq_color(result_df))\n", " if len(sf_dur_sub_result_df) > 0:\n", " dur_tag = True\n", " title = \"Relevant Durations\"\n", " display(HTML(f'

{title}

'))\n", " result_df = sf_dur_sub_result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n", " result_df = result_df.set_index('Phrase')\n", " display(sf_dur_color(result_df))\n", " if len(sf_sev_sub_result_df) > 0:\n", " sev_tag = True\n", " title = \"Relevant Severities\"\n", " display(HTML(f'

{title}

'))\n", " result_df = sf_sev_sub_result_df.sort_values(by='Score', ascending=False).reset_index(drop=True)\n", " result_df = result_df.set_index('Phrase')\n", " display(sf_sev_color(result_df))\n", " title = \"Final Scores\"\n", " display(HTML(f'

{title}

'))\n", " display_final_df([bhvr_tag, freq_tag, dur_tag, sev_tag])\n", " if len(sub_2_result_dfs) > 0:\n", " sub_2_result_df = pd.concat(sub_2_result_dfs).reset_index(drop=True)\n", " null_df = get_null_class_df(sentences, sub_2_result_df)\n", " if len(null_df) > 0:\n", " sub_2_result_df = pd.concat([sub_2_result_df, null_df]).reset_index(drop=True)\n", " agrilla_df = sub_2_result_df.copy()\n", "\n", "def on_agr_button_next(b):\n", " global agrilla_df, annotated\n", " with bhvr_outt:\n", " clear_output()\n", " if agrilla_df is not None:\n", " # convert the dataframe to the structure accepted by argilla\n", " converted_df = convert_df(agrilla_df)\n", " # convert pandas dataframe to DatasetForTextClassification\n", " dataset_rg = rg.DatasetForTextClassification.from_pandas(converted_df)\n", " # delete the old DatasetForTextClassification from the Argilla web app if exists\n", " rg.delete(dataset_rg_name, workspace=\"admin\")\n", " # load the new DatasetForTextClassification into the Argilla web app\n", " rg.log(dataset_rg, name=dataset_rg_name, workspace=\"admin\")\n", " # Make sure all classes are present for annotation\n", " rg_settings = rg.TextClassificationSettings(label_schema=['BEHAVIOUR', \n", " 'FREQUENCY', \n", " 'DURATION', \n", " 'SEVERITY', \n", " 'NONE'])\n", " rg.configure_dataset(name=dataset_rg_name, workspace=\"admin\", settings=rg_settings)\n", " annotated = True\n", " else:\n", " display(Markdown(\"

Please score the answer first!

\"))\n", " \n", "def on_eval_button_next(b):\n", " global annotated\n", " with bhvr_outt:\n", " clear_output()\n", " if annotated:\n", " display(f1(dataset_rg_name).visualize())\n", " else:\n", " display(Markdown(\"

Please score the answer and validate the data first!

\"))\n", "\n", "bhvr_nlp_btn.on_click(on_bhvr_button_next)\n", "bhvr_agr_btn.on_click(on_agr_button_next)\n", "bhvr_eval_btn.on_click(on_eval_button_next)\n", "\n", "display(bhvr_label, bhvr_box)" ] }, { "cell_type": "code", "execution_count": null, "id": "a37ad293", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "hide_input": false, "kernelspec": { "display_name": "Python 3.9 (Argilla)", "language": "python", "name": "argilla" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.16" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": true, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": { "height": "calc(100% - 180px)", "left": "10px", "top": "150px", "width": "258.097px" }, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }