{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "e0TsG2okIaQ_",
    "outputId": "742d6ccc-8272-4a14-ef1e-c07710e2bfdb"
   },
   "outputs": [
    {
     "ename": "ModuleNotFoundError",
     "evalue": "No module named 'fastbook'",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mModuleNotFoundError\u001b[0m                       Traceback (most recent call last)",
      "\u001b[0;32m/tmp/ipykernel_268282/1933282452.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mfastbook\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      2\u001b[0m \u001b[0mfastbook\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msetup_book\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'fastbook'"
     ]
    }
   ],
   "source": [
    "import fastbook\n",
    "fastbook.setup_book()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import fastai\n",
    "from fastai import *\n",
    "from fastai.basic_train import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import fastai\n",
    "from fastai.tabular import *\n",
    "from fastai.text import *\n",
    "from fastai.vision import *\n",
    "from fastai import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import gradio as gr"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "h78mKJN7IibS"
   },
   "outputs": [],
   "source": [
    "from fastbook import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YaMYb4UiIqNG"
   },
   "outputs": [],
   "source": [
    "path = Path('gdrive/MyDrive/anime-image-labeller/safebooru')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "IF8LSz3kI1F1"
   },
   "outputs": [],
   "source": [
    "\"\"\"\n",
    "Get the prediction labels and their accuracies, then return the results as a dictionary.\n",
    "\n",
    "[obj] - tensor matrix containing the predicted accuracy given from the model\n",
    "[learn] - fastai learner needed to get the labels\n",
    "[thresh] - minimum accuracy threshold to returning results\n",
    "\"\"\"\n",
    "def get_pred_classes(obj, learn, thresh):\n",
    "    labels = []\n",
    "    # get list of classes from csv--replace\n",
    "    with open('classes.txt', 'r') as f:\n",
    "      for line in f:\n",
    "        labels.append(line.strip('\\n'))\n",
    "\n",
    "    predictions = {}\n",
    "    x=0\n",
    "    for item in obj:\n",
    "        acc= round(item.item(), 3)\n",
    "        if acc > thresh:\n",
    "            predictions[labels[x]] = round(acc, 3)\n",
    "        x+=1\n",
    "\n",
    "    predictions =sorted(predictions.items(), key=lambda x: x[1], reverse=True)\n",
    "\n",
    "    return predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YaVTkhcDSwGl"
   },
   "outputs": [],
   "source": [
    "def get_x(r): return 'images'/r['img_name']\n",
    "def get_y(r): return [t for t in r['tags'].split(' ') if t in pop_tags]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "eN0og22RJ0xW"
   },
   "outputs": [],
   "source": [
    "learn = load_learner('model-large-40e.pkl')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Q8geXEEmJCVz"
   },
   "outputs": [],
   "source": [
    "def predict_single_img(imf, thresh=0.2, learn=learn):\n",
    "  \n",
    "  img = PILImage.create(imf)\n",
    "\n",
    "  #img.show() #show image\n",
    "  _, _, pred_pct = learn.predict(img) #predict while ignoring first 2 array inputs\n",
    "  img.show() #show image\n",
    "  return str(get_pred_classes(pred_pct, learn, thresh))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 227
    },
    "id": "XuwlpTtoKF_G",
    "outputId": "2fefdc83-cb6a-472f-99ed-6f1b3c059c24"
   },
   "outputs": [],
   "source": [
    "predict_single_img('test/midriff.jpg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 643
    },
    "id": "XJsy9FPeG2BI",
    "outputId": "9b6125e9-4b16-47e2-c1ad-d8e7caa3c2fa"
   },
   "outputs": [],
   "source": [
    "iface = gr.Interface(fn=predict_single_img, \n",
    "                     inputs=[\"image\",\"number\"], \n",
    "                     outputs=\"text\")\n",
    "iface.launch()"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "name": "Anime Image Label Inference.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python [conda env:fastai2]",
   "language": "python",
   "name": "conda-env-fastai2-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}