Transformers
remyx
Inference Endpoints
File size: 1,776 Bytes
6199e94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
import argparse
import requests
import torch
from PIL import Image
from pathlib import Path
from prismatic import load

def main(model_location, user_prompt, image_source):
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    # Load a pretrained VLM (either local path, or ID to auto-download from the HF Hub) 
    vlm = load(model_location)
    vlm.to(device, dtype=torch.bfloat16)

    # Load the image from URL or local path
    if image_source.startswith("http://") or image_source.startswith("https://"):
        image = Image.open(requests.get(image_source, stream=True).raw).convert("RGB")
    else:
        image = Image.open(image_source).convert("RGB")

    # Build prompt
    prompt_builder = vlm.get_prompt_builder()
    prompt_builder.add_turn(role="human", message=user_prompt)
    prompt_text = prompt_builder.get_prompt()

    # Generate!
    generated_text = vlm.generate(
        image,
        prompt_text,
        do_sample=True,
        temperature=0.1,
        max_new_tokens=512,
        min_length=1,
    )
    generated_text = generated_text.split("</s>")[0]

    print("PROMPT TEXT: ", user_prompt)
    print("GENERATED TEXT: ", generated_text)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Process an image and prompt with a pretrained VLM model.")
    parser.add_argument("--model_location", type=str, required=True, help="The location of the pretrained VLM model.")
    parser.add_argument("--user_prompt", type=str, required=True, help="The prompt to process.")
    parser.add_argument("--image_source", type=str, required=True, help="The URL or local path of the image.")

    args = parser.parse_args()

    main(args.model_location, args.user_prompt, args.image_source)