reimagine-it / app.py
Alberto Carmona
Track error cloning the repo
7ec5667
raw
history blame
6.74 kB
# import torch
# import torch.nn as nn
# import numpy as np
# import json
# import captioning.utils.opts as opts
# import captioning.models as models
# import captioning.utils.misc as utils
# import pytorch_lightning as pl
import gradio as gr
# from diffusers import LDMTextToImagePipeline
# # import PIL.Image
import random
# import os
# # Checkpoint class
# class ModelCheckpoint(pl.callbacks.ModelCheckpoint):
# def on_keyboard_interrupt(self, trainer, pl_module):
# # Save model when keyboard interrupt
# filepath = os.path.join(self.dirpath, self.prefix + 'interrupt.ckpt')
# self._save_model(filepath)
# device = 'cpu' #@param ["cuda", "cpu"] {allow-input: true}
# reward = 'clips_grammar'
# cfg = f'./configs/phase2/clipRN50_{reward}.yml'
# print("Loading cfg from", cfg)
# opt = opts.parse_opt(parse=False, cfg=cfg)
# import gdown
# url = "https://drive.google.com/drive/folders/1nSX9aS7pPK4-OTHYtsUD_uEkwIQVIV7W"
# gdown.download_folder(url, quiet=True, use_cookies=False, output="save/")
# url = "https://drive.google.com/uc?id=1HNRE1MYO9wxmtMHLC8zURraoNFu157Dp"
# gdown.download(url, quiet=True, use_cookies=False, output="data/")
# dict_json = json.load(open('./data/cocotalk.json'))
# print(dict_json.keys())
# ix_to_word = dict_json['ix_to_word']
# vocab_size = len(ix_to_word)
# print('vocab size:', vocab_size)
# seq_length = 1
# opt.vocab_size = vocab_size
# opt.seq_length = seq_length
# opt.batch_size = 1
# opt.vocab = ix_to_word
# model = models.setup(opt)
# del opt.vocab
# ckpt_path = opt.checkpoint_path + '-last.ckpt'
# print("Loading checkpoint from", ckpt_path)
# raw_state_dict = torch.load(
# ckpt_path,
# map_location=device)
# strict = True
# state_dict = raw_state_dict['state_dict']
# if '_vocab' in state_dict:
# model.vocab = utils.deserialize(state_dict['_vocab'])
# del state_dict['_vocab']
# elif strict:
# raise KeyError
# if '_opt' in state_dict:
# saved_model_opt = utils.deserialize(state_dict['_opt'])
# del state_dict['_opt']
# # Make sure the saved opt is compatible with the curren topt
# need_be_same = ["caption_model",
# "rnn_type", "rnn_size", "num_layers"]
# for checkme in need_be_same:
# if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \
# getattr(opt, checkme) in ['updown', 'topdown']:
# continue
# assert getattr(saved_model_opt, checkme) == getattr(
# opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
# elif strict:
# raise KeyError
# res = model.load_state_dict(state_dict, strict)
# print(res)
# model = model.to(device)
# model.eval();
# import clip
# from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
# from PIL import Image
# from timm.models.vision_transformer import resize_pos_embed
# clip_model, clip_transform = clip.load("RN50", jit=False, device=device)
# preprocess = Compose([
# Resize((448, 448), interpolation=Image.BICUBIC),
# CenterCrop((448, 448)),
# ToTensor()
# ])
# image_mean = torch.Tensor([0.48145466, 0.4578275, 0.40821073]).to(device).reshape(3, 1, 1)
# image_std = torch.Tensor([0.26862954, 0.26130258, 0.27577711]).to(device).reshape(3, 1, 1)
# num_patches = 196 #600 * 1000 // 32 // 32
# pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, clip_model.visual.attnpool.positional_embedding.shape[-1], device=device),)
# pos_embed.weight = resize_pos_embed(clip_model.visual.attnpool.positional_embedding.unsqueeze(0), pos_embed)
# clip_model.visual.attnpool.positional_embedding = pos_embed
# # End below
# print('Loading the model: CompVis/ldm-text2im-large-256')
# ldm_pipeline = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
# def generate_image_from_text(prompt, steps=100, seed=42, guidance_scale=6.0):
# print('RUN: generate_image_from_text')
# torch.cuda.empty_cache()
# generator = torch.manual_seed(seed)
# images = ldm_pipeline([prompt], generator=generator, num_inference_steps=steps, eta=0.3, guidance_scale=guidance_scale)["sample"]
# return images[0]
# def generate_text_from_image(img):
# print('RUN: generate_text_from_image')
# with torch.no_grad():
# image = preprocess(img)
# image = torch.tensor(np.stack([image])).to(device)
# image -= image_mean
# image /= image_std
# tmp_att, tmp_fc = clip_model.encode_image(image)
# tmp_att = tmp_att[0].permute(1, 2, 0)
# tmp_fc = tmp_fc[0]
# att_feat = tmp_att
# fc_feat = tmp_fc
# # Inference configurations
# eval_kwargs = {}
# eval_kwargs.update(vars(opt))
# verbose = eval_kwargs.get('verbose', True)
# verbose_beam = eval_kwargs.get('verbose_beam', 0)
# verbose_loss = eval_kwargs.get('verbose_loss', 1)
# # dataset = eval_kwargs.get('dataset', 'coco')
# beam_size = eval_kwargs.get('beam_size', 1)
# sample_n = eval_kwargs.get('sample_n', 1)
# remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
# with torch.no_grad():
# fc_feats = torch.zeros((1,0)).to(device)
# att_feats = att_feat.view(1, 196, 2048).float().to(device)
# att_masks = None
# # forward the model to also get generated samples for each image
# # Only leave one feature for each image, in case duplicate sample
# tmp_eval_kwargs = eval_kwargs.copy()
# tmp_eval_kwargs.update({'sample_n': 1})
# seq, seq_logprobs = model(
# fc_feats, att_feats, att_masks, opt=tmp_eval_kwargs, mode='sample')
# seq = seq.data
# sents = utils.decode_sequence(model.vocab, seq)
# return sents[0]
# def generate_drawing_from_image(img, steps=100, seed=42, guidance_scale=6.0):
# print('RUN: generate_drawing_from_image')
# caption = generate_text_from_image(img)
# gen_image = generate_image_from_text(caption, steps=steps, seed=seed, guidance_scale=guidance_scale)
# return gen_image
random_seed = random.randint(0, 2147483647)
def test_fn(**kwargs):
return None
gr.Interface(
# generate_drawing_from_image,
test_fn,
inputs=[
gr.Image(type="pil"),
gr.inputs.Slider(1, 100, label='Inference Steps', default=50, step=1),
gr.inputs.Slider(0, 2147483647, label='Seed', default=random_seed, step=1),
gr.inputs.Slider(1.0, 20.0, label='Guidance Scale - how much the prompt will influence the results', default=6.0, step=0.1),
],
outputs=gr.Image(shape=[256,256], type="pil", elem_id="output_image"),
css="#output_image{width: 256px}",
).launch()