File size: 3,352 Bytes
43fda8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "m_Oa5BjYS_UF"
      },
      "source": [
        "# Semantic search with FAISS (TensorFlow)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sjgafp3ZS_UH"
      },
      "outputs": [],
      "source": [
        "!pip install datasets evaluate transformers[sentencepiece]\n",
        "!pip install faiss-cpu"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "UWljEBCWS_UR"
      },
      "outputs": [],
      "source": [
        "import pandas as pd\n",
        "from datasets import load_from_disk\n",
        "from transformers import AutoTokenizer, TFAutoModel\n",
        "\n",
        "Drugs = ['Acne', 'Adhd', 'Allergies', 'Anaemia', 'Angina', 'Appetite',\n",
        "       'Arthritis', 'Constipation', 'Contraception', 'Dandruff',\n",
        "       'Diabetes', 'Digestion', 'Fever', 'Fungal', 'General', 'Glaucoma',\n",
        "       'Gout', 'Haematopoiesis', 'Haemorrhoid', 'Hyperpigmentation',\n",
        "       'Hypertension', 'Hyperthyroidism', 'Hypnosis', 'Hypothyroidism',\n",
        "       'Infection', 'Migraine', 'Osteoporosis', 'Pain', 'Psychosis',\n",
        "       'Schizophrenia', 'Supplement', 'Thrombolysis', 'Viral', 'Wound']\n",
        "\n",
        "model_ckpt = \"sentence-transformers/multi-qa-mpnet-base-dot-v1\"\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
        "model = TFAutoModel.from_pretrained(model_ckpt, from_pt=True)\n",
        "\n",
        "def cls_pooling(model_output):\n",
        "    return model_output.last_hidden_state[:, 0]\n",
        "\n",
        "def get_embeddings(text_list):\n",
        "    encoded_input = tokenizer(\n",
        "        text_list, padding=True, truncation=True, return_tensors=\"tf\"\n",
        "    )\n",
        "    encoded_input = {k: v for k, v in encoded_input.items()}\n",
        "    model_output = model(**encoded_input)\n",
        "    return cls_pooling(model_output)\n",
        "\n",
        "\n",
        "embeddings_dataset = load_from_disk(\"/content/drive/MyDrive/Drugs\")\n",
        "embeddings_dataset.add_faiss_index(column=\"embeddings\")\n",
        "\n",
        "def recommendations(question):\n",
        "  question_embedding = get_embeddings([question]).numpy()\n",
        "  scores, samples = embeddings_dataset.get_nearest_examples(\n",
        "      \"embeddings\", question_embedding, k=5\n",
        "  )\n",
        "  samples_df = pd.DataFrame.from_dict(samples)\n",
        "  samples_df[\"scores\"] = scores\n",
        "  samples_df.sort_values(\"sc>ores\", ascending=False, inplace=True)\n",
        "  return samples_df[['Drug_Name', 'Reason', 'scores']]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fGRtwhvTcZ9t"
      },
      "outputs": [],
      "source": [
        "question = \"moderate acne\"\n",
        "recommendations(question)"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}