codeShare commited on
Commit
f378257
·
verified ·
1 Parent(s): 8a0aaf3

Upload sd_token_similarity_calculator.ipynb

Browse files
Files changed (1) hide show
  1. sd_token_similarity_calculator.ipynb +172 -69
sd_token_similarity_calculator.ipynb CHANGED
@@ -116,10 +116,28 @@
116
  "metadata": {
117
  "id": "Ch9puvwKH1s3",
118
  "collapsed": true,
119
- "cellView": "form"
 
 
 
 
120
  },
121
- "execution_count": null,
122
- "outputs": []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
  },
124
  {
125
  "cell_type": "code",
@@ -272,56 +290,23 @@
272
  "outputs": []
273
  },
274
  {
275
- "cell_type": "code",
276
  "source": [
277
- "# @title 💫 Compare Text encodings\n",
278
- "\n",
279
- "prompt_A = \"banana\" # @param {\"type\":\"string\",\"placeholder\":\"Write a prompt\"}\n",
280
- "prompt_B = \"\" # @param {\"type\":\"string\",\"placeholder\":\"Write a prompt\"}\n",
281
- "use_token_padding = True # @param {type:\"boolean\"}\n",
282
  "\n",
283
- "from transformers import CLIPProcessor, CLIPModel\n",
284
- "\n",
285
- "processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-large-patch14\" , clean_up_tokenization_spaces = True)\n",
286
  "\n",
287
- "model = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
288
- "\n",
289
- "ids_A = processor.tokenizer(text=prompt_A, padding=use_token_padding, return_tensors=\"pt\")\n",
290
- "text_encoding_A = model.get_text_features(**ids_A)\n",
291
- "\n",
292
- "\n",
293
- "ids_B = processor.tokenizer(text=prompt_B, padding=use_token_padding, return_tensors=\"pt\")\n",
294
- "text_encoding_B = model.get_text_features(**ids_B)\n",
295
- "\n",
296
- "similarity_str = 'The similarity between the text_encoding for A:\"' + prompt_A + '\" and B: \"' + prompt_B +'\" is ' + token_similarity(text_encoding_A[0] , text_encoding_B[0])\n",
297
- "\n",
298
- "\n",
299
- "print(similarity_str)\n",
300
- "#outputs = model(**inputs)\n",
301
- "#logits_per_image = outputs.logits_per_image # this is the image-text similarity score\n",
302
- "#probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities"
303
- ],
304
- "metadata": {
305
- "id": "QQOjh5BvnG8M",
306
- "collapsed": true,
307
- "cellView": "form"
308
- },
309
- "execution_count": null,
310
- "outputs": []
311
- },
312
- {
313
- "cell_type": "markdown",
314
- "source": [
315
- "You can write an url or upload a file locally from your device to use as reference. The image will by saved in the 'sd_tokens' folder. Note that the 'sd_tokens' folder will be deleted upon exiting this runtime."
316
  ],
317
  "metadata": {
318
- "id": "hyK423TQCRup"
319
  }
320
  },
