File size: 2,480 Bytes
2ec2ebd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import gradio as gr
import sys
# sys.path.append("LaVi-Bridge/test")
# from llama2_unet_diffusion_lens import call_diffusion_lens
from diffusion_lens import get_images

import os
import subprocess

def prepare_images(images):
    return images


# def call_diffusion_lens(prompt):
#     # os.chdir('LaVi-Bridge/test')
#     command = f"python -u llama2_unet.py --ckpt_dir 'LaVi-Bridge/llama2_unet' --output_dir 'output' --llama2_dir 'meta-llama/Llama-2-7b-hf' --prompt '{prompt}'"
#     subprocess.run(command, shell=True)
#     return 'output'


def get_prompt(prompt):
    print('prompt:', prompt)
    print('calling diffusion lens')

    # parser.add_argument("--ckpt_dir", type=str, default="")
    # parser.add_argument("--output_dir", type=str, default="")
    # parser.add_argument("--llama2_dir", type=str, default="")
    # parser.add_argument("--prompts_path", type=str, default="inputs/in.txt")
    # parser.add_argument("--use_chat", action="store_true")
    # parser.add_argument("--generate_text", action="store_true")
    # parser.add_argument("--dont_use_lora", action="store_true")
    # parser.add_argument("--is_gradio", action="store_true")

    # args = {
    #     'ckpt_dir': 'LaVi-Bridge/LaVi-Bridge/llama2_unet',
    #     'output_dir': 'output',
    #     'llama2_dir': 'meta-llama/Llama-2-7b-hf',
    #     'prompt': prompt,
    #     'use_chat': False,
    #     'generate_text': False,
    #     'dont_use_lora': False,
    #     'is_gradio': True,
    #     'prompts_path': None
    # }
    # images = call_diffusion_lens(args, prompt)
    # print('done calling diffusion lens')
    # print('number of images:', len(images))
    # images = prepare_images(images)
    # print('done preparing images')
    image = get_images(prompt)
    return image

if __name__ == '__main__':
    print('starting')
    get_prompt("A photo of a cat")
    print('done')


# iface = gr.Interface(fn=get_prompt, inputs="text", outputs="image", title="Diffusion Lens")
# iface.launch()

import gradio as gr

def display_images(images):
    # Prepare images for display
    return [gr.Image(image) for image in images]

if __name__ == '__main__':
    with gr.Blocks() as demo:
        gallery = gr.Gallery(
            label="Generated images", show_label=False, elem_id="gallery",
            columns=[1], rows=[1], object_fit="contain", height="auto")
        btn = gr.Button("Generate images", scale=0)

        btn.click(get_prompt, 'text', gallery)

    demo.launch()