File size: 1,205 Bytes
850b0e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7a7833f
 
850b0e4
 
d59d1e6
 
 
 
 
 
 
850b0e4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import gradio as gr
import torch
from mario_gpt.dataset import MarioDataset
from mario_gpt.prompter import Prompter
from mario_gpt.lm import MarioLM
from mario_gpt.utils import view_level, convert_level_to_png

mario_lm = MarioLM()

device = torch.device('cuda')
mario_lm = mario_lm.to(device)
TILE_DIR = "data/tiles"

def update(prompt, progress=gr.Progress(track_tqdm=True)):
    prompts = [prompt]
    generated_level = mario_lm.sample(
        prompts=prompts,
        num_steps=1399,
        temperature=2.0,
        use_tqdm=True
    )
    img = convert_level_to_png(generated_level.squeeze(), TILE_DIR, mario_lm.tokenizer)[0]
    return img   

with gr.Blocks() as demo:
    prompt = gr.Textbox(label="Enter your MarioGPT prompt")
    level_image = gr.Image()
    btn = gr.Button("Generate level")
    btn.click(fn=update, inputs=prompt, outputs=level_image)
    gr.Examples(
        examples=["many pipes, many enemies, some blocks, high elevation", "little pipes, little enemies, many blocks, high elevation", "many pipes, some enemies", "no pipes, no enemies, many blocks"],
        inputs=prompt,
        outputs=level_image,
        fn=update,
        cache_examples=True,
    )
demo.launch()