File size: 3,123 Bytes
22fc8c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""SpaceLlama3.1 demo gradio app."""

"""SpaceLlama3.1 demo gradio app."""

import datetime
import logging
import os

import gradio as gr
import requests
import torch
import PIL.Image
from prismatic import load

INTRO_TEXT = """SpaceLlama3.1 demo\n\n
| [Model](https://huggingface.co/remyxai/SpaceLlama3.1) 
| [GitHub](https://github.com/remyxai/VQASynth/tree/main) 
| [Demo](https://huggingface.co/spaces/remyxai/SpaceLlama3.1) 
| [Discord](https://discord.gg/DAy3P5wYJk) 
\n\n
**This is an experimental research model.** Make sure to add appropriate guardrails when using the model for applications.
"""

def compute(image, prompt, model_location):
    """Runs model inference."""
    if image is None:
        raise gr.Error("Image required")

    logging.info('prompt="%s"', prompt)

    # Open the image file
    if isinstance(image, str):
        image = PIL.Image.open(image).convert("RGB")

    # Set device and load the model
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    vlm = load(model_location)
    vlm.to(device, dtype=torch.bfloat16)

    # Prepare prompt
    prompt_builder = vlm.get_prompt_builder()
    prompt_builder.add_turn(role="human", message=prompt)
    prompt_text = prompt_builder.get_prompt()

    # Generate the text based on image and prompt
    generated_text = vlm.generate(
        image,
        prompt_text,
        do_sample=True,
        temperature=0.1,
        max_new_tokens=512,
        min_length=1,
    )
    output = generated_text.split("</s>")[0]

    logging.info('output="%s"', output)

    return output

def reset():
    """Resets the input fields."""
    return "", None

def create_app():
    """Creates demo UI."""

    with gr.Blocks() as demo:
        # Main UI structure
        gr.Markdown(INTRO_TEXT)
        with gr.Row():
            image = gr.Image(value=None, label="Image", type="filepath", visible=True)  # input
            with gr.Column():
                prompt = gr.Textbox(value="", label="Prompt", visible=True)
                model_info = gr.Markdown(label="Model Info")
                run = gr.Button("Run", variant="primary")
                clear = gr.Button("Clear")
                highlighted_text = gr.HighlightedText(value="", label="Output", visible=True)

        # Model location
        model_location = "remyxai/SpaceLlama3.1"  # Update as needed

        # Button event handlers
        run.click(
            compute,
            [image, prompt, model_location],
            highlighted_text,
        )
        clear.click(reset, None, [prompt, image])

        # Status
        status = gr.Markdown(f"Startup: {datetime.datetime.now()}")
        gpu_kind = gr.Markdown(f"GPU=?")
        demo.load(
            lambda: [f"Model `{model_location}` loaded."],
            None,
            model_info,
        )

    return demo

if __name__ == "__main__":

    logging.basicConfig(
        level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s"
    )

    for k, v in os.environ.items():
        logging.info('environ["%s"] = %r', k, v)

    create_app().queue().launch()