{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" } }, "cells": [ { "cell_type": "markdown", "source": [ "This Notebook is a Stable-diffusion tool which allows you to find similiar tokens from the SD 1.5 vocab.json that you can use for text-to-image generation." ], "metadata": { "id": "L7JTcbOdBPfh" } }, { "cell_type": "code", "source": [ "# Load the tokens into the colab\n", "!git clone https://huggingface.co/datasets/codeShare/sd_tokens\n", "import torch\n", "from torch import linalg as LA\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", "%cd /content/sd_tokens\n", "token = torch.load('sd15_tensors.pt', map_location=device, weights_only=True)\n", "#-----#\n", "\n", "#Import the vocab.json\n", "import json\n", "import pandas as pd\n", "with open('vocab.json', 'r') as f:\n", " data = json.load(f)\n", "\n", "_df = pd.DataFrame({'count': data})['count']\n", "\n", "vocab = {\n", " value: key for key, value in _df.items()\n", "}\n", "#-----#\n", "\n", "# Define functions/constants\n", "NUM_TOKENS = 49407\n", "\n", "def absolute_value(x):\n", " return max(x, -x)\n", "\n", "def similarity(id_A , id_B):\n", " #Tensors\n", " A = token[id_A]\n", " B = token[id_B]\n", " #Tensor vector length (2nd order, i.e (a^2 + b^2 + ....)^(1/2)\n", " _A = LA.vector_norm(A, ord=2)\n", " _B = LA.vector_norm(B, ord=2)\n", " #----#\n", " result = torch.dot(A,B)/(_A*_B)\n", " similarity_pcnt = absolute_value(result.item()*100)\n", " similarity_pcnt_aprox = round(similarity_pcnt, 3)\n", " result = f'{similarity_pcnt_aprox} %'\n", " return result\n", "#----#" ], "metadata": { "id": "Ch9puvwKH1s3" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "print(vocab[8922]) #the vocab item for ID 8922\n", "print(token[8922].shape) #dimension of the token" ], "metadata": { "id": "S_Yh9gH_OUA1" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Get the IDs from a prompt text.\n", "\n", "The prompt will be enclosed with the <|start-of-text|> and <|end-of-text|> tokens" ], "metadata": { "id": "f1-jS7YJApiO" } }, { "cell_type": "code", "source": [ "from transformers import AutoTokenizer\n", "tokenizer = AutoTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\", clean_up_tokenization_spaces = False)\n", "prompt= \"banana\" # @param {type:'string'}\n", "tokenizer_output = tokenizer(text = prompt)\n", "input_ids = tokenizer_output['input_ids']\n", "print(input_ids)\n", "id_A = input_ids[1]\n", "A = token[id_A]\n", "_A = LA.vector_norm(A, ord=2)" ], "metadata": { "id": "RPdkYzT2_X85" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "OPTIONAL : Add/subtract + normalize above result with another token" ], "metadata": { "id": "JKnz0aLFVGXc" } }, { "cell_type": "code", "source": [ "mix_with = \"\" # @param {type:'string'}\n", "mix_method = 'None' # @param [\"None\" , \"Average\", \"Subtract\"] {allow-input: true}\n", "w = 0.5 # @param {type:\"slider\", min:0, max:1, step:0.01}\n", "\n", "\n", "\n", "tokenizer_output = tokenizer(text = mix_with)\n", "input_ids = tokenizer_output['input_ids']\n", "id_C = input_ids[1]\n", "C = token[id_C]\n", "_C = LA.vector_norm(C, ord=2)\n", "\n", "if (mix_method == \"Average\"):\n", " A = w*A + (1-w)*C\n", " _A = LA.vector_norm(A, ord=2)\n", "\n", "if (mix_method == \"Subtract\"):\n", " tmp = w*A - (1-w)*C\n", " _tmp = LA.vector_norm(tmp, ord=2)\n", " A = tmp*((w*_A + (1-w)*_C)/_tmp)\n", " _A = LA.vector_norm(A, ord=2)\n", "\n", "\n" ], "metadata": { "id": "oXbNSRSKPgRr" }, "execution_count": 6, "outputs": [] }, { "cell_type": "markdown", "source": [ "Produce a list id IDs that are most similiar to the prompt ID at positiion 1 based on above result" ], "metadata": { "id": "3uBSZ1vWVCew" } }, { "cell_type": "code", "source": [ "\n", "dots = torch.zeros(NUM_TOKENS)\n", "for index in range(NUM_TOKENS):\n", " id_B = index\n", " B = token[id_B]\n", " _B = LA.vector_norm(B, ord=2)\n", " result = torch.dot(A,B)/(_A*_B)\n", " result = absolute_value(result.item())\n", " dots[index] = result\n", "\n", "sorted, indices = torch.sort(dots,dim=0 , descending=True)\n", "#----#\n", "if (mix_method == \"Average\"):\n", " print(f'Calculated all cosine-similarities between the average of token {vocab[id_A]} and {vocab[id_C]} with ID = {id_A} and mixed ID = {id_C} as a 1x{sorted.shape[0]} tensor')\n", "if (mix_method == \"Subtract\"):\n", " print(f'Calculated all cosine-similarities between the subtract of token {vocab[id_A]} and {vocab[id_C]} with ID = {id_A} and mixed ID = {id_C} as a 1x{sorted.shape[0]} tensor')\n", "if (mix_method == \"None\"):\n", " print(f'Calculated all cosine-similarities between the token {vocab[id_A]} with ID = {id_A} the rest of the {NUM_TOKENS} tokens as a 1x{sorted.shape[0]} tensor')" ], "metadata": { "id": "juxsvco9B0iV" }, "execution_count": null, "outputs": [] }, { "cell_type": "markdown", "source": [ "Print the sorted list from above result" ], "metadata": { "id": "y-Ig3glrVQC3" } }, { "cell_type": "code", "source": [ "list_size = 100 # @param {type:'number'}\n", "\n", "print_ID = False # @param {type:\"boolean\"}\n", "print_Similarity = True # @param {type:\"boolean\"}\n", "print_Name = True # @param {type:\"boolean\"}\n", "print_Divider = True # @param {type:\"boolean\"}\n", "\n", "for index in range(list_size):\n", " id = indices[index].item()\n", " if (print_Name):\n", " print(f'{vocab[id]}') # vocab item\n", " if (print_ID):\n", " print(f'ID = {id}') # IDs\n", " if (print_Similarity):\n", " print(f'similiarity = {round(sorted[index].item()*100,2)} %') # % value\n", " if (print_Divider):\n", " print('--------')" ], "metadata": { "id": "YIEmLAzbHeuo", "outputId": "843fbd7c-b208-49e0-9793-69bb36622c27", "colab": { "base_uri": "https://localhost:8080/" } }, "execution_count": 5, "outputs": [ { "output_type": "stream", "name": "stdout", "text": [ "banana\n", "similiarity = 74.26 %\n", "nude\n", "similiarity = 72.49 %\n", "bananas\n", "similiarity = 30.34 %\n", "nudes\n", "similiarity = 27.19 %\n", "banan\n", "similiarity = 25.08 %\n", "ðŁįĮ\n", "similiarity = 22.27 %\n", "naked\n", "similiarity = 22.12 %\n", "orange\n", "similiarity = 19.53 %\n", "cucumber\n", "similiarity = 17.36 %\n", "nutella\n", "similiarity = 17.33 %\n", "camel\n", "similiarity = 17.22 %\n", "eggplant\n", "similiarity = 17.13 %\n", "swimsuit\n", "similiarity = 16.62 %\n", "chicken\n", "similiarity = 16.38 %\n", "bikini\n", "similiarity = 16.08 %\n", "grape\n", "similiarity = 16.01 %\n", "ballerina\n", "similiarity = 16.01 %\n", "mango\n", "similiarity = 16.0 %\n", "manicure\n", "similiarity = 15.8 %\n", "pencil\n", "similiarity = 15.62 %\n", "yoga\n", "similiarity = 15.56 %\n", "indian\n", "similiarity = 15.51 %\n", "yellow\n", "similiarity = 15.51 %\n", "venus\n", "similiarity = 15.5 %\n", "snake\n", "similiarity = 15.41 %\n", "dunk\n", "similiarity = 15.39 %\n", "ters\n", "similiarity = 15.27 %\n", "underwear\n", "similiarity = 15.26 %\n", "sunbathing\n", "similiarity = 15.15 %\n", "potato\n", "similiarity = 15.04 %\n", "milk\n", "similiarity = 14.91 %\n", "bamboo\n", "similiarity = 14.85 %\n", "selfie\n", "similiarity = 14.85 %\n", "features\n", "similiarity = 14.82 %\n", "know\n", "similiarity = 14.79 %\n", "oilpainting\n", "similiarity = 14.7 %\n", "reas\n", "similiarity = 14.63 %\n", "croissant\n", "similiarity = 14.61 %\n", "oranges\n", "similiarity = 14.59 %\n", "conversation\n", "similiarity = 14.57 %\n", "photoshoot\n", "similiarity = 14.55 %\n", "ery\n", "similiarity = 14.49 %\n", "pear\n", "similiarity = 14.42 %\n", "mcnam\n", "similiarity = 14.42 %\n", "dens\n", "similiarity = 14.38 %\n", "cigarette\n", "similiarity = 14.33 %\n", "tangerine\n", "similiarity = 14.3 %\n", "aluminum\n", "similiarity = 14.28 %\n", "plum\n", "similiarity = 14.28 %\n", "rape\n", "similiarity = 14.24 %\n", "apple\n", "similiarity = 14.2 %\n", "apd\n", "similiarity = 14.17 %\n", "safari\n", "similiarity = 14.09 %\n", "yolo\n", "similiarity = 14.06 %\n", "hoodie\n", "similiarity = 13.96 %\n", "cabaret\n", "similiarity = 13.91 %\n", "superman\n", "similiarity = 13.9 %\n", "saree\n", "similiarity = 13.86 %\n", "mommy\n", "similiarity = 13.78 %\n", "sausage\n", "similiarity = 13.76 %\n", "marshmallow\n", "similiarity = 13.75 %\n", "latex\n", "similiarity = 13.74 %\n", "blonde\n", "similiarity = 13.69 %\n", "champagne\n", "similiarity = 13.62 %\n", "parachute\n", "similiarity = 13.61 %\n", "stor\n", "similiarity = 13.58 %\n", "feminine\n", "similiarity = 13.55 %\n", "ayu\n", "similiarity = 13.5 %\n", "â̼ï¸ı\n", "similiarity = 13.45 %\n", "naked\n", "similiarity = 13.45 %\n", "poop\n", "similiarity = 13.44 %\n", "honeymoon\n", "similiarity = 13.41 %\n", "giraffe\n", "similiarity = 13.37 %\n", "zebra\n", "similiarity = 13.35 %\n", "mud\n", "similiarity = 13.35 %\n", "blanket\n", "similiarity = 13.34 %\n", "silly\n", "similiarity = 13.32 %\n", "animal\n", "similiarity = 13.31 %\n", "malayalam\n", "similiarity = 13.25 %\n", "mustache\n", "similiarity = 13.25 %\n", "mrc\n", "similiarity = 13.24 %\n", "yuri\n", "similiarity = 13.23 %\n", "japanese\n", "similiarity = 13.19 %\n", "gibbs\n", "similiarity = 13.16 %\n", "ðŁĻĤ\n", "similiarity = 13.15 %\n", "rhubarb\n", "similiarity = 13.14 %\n", "trac\n", "similiarity = 13.13 %\n", "polaroid\n", "similiarity = 13.08 %\n", "lunch\n", "similiarity = 13.04 %\n", "sandal\n", "similiarity = 13.03 %\n", "popart\n", "similiarity = 13.02 %\n", "kissing\n", "similiarity = 13.02 %\n", "funeral\n", "similiarity = 13.02 %\n", "runway\n", "similiarity = 13.01 %\n", "milk\n", "similiarity = 12.98 %\n", "tutu\n", "similiarity = 12.96 %\n", "flag\n", "similiarity = 12.95 %\n", "hours\n", "similiarity = 12.95 %\n", "monet\n", "similiarity = 12.91 %\n", "ali\n", "similiarity = 12.89 %\n" ] } ] }, { "cell_type": "markdown", "source": [ "Find the most similiar Tokens for given input" ], "metadata": { "id": "qqZ5DvfLBJnw" } }, { "cell_type": "markdown", "source": [ "Valid ID ranges for id_for_token_A / id_for_token_B are between 0 and 49407" ], "metadata": { "id": "kX72bAuhOtlT" } }, { "cell_type": "code", "source": [ "id_for_token_A = 4567 # @param {type:'number'}\n", "id_for_token_B = 4343 # @param {type:'number'}\n", "\n", "similarity_str = 'The similarity between tokens A and B is ' + similarity(id_for_token_A , id_for_token_B)\n", "\n", "print(similarity_str)" ], "metadata": { "id": "MwmOdC9cNZty" }, "execution_count": null, "outputs": [] } ] }