321
  {
322
  "cell_type": "code",
323
  "source": [
324
- "# @title 🪐🖼️ -> 📝 Image to prompt : Add single token to existing prompt to match image\n",
325
  "from google.colab import files\n",
326
  "def upload_files():\n",
327
  " from google.colab import files\n",
@@ -331,7 +316,7 @@
331
  " return list(uploaded.keys())\n",
332
  "#Get image\n",
333
  "# You can use \"http://images.cocodataset.org/val2017/000000039769.jpg\" for testing\n",
334
- "url = \"\" # @param {\"type\":\"string\",\"placeholder\":\"leave empty for local upload (scroll down to see it)\"}\n",
335
  "\n",
336
  "colab_image_path = \"\" # @param {\"type\":\"string\",\"placeholder\":\"(optional) Write colab image path to load from\"}\n",
337
  "from PIL import Image\n",
@@ -369,19 +354,19 @@
369
  "\n",
370
  "# @markdown Set conditions for the output\n",
371
  "must_start_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
372
- "must_contain = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
373
  "must_end_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
374
  "token_B = must_contain\n",
375
  "\n",
376
  "# @markdown Limit the search\n",
377
  "use_token_padding = True # @param {type:\"boolean\"}\n",
378
- "start_search_at_ID = 12500 # @param {type:\"slider\", min:0, max: 49407, step:100}\n",
379
- "search_range = 500 # @param {type:\"slider\", min:0, max: 2000, step:100}\n",
380
- "restrictions = 'Suffix only' # @param [\"None\", \"Suffix only\", \"Prefix only\"]\n",
381
  "\n",
382
  "# @markdown Limit char size of included token\n",
383
- "min_char_size = 3 # @param {type:\"slider\", min:0, max: 50, step:1}\n",
384
- "char_range = 5 # @param {type:\"slider\", min:0, max: 50, step:1}\n",
385
  "\n",
386
  "#Tokenize input B\n",
387
  "from transformers import AutoTokenizer\n",
@@ -397,14 +382,26 @@
397
  "\n",
398
  "dots = torch.zeros(RANGE)\n",
399
  "is_BC = torch.zeros(RANGE)\n",
 
 
 
400
  "for index in range(RANGE):\n",
401
  " id_C = START + index\n",
402
  " C = token[id_C]\n",
403
  " _C = LA.vector_norm(C, ord=2)\n",
404
  " name_C = vocab[id_C]\n",
405
  "\n",
 
 
 
 
 
 
 
 
406
  " # Decide if we should process prefix/suffix tokens\n",
407
  " if name_C.find('</w>')<=-1:\n",
 
408
  " if restrictions != \"Prefix only\":\n",
409
  " continue\n",
410
  " else:\n",
@@ -420,8 +417,8 @@
420
  " #-----#\n",
421
  "\n",
422
  " name_CB = must_start_with + name_C + name_B + must_end_with\n",
423
- " if restrictions == \"Prefix only\":\n",
424
- " name_CB = must_start_with + name_C + '-' + name_B + must_end_with\n",
425
  " #-----#\n",
426
  " ids_CB = processor.tokenizer(text=name_CB, padding=use_token_padding, return_tensors=\"pt\")\n",
427
  " text_encoding_CB = model.get_text_features(**ids_CB)\n",
@@ -469,37 +466,143 @@
469
  "print('')\n",
470
  "print(f'These token pairings within the range ID = {START} to ID = {START + RANGE} most closely match the text_encoding for {prompt_A} : ')\n",
471
  "print('')\n",
472
- "\n",
 
 
 
 
 
 
 
 
 
 
473
  "for index in range(min(list_size,RANGE)):\n",
474
  " id = START + indices[index].item()\n",
475
- " if (print_Name):\n",
476
- " if(is_BC[index]>0):\n",
477
- " print(must_start_with + name_B + vocab[id] + must_end_with)\n",
478
- " else:\n",
479
- " if restrictions == \"Prefix only\":\n",
480
- " print(must_start_with + vocab[id] + '-' + name_B + must_end_with)\n",
481
- " else:\n",
482
- " print(must_start_with + vocab[id] + name_B + must_end_with)\n",
483
- " if (print_ID):\n",
484
- " print(f'ID = {id}') # IDs\n",
485
- " if (print_Similarity):\n",
486
- " print(f'similiarity = {round(sorted[index].item()*100,2)} %')\n",
487
- " if (print_Divider):\n",
488
- " print('--------')\n",
489
  "\n",
 
 
 
 
 
 
 
 
490
  "\n",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
491
  "\n",
 
492
  "\n",
493
- "\n"
 
 
 
 
 
 
 
494
  ],
495
  "metadata": {
496
  "collapsed": true,
497
- "cellView": "form",
498
  "id": "fi0jRruI0-tu"
499
  },
500
  "execution_count": null,
501
  "outputs": []
502
  },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
