{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from args import args, config\n",
    "from items_dataset import items_dataset\n",
    "from torch.utils.data import DataLoader\n",
    "from models import Model_Crf, Model_Softmax\n",
    "from transformers import AutoTokenizer\n",
    "from tqdm import tqdm\n",
    "import prediction\n",
    "import torch\n",
    "import math"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "directory = \"../model/\"\n",
    "model_name = \"roberta_CRF.pt\"\n",
    "device = torch.device('cuda', 0) if torch.cuda.is_available() else torch.device('cpu')\n",
    "model = Model_Crf(config).to(device)\n",
    "model.load_state_dict(state_dict=torch.load(directory + model_name, map_location=device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "input_dict = [{\"span_labels\":[]}]\n",
    "input_dict[0][\"original_text\"] = \"\"\"貓咪犯錯後,以下5種懲罰方法很有效,飼主可以試試樂享網 2021-03-06 繼續閱讀 繼續閱讀 繼續閱讀 繼續閱讀 繼續閱讀 貓咪雖然高冷,但也是會犯錯的,那貓咪犯錯後,怎麼懲罰它才最有效呢?今天就來說一些懲罰貓咪最有效的5個方法!1、把痛感形成條件反射 這裡說的是「痛感」,而不是「暴打」。在貓咪犯錯後,寵主不需要打它,可以彈鼻頭或者是輕拍它的頭頂,給它造成痛感,這樣讓貓咪有一些畏懼心理,知道你在懲罰它。這樣時間長了,貓咪就會形成條件反射,以後就會少犯錯了。  2、大聲呵斥比起打貓,大聲呵斥貓咪會更加有效。因為貓咪對聲音很敏感,它能從主人的語氣中判斷主人的情緒,當大聲呵斥它的時候,它往往會楞一下,這時你繼續大聲呵斥它,那它就會明白你不允許它做這件事,這樣犯錯地方幾率就會減少了。  3、限制自由限制自由說白了,就是把貓咪關進籠子裡。因為貓咪都是很愛外出玩耍,當它犯錯咯,主人可以把它關進籠子裡,不搭理它,讓它自己反思。但要注意,這個方法不能經常用,而且不能把貓咪關進籠子太久。  4、利用水都知道貓咪很怕水的,所以當貓咪犯錯後,寵主也可以利用水來懲罰貓咪,這也是很效果的方法。寵主可以給貓咪臉上或是頭頂噴一些水,從而讓貓知道這個行為是錯誤的,以後就不會再犯了。  5、冷淡雖然貓咪不是很粘主人,但它還是很愛主人的,所以在貓咪犯錯後,寵主也可以採取冷淡的方法來懲罰貓。對貓咪採取不理睬、不靠近、不擁抱等策略,這樣貓咪就會知道自己錯了。當然懲罰的時間不要太長,不然貓咪就會以為你不愛它了。\"\"\"\n",
    "tokenizer = AutoTokenizer.from_pretrained(args.pre_model_name, add_prefix_space=True)\n",
    "prediction_dataset = items_dataset(tokenizer, input_dict, args.label_dict)\n",
    "prediction_loader = DataLoader(prediction_dataset, batch_size=args.batch_size, shuffle=True, collate_fn=prediction_dataset.collate_fn)\n",
    "predict_data = prediction.test_predict(prediction_loader, device, model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prediction.add_sentence_table(predict_data)\n",
    "print(predict_data[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def prepare_span_data(dataset):\n",
    "    \"\"\"prepare spans labels for each sample\"\"\"\n",
    "    for sample in dataset:\n",
    "        spans = items_dataset.cal_agreement_span(None, agreement_table=sample[\"predict_sentence_table\"], min_agree=1, max_agree=2)\n",
    "        sample[\"span_labels\"] = spans\n",
    "        sample[\"original_text\"] = sample[\"text_a\"]\n",
    "        del sample[\"text_a\"]\n",
    "prepare_span_data(predict_data)\n",
    "tokenizer = AutoTokenizer.from_pretrained(args.pre_model_name, add_prefix_space=True)\n",
    "prediction_dataset = items_dataset(tokenizer, predict_data, args.label_dict)\n",
    "prediction_loader = DataLoader(prediction_dataset, batch_size=args.batch_size, shuffle=False, collate_fn=prediction_dataset.collate_fn)\n",
    "\n",
    "index=0\n",
    "print(predict_data[index][\"original_text\"])\n",
    "print(predict_data[index][\"span_labels\"])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "directory = \"../model/\"\n",
    "model_name = \"roberta_softmax.pt\"\n",
    "device = torch.device('cuda', 0) if torch.cuda.is_available() else torch.device('cpu')\n",
    "model = Model_Softmax(config).to(device)\n",
    "model.load_state_dict(state_dict=torch.load(directory + model_name, map_location=device))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rank_spans(test_loader, device, model, reverse=True):\n",
    "  \"\"\"Calculate each span probability by e**(word average log likelihood)\"\"\"\n",
    "  model.eval()\n",
    "  result = []\n",
    "  \n",
    "  for i, test_batch in enumerate(tqdm(test_loader)):\n",
    "    batch_text = test_batch['batch_text']\n",
    "    input_ids = test_batch['input_ids'].to(device)\n",
    "    token_type_ids = test_batch['token_type_ids'].to(device)\n",
    "    attention_mask = test_batch['attention_mask'].to(device)\n",
    "    labels = test_batch['labels']\n",
    "    crf_mask = test_batch[\"crf_mask\"].to(device)\n",
    "    sample_mapping = test_batch[\"overflow_to_sample_mapping\"]\n",
    "    output = model(input_ids=input_ids, token_type_ids=token_type_ids, attention_mask=attention_mask, labels=None, crf_mask=crf_mask)\n",
    "    output = torch.nn.functional.softmax(output[0], dim=-1)\n",
    "    \n",
    "    #make result of every sample\n",
    "    sample_id = 0\n",
    "    sample_result= {\"original_text\" : test_batch['batch_text'][sample_id], \"span_ranked\" : []}\n",
    "    for batch_id in range(len(sample_mapping)):\n",
    "      change_sample = False\n",
    "        \n",
    "      #make sure status\n",
    "      if sample_id != sample_mapping[batch_id]: change_sample = True\n",
    "      if change_sample:\n",
    "        sample_id = sample_mapping[batch_id]\n",
    "        result.append(sample_result)\n",
    "        sample_result= {\"original_text\" : test_batch['batch_text'][sample_id], \"span_ranked\" : []}\n",
    "        \n",
    "      encoded_spans = items_dataset.cal_agreement_span(None, agreement_table=labels[batch_id], min_agree=1, max_agree=2)\n",
    "      #print(encoded_spans)\n",
    "      for encoded_span in encoded_spans:\n",
    "        #calculate span loss\n",
    "        span_lenght = encoded_span[1]-encoded_span[0]\n",
    "        #print(span_lenght)\n",
    "        span_prob_table = torch.log(output[batch_id][encoded_span[0]:encoded_span[1]])\n",
    "        if not change_sample and encoded_span[0]==0 and batch_id!=0: #span cross two tensors\n",
    "          span_loss += span_prob_table[0][1] #Begin\n",
    "        else:\n",
    "          span_loss = span_prob_table[0][1] #Begin\n",
    "        for token_id in range(1, span_prob_table.shape[0]):\n",
    "          span_loss+=span_prob_table[token_id][2] #Inside\n",
    "        span_loss /= span_lenght\n",
    "        \n",
    "        #span decode\n",
    "        decode_start = test_batch[batch_id].token_to_chars(encoded_span[0]+1)[0]\n",
    "        decode_end =  test_batch[batch_id].token_to_chars(encoded_span[1])[0]+1\n",
    "        #print((decode_start, decode_end))\n",
    "        span_text = test_batch['batch_text'][sample_mapping[batch_id]][decode_start:decode_end]\n",
    "        if not change_sample and encoded_span[0]==0 and batch_id!=0: #span cross two tensors\n",
    "          presample = sample_result[\"span_ranked\"].pop(-1)\n",
    "          sample_result[\"span_ranked\"].append([presample[0]+span_text, math.e**float(span_loss)])\n",
    "        else:\n",
    "          sample_result[\"span_ranked\"].append([span_text, math.e**float(span_loss)])\n",
    "    result.append(sample_result)\n",
    "    \n",
    "  #sorted spans by probability\n",
    "  for sample in result:\n",
    "    sample[\"span_ranked\"] = sorted(sample[\"span_ranked\"], key=lambda x:x[1], reverse=reverse)\n",
    "  return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "span_ranked = rank_spans(prediction_loader, device, model)\n",
    "for sample in span_ranked:\n",
    "    print(sample[\"original_text\"])\n",
    "    print(sample[\"span_ranked\"])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "for_project",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "7d6017e34087523a14d1e41a3fef2927de5697dc5dbb9b7906df99909cc5c8a1"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}