File size: 3,352 Bytes
43fda8b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "m_Oa5BjYS_UF"
},
"source": [
"# Semantic search with FAISS (TensorFlow)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "sjgafp3ZS_UH"
},
"outputs": [],
"source": [
"!pip install datasets evaluate transformers[sentencepiece]\n",
"!pip install faiss-cpu"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "UWljEBCWS_UR"
},
"outputs": [],
"source": [
"import pandas as pd\n",
"from datasets import load_from_disk\n",
"from transformers import AutoTokenizer, TFAutoModel\n",
"\n",
"Drugs = ['Acne', 'Adhd', 'Allergies', 'Anaemia', 'Angina', 'Appetite',\n",
" 'Arthritis', 'Constipation', 'Contraception', 'Dandruff',\n",
" 'Diabetes', 'Digestion', 'Fever', 'Fungal', 'General', 'Glaucoma',\n",
" 'Gout', 'Haematopoiesis', 'Haemorrhoid', 'Hyperpigmentation',\n",
" 'Hypertension', 'Hyperthyroidism', 'Hypnosis', 'Hypothyroidism',\n",
" 'Infection', 'Migraine', 'Osteoporosis', 'Pain', 'Psychosis',\n",
" 'Schizophrenia', 'Supplement', 'Thrombolysis', 'Viral', 'Wound']\n",
"\n",
"model_ckpt = \"sentence-transformers/multi-qa-mpnet-base-dot-v1\"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
"model = TFAutoModel.from_pretrained(model_ckpt, from_pt=True)\n",
"\n",
"def cls_pooling(model_output):\n",
" return model_output.last_hidden_state[:, 0]\n",
"\n",
"def get_embeddings(text_list):\n",
" encoded_input = tokenizer(\n",
" text_list, padding=True, truncation=True, return_tensors=\"tf\"\n",
" )\n",
" encoded_input = {k: v for k, v in encoded_input.items()}\n",
" model_output = model(**encoded_input)\n",
" return cls_pooling(model_output)\n",
"\n",
"\n",
"embeddings_dataset = load_from_disk(\"/content/drive/MyDrive/Drugs\")\n",
"embeddings_dataset.add_faiss_index(column=\"embeddings\")\n",
"\n",
"def recommendations(question):\n",
" question_embedding = get_embeddings([question]).numpy()\n",
" scores, samples = embeddings_dataset.get_nearest_examples(\n",
" \"embeddings\", question_embedding, k=5\n",
" )\n",
" samples_df = pd.DataFrame.from_dict(samples)\n",
" samples_df[\"scores\"] = scores\n",
" samples_df.sort_values(\"sc>ores\", ascending=False, inplace=True)\n",
" return samples_df[['Drug_Name', 'Reason', 'scores']]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "fGRtwhvTcZ9t"
},
"outputs": [],
"source": [
"question = \"moderate acne\"\n",
"recommendations(question)"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
|