Spaces:
Runtime error
Runtime error
| # -*- coding: utf-8 -*- | |
| """message_bottle.ipynb | |
| Automatically generated by Colab. | |
| """ | |
| DEVICE = 'cpu' | |
| import matplotlib.pyplot as plt | |
| import matplotlib | |
| import argparse | |
| import glob | |
| import logging | |
| import os | |
| import pickle | |
| import random | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from tqdm import tqdm, trange | |
| from types import SimpleNamespace | |
| import sys | |
| sys.path.append('./Optimus/code/examples/big_ae/') | |
| sys.path.append('./Optimus/code/') | |
| from pytorch_transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, BertConfig | |
| from pytorch_transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2ForLatentConnector | |
| from pytorch_transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer | |
| from pytorch_transformers import XLNetLMHeadModel, XLNetTokenizer | |
| from pytorch_transformers import TransfoXLLMHeadModel, TransfoXLTokenizer | |
| from pytorch_transformers import BertForLatentConnector, BertTokenizer | |
| from modules import VAE | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| torch.set_float32_matmul_precision('high') | |
| from tqdm import tqdm | |
| ################################################ | |
| def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')): | |
| """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering | |
| Args: | |
| logits: logits distribution shape (vocabulary size) | |
| top_k > 0: keep only top k tokens with highest probability (top-k filtering). | |
| top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). | |
| Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) | |
| From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317 | |
| """ | |
| assert logits.dim() == 1 # batch size 1 for now - could be updated for more but the code would be less clear | |
| top_k = min(top_k, logits.size(-1)) # Safety check | |
| if top_k > 0: | |
| # Remove all tokens with a probability less than the last token of the top-k | |
| indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] | |
| logits[indices_to_remove] = filter_value | |
| if top_p > 0.0: | |
| sorted_logits, sorted_indices = torch.sort(logits, descending=True) | |
| cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) | |
| # Remove tokens with cumulative probability above the threshold | |
| sorted_indices_to_remove = cumulative_probs > top_p | |
| # Shift the indices to the right to keep also the first token above the threshold | |
| sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() | |
| sorted_indices_to_remove[..., 0] = 0 | |
| indices_to_remove = sorted_indices[sorted_indices_to_remove] | |
| logits[indices_to_remove] = filter_value | |
| return logits | |
| def sample_sequence_conditional(model, length, context, past=None, num_samples=1, temperature=1, top_k=0, top_p=0.0, device='cpu', decoder_tokenizer=None): | |
| context = torch.tensor(context, dtype=torch.long, device=device) | |
| context = context.unsqueeze(0).repeat(num_samples, 1) | |
| generated = context | |
| with torch.no_grad(): | |
| while True: | |
| # for _ in trange(length): | |
| inputs = {'input_ids': generated, 'past': past} | |
| outputs = model(**inputs) # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet (cached hidden-states) | |
| next_token_logits = outputs[0][0, -1, :] / temperature | |
| filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p) | |
| next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1) | |
| generated = torch.cat((generated, next_token.unsqueeze(0)), dim=1) | |
| # pdb.set_trace() | |
| if next_token.unsqueeze(0)[0,0].item() == decoder_tokenizer.encode('<EOS>')[0]: | |
| break | |
| return generated | |
| def latent_code_from_text(text,):# args): | |
| tokenized1 = tokenizer_encoder.encode(text) | |
| tokenized1 = [101] + tokenized1 + [102] | |
| coded1 = torch.Tensor([tokenized1]) | |
| coded1 =torch.Tensor.long(coded1) | |
| with torch.no_grad(): | |
| x0 = coded1 | |
| x0 = x0.to(DEVICE) | |
| pooled_hidden_fea = model_vae.encoder(x0, attention_mask=(x0 > 0).float())[1] | |
| mean, logvar = model_vae.encoder.linear(pooled_hidden_fea).chunk(2, -1) | |
| latent_z = mean.squeeze(1) | |
| coded_length = len(tokenized1) | |
| return latent_z, coded_length | |
| # args | |
| def text_from_latent_code(latent_z): | |
| past = latent_z | |
| context_tokens = tokenizer_decoder.encode('<BOS>') | |
| length = 128 # maximum length, but not used | |
| out = sample_sequence_conditional( | |
| model=model_vae.decoder, | |
| context=context_tokens, | |
| past=past, | |
| length= length, # Chunyuan: Fix length; or use <EOS> to complete a sentence | |
| temperature=.5, | |
| top_k=100, | |
| top_p=.98, | |
| device=DEVICE, | |
| decoder_tokenizer=tokenizer_decoder | |
| ) | |
| text_x1 = tokenizer_decoder.decode(out[0,:].tolist(), clean_up_tokenization_spaces=True) | |
| text_x1 = text_x1.split()[1:-1] | |
| text_x1 = ' '.join(text_x1) | |
| return text_x1 | |
| ################################################ | |
| # Load model | |
| MODEL_CLASSES = { | |
| 'gpt2': (GPT2Config, GPT2ForLatentConnector, GPT2Tokenizer), | |
| 'bert': (BertConfig, BertForLatentConnector, BertTokenizer) | |
| } | |
| latent_size = 768 | |
| model_path = './checkpoint-31250/checkpoint-full-31250/' | |
| encoder_path = './checkpoint-31250/checkpoint-encoder-31250/' | |
| decoder_path = './checkpoint-31250/checkpoint-decoder-31250/' | |
| block_size = 100 | |
| # Load a trained Encoder model and vocabulary that you have fine-tuned | |
| encoder_config_class, encoder_model_class, encoder_tokenizer_class = MODEL_CLASSES['bert'] | |
| model_encoder = encoder_model_class.from_pretrained(encoder_path, latent_size=latent_size) | |
| tokenizer_encoder = encoder_tokenizer_class.from_pretrained('bert-base-cased', do_lower_case=True) | |
| model_encoder.to(DEVICE) | |
| if block_size <= 0: | |
| block_size = tokenizer_encoder.max_len_single_sentence # Our input block size will be the max possible for the model | |
| block_size = min(block_size, tokenizer_encoder.max_len_single_sentence) | |
| # Load a trained Decoder model and vocabulary that you have fine-tuned | |
| decoder_config_class, decoder_model_class, decoder_tokenizer_class = MODEL_CLASSES['gpt2'] | |
| model_decoder = decoder_model_class.from_pretrained(decoder_path, latent_size=latent_size) | |
| tokenizer_decoder = decoder_tokenizer_class.from_pretrained('gpt2', do_lower_case=False) | |
| model_decoder.to(DEVICE) | |
| if block_size <= 0: | |
| block_size = tokenizer_decoder.max_len_single_sentence # Our input block size will be the max possible for the model | |
| block_size = min(block_size, tokenizer_decoder.max_len_single_sentence) | |
| # Load full model | |
| output_full_dir = '/home/ryn_mote/Misc/generative_recommender/text_space/' | |
| checkpoint = torch.load(os.path.join(model_path, 'training.bin')) | |
| # Chunyuan: Add Padding token to GPT2 | |
| special_tokens_dict = {'pad_token': '<PAD>', 'bos_token': '<BOS>', 'eos_token': '<EOS>'} | |
| num_added_toks = tokenizer_decoder.add_special_tokens(special_tokens_dict) | |
| print('We have added', num_added_toks, 'tokens to GPT2') | |
| model_decoder.resize_token_embeddings(len(tokenizer_decoder)) # Notice: resize_token_embeddings expect to receive the full size of the new vocabulary, i.e. the length of the tokenizer. | |
| assert tokenizer_decoder.pad_token == '<PAD>' | |
| # Evaluation | |
| model_vae = VAE(model_encoder, model_decoder, tokenizer_encoder, tokenizer_decoder, SimpleNamespace(**{'latent_size': latent_size, 'device':DEVICE})) | |
| model_vae.load_state_dict(checkpoint['model_state_dict']) | |
| print("Pre-trained Optimus is successfully loaded") | |
| model_vae.to(DEVICE).to(torch.bfloat16) | |
| model_vae = torch.compile(model_vae) | |
| l = latent_code_from_text('A photo of a mountain.')[0] | |
| t = text_from_latent_code(l) | |
| print(t, l, l.shape) | |
| ################################################ | |
| import gradio as gr | |
| import numpy as np | |
| from sklearn.svm import SVC | |
| from sklearn.inspection import permutation_importance | |
| from sklearn import preprocessing | |
| import pandas as pd | |
| import random | |
| import time | |
| dtype = torch.bfloat16 | |
| torch.set_grad_enabled(False) | |
| prompt_list = [p for p in list(set( | |
| pd.read_csv('./twitter_prompts.csv').iloc[:, 1].tolist())) if type(p) == str] | |
| start_time = time.time() | |
| ####################### Setup Model | |
| # TODO put back | |
| # @spaces.GPU() | |
| def generate(prompt, in_embs=None,): | |
| if prompt != '': | |
| print(prompt) | |
| in_embs = in_embs / in_embs.abs().max() * .6 if in_embs != None else None | |
| in_embs = 1 * in_embs.to(DEVICE) + 1 * latent_code_from_text(prompt)[0] if in_embs != None else latent_code_from_text(prompt)[0] | |
| else: | |
| print('From embeds.') | |
| in_embs = in_embs / in_embs.abs().max() * .6 | |
| in_embs = in_embs.to(DEVICE).to(torch.bfloat16) | |
| plt.close('all') | |
| plt.hist(np.array(in_embs.detach().to('cpu').to(torch.float)).flatten(), bins=5) | |
| plt.savefig('real_im_emb_plot.jpg') | |
| text = ' '.join(text_from_latent_code(in_embs).replace( '<unk>', '').split()) | |
| in_embs = latent_code_from_text(text)[0] | |
| print(text) | |
| return text, in_embs.to('cpu') | |
| ####################### | |
| # TODO add to state instead of shared across all | |
| glob_idx = 0 | |
| def next_one(embs, ys, calibrate_prompts): | |
| global glob_idx | |
| glob_idx = glob_idx + 1 | |
| with torch.no_grad(): | |
| if len(calibrate_prompts) > 0: | |
| print('######### Calibrating with sample prompts #########') | |
| prompt = calibrate_prompts.pop(0) | |
| text, img_embs = generate(prompt) | |
| embs += img_embs | |
| print(len(embs)) | |
| return text, embs, ys, calibrate_prompts | |
| else: | |
| print('######### Roaming #########') | |
| # handle case where every instance of calibration prompts is 'Neither' or 'Like' or 'Dislike' | |
| if len(list(set(ys))) <= 1: | |
| embs.append(.01*torch.randn(latent_size)) | |
| embs.append(.01*torch.randn(latent_size)) | |
| ys.append(0) | |
| ys.append(1) | |
| if len(list(ys)) < 10: | |
| embs += [.01*torch.randn(latent_size)] * 3 | |
| ys += [0] * 3 | |
| pos_indices = [i for i in range(len(embs)) if ys[i] == 1] | |
| neg_indices = [i for i in range(len(embs)) if ys[i] == 0] | |
| # the embs & ys stay tied by index but we shuffle to drop randomly | |
| random.shuffle(pos_indices) | |
| random.shuffle(neg_indices) | |
| if len(neg_indices) > 25: | |
| neg_indices = neg_indices[1:] | |
| print(len(pos_indices), len(neg_indices)) | |
| indices = pos_indices + neg_indices | |
| embs = [embs[i] for i in indices] | |
| ys = [ys[i] for i in indices] | |
| indices = list(range(len(embs))) | |
| # also add the latest 0 and the latest 1 | |
| #has_0 = False | |
| #has_1 = False | |
| #for i in reversed(range(len(ys))): | |
| # if ys[i] == 0 and has_0 == False: | |
| # indices.append(i) | |
| # has_0 = True | |
| # elif ys[i] == 1 and has_1 == False: | |
| # indices.append(i) | |
| # has_1 = True | |
| # if has_0 and has_1: | |
| # break | |
| # we may have just encountered a rare multi-threading diffusers issue (https://github.com/huggingface/diffusers/issues/5749); | |
| # this ends up adding a rating but losing an embedding, it seems. | |
| # let's take off a rating if so to continue without indexing errors. | |
| if len(ys) > len(embs): | |
| print('ys are longer than embs; popping latest rating') | |
| ys.pop(-1) | |
| feature_embs = np.array(torch.stack([embs[i].to('cpu') for i in indices]).to('cpu')) | |
| scaler = preprocessing.StandardScaler().fit(feature_embs) | |
| feature_embs = scaler.transform(feature_embs) | |
| chosen_y = np.array([ys[i] for i in indices]) | |
| print('Gathering coefficients') | |
| lin_class = SVC(max_iter=50000, kernel='linear', class_weight='balanced', C=.1).fit(feature_embs, chosen_y) | |
| coef_ = torch.tensor(lin_class.coef_, dtype=torch.double) | |
| print(coef_.shape, 'COEF') | |
| print('Gathered') | |
| rng_prompt = random.choice(prompt_list) | |
| w = 1# if len(embs) % 2 == 0 else 0 | |
| im_emb = w * coef_.to(dtype=dtype) | |
| prompt= '' if glob_idx % 3 != 0 else rng_prompt | |
| text, im_emb = generate(prompt, im_emb) | |
| embs += im_emb | |
| return text, embs, ys, calibrate_prompts | |
| def start(_, embs, ys, calibrate_prompts): | |
| text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts) | |
| return [ | |
| gr.Button(value='Like (L)', interactive=True), | |
| gr.Button(value='Neither (Space)', interactive=True), | |
| gr.Button(value='Dislike (A)', interactive=True), | |
| gr.Button(value='Start', interactive=False), | |
| text, | |
| embs, | |
| ys, | |
| calibrate_prompts | |
| ] | |
| def choose(text, choice, embs, ys, calibrate_prompts): | |
| if choice == 'Like (L)': | |
| choice = 1 | |
| elif choice == 'Neither (Space)': | |
| embs = embs[:-1] | |
| text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts) | |
| return text, embs, ys, calibrate_prompts | |
| else: | |
| choice = 0 | |
| # if we detected NSFW, leave that area of latent space regardless of how they rated chosen. | |
| # TODO skip allowing rating | |
| if text == None: | |
| print('NSFW -- choice is disliked') | |
| choice = 0 | |
| ys += [choice]*1 | |
| text, embs, ys, calibrate_prompts = next_one(embs, ys, calibrate_prompts) | |
| return text, embs, ys, calibrate_prompts | |
| css = '''.gradio-container{max-width: 700px !important} | |
| #description{text-align: center} | |
| #description h1, #description h3{display: block} | |
| #description p{margin-top: 0} | |
| .fade-in-out {animation: fadeInOut 3s forwards} | |
| @keyframes fadeInOut { | |
| 0% { | |
| background: var(--bg-color); | |
| } | |
| 100% { | |
| background: var(--button-secondary-background-fill); | |
| } | |
| } | |
| ''' | |
| js_head = ''' | |
| <script> | |
| document.addEventListener('keydown', function(event) { | |
| if (event.key === 'a' || event.key === 'A') { | |
| // Trigger click on 'dislike' if 'A' is pressed | |
| document.getElementById('dislike').click(); | |
| } else if (event.key === ' ' || event.keyCode === 32) { | |
| // Trigger click on 'neither' if Spacebar is pressed | |
| document.getElementById('neither').click(); | |
| } else if (event.key === 'l' || event.key === 'L') { | |
| // Trigger click on 'like' if 'L' is pressed | |
| document.getElementById('like').click(); | |
| } | |
| }); | |
| function fadeInOut(button, color) { | |
| button.style.setProperty('--bg-color', color); | |
| button.classList.remove('fade-in-out'); | |
| void button.offsetWidth; // This line forces a repaint by accessing a DOM property | |
| button.classList.add('fade-in-out'); | |
| button.addEventListener('animationend', () => { | |
| button.classList.remove('fade-in-out'); // Reset the animation state | |
| }, {once: true}); | |
| } | |
| document.body.addEventListener('click', function(event) { | |
| const target = event.target; | |
| if (target.id === 'dislike') { | |
| fadeInOut(target, '#ff1717'); | |
| } else if (target.id === 'like') { | |
| fadeInOut(target, '#006500'); | |
| } else if (target.id === 'neither') { | |
| fadeInOut(target, '#cccccc'); | |
| } | |
| }); | |
| </script> | |
| ''' | |
| with gr.Blocks(css=css, head=js_head) as demo: | |
| gr.Markdown('''# Compass | |
| ### Generative Recommenders for Exporation of Text | |
| Explore the latent space without prompting based on your preferences. Learn more on [the write-up](https://rynmurdock.github.io/posts/2024/3/generative_recomenders/). | |
| ''', elem_id="description") | |
| embs = gr.State([]) | |
| ys = gr.State([]) | |
| calibrate_prompts = gr.State([ | |
| 'the moon is melting into my glass of tea', | |
| 'a sea slug -- pair of claws scuttling -- jelly fish glowing', | |
| 'an adorable creature. It may be a goblin or a pig or a slug.', | |
| 'an animation about a gorgeous nebula', | |
| 'a sketch of an impressive mountain by da vinci', | |
| 'a watercolor painting: the octopus writhes', | |
| ]) | |
| def l(): | |
| return None | |
| with gr.Row(elem_id='output-image'): | |
| text = gr.Textbox(interactive=False, elem_id="text") | |
| with gr.Row(equal_height=True): | |
| b3 = gr.Button(value='Dislike (A)', interactive=False, elem_id="dislike") | |
| b2 = gr.Button(value='Neither (Space)', interactive=False, elem_id="neither") | |
| b1 = gr.Button(value='Like (L)', interactive=False, elem_id="like") | |
| b1.click( | |
| choose, | |
| [text, b1, embs, ys, calibrate_prompts], | |
| [text, embs, ys, calibrate_prompts] | |
| ) | |
| b2.click( | |
| choose, | |
| [text, b2, embs, ys, calibrate_prompts], | |
| [text, embs, ys, calibrate_prompts] | |
| ) | |
| b3.click( | |
| choose, | |
| [text, b3, embs, ys, calibrate_prompts], | |
| [text, embs, ys, calibrate_prompts] | |
| ) | |
| with gr.Row(): | |
| b4 = gr.Button(value='Start') | |
| b4.click(start, | |
| [b4, embs, ys, calibrate_prompts], | |
| [b1, b2, b3, b4, text, embs, ys, calibrate_prompts]) | |
| with gr.Row(): | |
| html = gr.HTML('''<div style='text-align:center; font-size:20px'>You will calibrate for several prompts and then roam. </ div><br><br><br> | |
| <div style='text-align:center; font-size:14px'>Note that while the model is unlikely to produce NSFW text, this may still occur, and users should avoid NSFW content when rating. | |
| </ div> | |
| <br><br> | |
| <div style='text-align:center; font-size:14px'>Thanks to @multimodalart for their contributions to the demo, esp. the interface and @maxbittker for feedback. | |
| </ div>''') | |
| demo.launch(share=True) | |