503
  {
504
  "cell_type": "code",
505
  "source": [
 
116
  "metadata": {
117
  "id": "Ch9puvwKH1s3",
118
  "collapsed": true,
119
+ "cellView": "form",
120
+ "outputId": "aa58503f-8e68-43bf-d73b-3eb877ae10e4",
121
+ "colab": {
122
+ "base_uri": "https://localhost:8080/"
123
+ }
124
  },
125
+ "execution_count": 1,
126
+ "outputs": [
127
+ {
128
+ "output_type": "stream",
129
+ "name": "stdout",
130
+ "text": [
131
+ "Cloning into 'sd_tokens'...\n",
132
+ "remote: Enumerating objects: 10, done.\u001b[K\n",
133
+ "remote: Counting objects: 100% (7/7), done.\u001b[K\n",
134
+ "remote: Compressing objects: 100% (7/7), done.\u001b[K\n",
135
+ "remote: Total 10 (delta 1), reused 0 (delta 0), pack-reused 3 (from 1)\u001b[K\n",
136
+ "Unpacking objects: 100% (10/10), 306.93 KiB | 5.48 MiB/s, done.\n",
137
+ "/content/sd_tokens\n"
138
+ ]
139
+ }
140
+ ]
141
  },
142
  {
143
  "cell_type": "code",
 
290
  "outputs": []
291
  },
292
  {
293
+ "cell_type": "markdown",
294
  "source": [
295
+ "Below image interrogator appends CLIP tokens to either end of the 'must_contain' text , and seeks to maximize similarity with the image encoding.\n",
 
 
 
 
296
  "\n",
297
+ "It takes a long while to check all the tokens (too long!) so this cell only samples a range of the 49K available tokens.\n",
 
 
298
  "\n",
299
+ "You can run this cell, then paste the result into the 'must_contain' box , and then run the cell again.\n",
300
+ "\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
  ],
302
  "metadata": {
303
+ "id": "IUCuV9RtQpBn"
304
  }
305
  },
306
  {
307
  "cell_type": "code",
308
  "source": [
309
+ "# @title 🪐🖼️ -> 📝 Image to prompt : Create suggestions of things to add to prompt to match image\n",
310
  "from google.colab import files\n",
311
  "def upload_files():\n",
312
  " from google.colab import files\n",
 
316
  " return list(uploaded.keys())\n",
317
  "#Get image\n",
318
  "# You can use \"http://images.cocodataset.org/val2017/000000039769.jpg\" for testing\n",
319
+ "url = \"http://images.cocodataset.org/val2017/000000039769.jpg\" # @param {\"type\":\"string\",\"placeholder\":\"leave empty for local upload (scroll down to see it)\"}\n",
320
  "\n",
321
  "colab_image_path = \"\" # @param {\"type\":\"string\",\"placeholder\":\"(optional) Write colab image path to load from\"}\n",
322
  "from PIL import Image\n",
 
354
  "\n",
355
  "# @markdown Set conditions for the output\n",
356
  "must_start_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
357
+ "must_contain = \"banana \" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
358
  "must_end_with = \"\" # @param {\"type\":\"string\",\"placeholder\":\"write a text\"}\n",
359
  "token_B = must_contain\n",
360
  "\n",
361
  "# @markdown Limit the search\n",
362
  "use_token_padding = True # @param {type:\"boolean\"}\n",
363
+ "start_search_at_ID = 27700 # @param {type:\"slider\", min:0, max: 49407, step:100}\n",
364
+ "search_range = 288 # @param {type:\"slider\", min:100, max: 2000, step:0}\n",
365
+ "restrictions = 'None' # @param [\"None\", \"Suffix only\", \"Prefix only\"]\n",
366
  "\n",
367
  "# @markdown Limit char size of included token\n",
368
+ "min_char_size = 3 # @param {type:\"slider\", min:0, max: 20, step:1}\n",
369
+ "char_range = 14 # @param {type:\"slider\", min:0, max: 20, step:1}\n",
370
  "\n",
371
  "#Tokenize input B\n",
372
  "from transformers import AutoTokenizer\n",
 
382
  "\n",
383
  "dots = torch.zeros(RANGE)\n",
384
  "is_BC = torch.zeros(RANGE)\n",
385
+ "\n",
386
+ "import re\n",
387
+ "\n",
388
  "for index in range(RANGE):\n",
389
  " id_C = START + index\n",
390
  " C = token[id_C]\n",
391
  " _C = LA.vector_norm(C, ord=2)\n",
392
  " name_C = vocab[id_C]\n",
393
  "\n",
394
+ " is_Prefix = 0\n",
395
+ "\n",
396
+ "\n",
397
+ " #Skip if non-AZ characters are found\n",
398
+ " if re.search(\"\\W/g\" , name_C.replace('</w>', '')):\n",
399
+ " continue\n",
400
+ "\n",
401
+ "\n",
402
  " # Decide if we should process prefix/suffix tokens\n",
403
  " if name_C.find('</w>')<=-1:\n",
404
+ " is_Prefix = 1\n",
405
  " if restrictions != \"Prefix only\":\n",
406
  " continue\n",
407
  " else:\n",
 
417
  " #-----#\n",
418
  "\n",
419
  " name_CB = must_start_with + name_C + name_B + must_end_with\n",
420
+ " if is_Prefix>0:\n",
421
+ " name_CB = must_start_with + ' ' + name_C.strip() + '-' + name_B.strip() + ' ' + must_end_with\n",
422
  " #-----#\n",
423
  " ids_CB = processor.tokenizer(text=name_CB, padding=use_token_padding, return_tensors=\"pt\")\n",
424
  " text_encoding_CB = model.get_text_features(**ids_CB)\n",
 
466
  "print('')\n",
467
  "print(f'These token pairings within the range ID = {START} to ID = {START + RANGE} most closely match the text_encoding for {prompt_A} : ')\n",
468
  "print('')\n",
469
+ "#----#\n",
470
+ "aheads = \"{\"\n",
471
+ "trails = \"{\"\n",
472
+ "tmp = \"\"\n",
473
+ "#----#\n",
474
+ "max_sim_ahead = 0\n",
475
+ "max_sim_trail = 0\n",
476
+ "sim = 0\n",
477
+ "max_name_ahead = ''\n",
478
+ "max_name_trail = ''\n",
479
+ "#----#\n",
480
  "for index in range(min(list_size,RANGE)):\n",
481
  " id = START + indices[index].item()\n",
482
+ " name = vocab[id]\n",
483
+ " #-----#\n",
484
+ " if (name.find('</w>')<=-1):\n",
485
+ " name = name + '-'\n",
486
+ " else:\n",
487
+ " name = name.replace('</w>', ' ')\n",
488
+ " if(is_BC[index]>0):\n",
489
+ " trails = trails + name + \"|\"\n",
490
+ " else:\n",
491
+ " aheads = aheads + name + \"|\"\n",
492
+ " #----#\n",
493
+ " sim = sorted[index].item()\n",
 
 
494
  "\n",
495
+ " if(is_BC[index]>0):\n",
496
+ " if sim>max_sim_ahead:\n",
497
+ " max_sim_ahead = sim\n",
498
+ " max_name_ahead = name\n",
499
+ " else:\n",
500
+ " if sim>max_sim_trail:\n",
501
+ " max_sim_trail = sim\n",
502
+ " max_name_trail = name\n",
503
  "\n",
504
+ "#------#\n",
505
+ "trails = (trails + \"&&&&\").replace(\"|&&&&\", \"}\").replace(\"</w>\", \" \").replace(\"{&&&&\", \"\")\n",
506
+ "aheads = (aheads + \"&&&&\").replace(\"|&&&&\", \"}\").replace(\"</w>\", \" \").replace(\"{&&&&\", \"\")\n",
507
+ "max_sim_ahead=max_sim_ahead*100\n",
508
+ "max_sim_ahead=max_sim_trail*100\n",
509
+ "#-----#\n",
510
+ "print(f\"place these items ahead of prompt : {aheads}\")\n",
511
+ "print(\"\")\n",
512
+ "print(f\"place these items behind the prompt : {trails}\")\n",
513
+ "print(\"\")\n",
514
+ "print(f\"max_similarity = {max_sim_ahead} % when using '{max_name_ahead + must_contain}' \")\n",
515
+ "print(\"\")\n",
516
+ "print(f\"max_similarity = {max_sim_trail} % when using '{must_contain + max_name_trail}' \")\n",
517
+ "#-----#\n",
518
+ "#STEP 2\n",
519
+ "import random\n",
520
+ "\n",
521
+ "names = {}\n",
522
+ "\n",
523
+ "NUM_PERMUTATIONS = 4 # 0 1 2 3\n",
524
+ "dots = torch.zeros(NUM_PERMUTATIONS)\n",
525
+ "for index in range(NUM_PERMUTATIONS):\n",
526
+ " name = must_start_with\n",
527
+ " if index == 0 : name = name + must_contain\n",
528
+ " if index == 1 : name = name + max_name_ahead + must_contain\n",
529
+ " if index == 2 : name = name + must_contain + max_name_trail\n",
530
+ " if index == 3 : name = name + max_name_ahead + must_contain + max_name_trail\n",
531
+ " name = name + must_end_with\n",
532
+ " #----#\n",
533
+ " ids_B = processor.tokenizer(text=name, padding=use_token_padding, return_tensors=\"pt\")\n",
534
+ " text_encoding_B = model.get_text_features(**ids_B)\n",
535
+ " B = text_encoding_B[0]\n",
536
+ " _B = LA.vector_norm(B, ord=2)\n",
537
+ " dots[index] = torch.dot(A,B)/(_A*_B)\n",
538
+ " names[index] = name\n",
539
+ "#------#\n",
540
  "\n",
541
+ "sorted, indices = torch.sort(dots,dim=0 , descending=True)\n",
542
  "\n",
543
+ "for index in range(NUM_PERMUTATIONS):\n",
544
+ " print(names[indices[index].item()])\n",
545
+ " print(f'similiarity = {round(sorted[index].item()*100,2)} %')\n",
546
+ " print('------')\n",
547
+ "\n",
548
+ "\n",
549
+ "\n",
550
+ ""
551
  ],
552
  "metadata": {
553
  "collapsed": true,
 
554
  "id": "fi0jRruI0-tu"
555
  },
556
  "execution_count": null,
557
  "outputs": []
558
  },
559
+ {
560
+ "cell_type": "code",
561
+ "source": [
562
+ "# @title 💫 Compare Text encodings\n",
563
+ "\n",
564
+ "prompt_A = \"banana\" # @param {\"type\":\"string\",\"placeholder\":\"Write a prompt\"}\n",
565
+ "prompt_B = \"\" # @param {\"type\":\"string\",\"placeholder\":\"Write a prompt\"}\n",
566
+ "use_token_padding = True # @param {type:\"boolean\"}\n",
567
+ "\n",
568
+ "from transformers import CLIPProcessor, CLIPModel\n",
569
+ "\n",
570
+ "processor = CLIPProcessor.from_pretrained(\"openai/clip-vit-large-patch14\" , clean_up_tokenization_spaces = True)\n",
571
+ "\n",
572
+ "model = CLIPModel.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
573
+ "\n",
574
+ "ids_A = processor.tokenizer(text=prompt_A, padding=use_token_padding, return_tensors=\"pt\")\n",
575
+ "text_encoding_A = model.get_text_features(**ids_A)\n",
576
+ "\n",
577
+ "\n",
578
+ "ids_B = processor.tokenizer(text=prompt_B, padding=use_token_padding, return_tensors=\"pt\")\n",
579
+ "text_encoding_B = model.get_text_features(**ids_B)\n",
580
+ "\n",
581
+ "similarity_str = 'The similarity between the text_encoding for A:\"' + prompt_A + '\" and B: \"' + prompt_B +'\" is ' + token_similarity(text_encoding_A[0] , text_encoding_B[0])\n",
582
+ "\n",
583
+ "\n",
584
+ "print(similarity_str)\n",
585
+ "#outputs = model(**inputs)\n",
586
+ "#logits_per_image = outputs.logits_per_image # this is the image-text similarity score\n",
587
+ "#probs = logits_per_image.softmax(dim=1) # we can take the softmax to get the label probabilities"
588
+ ],
589
+ "metadata": {
590
+ "id": "QQOjh5BvnG8M",
591
+ "collapsed": true,
592
+ "cellView": "form"
593
+ },
594
+ "execution_count": null,
595
+ "outputs": []
596
+ },
597
+ {
598
+ "cell_type": "markdown",
599
+ "source": [
600
+ "You can write an url or upload a file locally from your device to use as reference. The image will by saved in the 'sd_tokens' folder. Note that the 'sd_tokens' folder will be deleted upon exiting this runtime."
601
+ ],
602
+ "metadata": {
603
+ "id": "hyK423TQCRup"
604
+ }
605
+ },
606
  {
607
  "cell_type": "code",
608
  "source": [