File size: 2,728 Bytes
b5ed9fd
 
 
 
 
 
c800f36
 
 
 
 
b5ed9fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import requests
import torch
from PIL import Image

from transformers import AriaProcessor, AriaForConditionalGeneration
from fastapi import FastAPI, Request

app = FastAPI()

@app.get("/")
def greet_json():
    return {"Hello": "World!"}

@app.post("/")
async def aria_image_to_text(request: Request):
    data = await request.json()
    image_url = data.get("image_url")
    image = Image.open(requests.get(image_url, stream=True).raw)

    model_id_or_path = "rhymes-ai/Aria"
    model = AriaForConditionalGeneration.from_pretrained(
        model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16
    )

    processor = AriaProcessor.from_pretrained(model_id_or_path)

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"text": "what is the image?", "type": "text"},
            ],
        }
    ]

    text = processor.apply_chat_template(messages, add_generation_prompt=True)
    inputs = processor(text=text, images=image, return_tensors="pt")
    inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
    inputs.to(model.device)

    output = model.generate(
        **inputs,
        max_new_tokens=15,
        stop_strings=["<|im_end|>"],
        tokenizer=processor.tokenizer,
        do_sample=True,
        temperature=0.9,
    )
    output_ids = output[0][inputs["input_ids"].shape[1]:]
    response = processor.decode(output_ids, skip_special_tokens=True)
    return {"response": response}

@app.get("/aria-test")
def aria_test():
    model_id_or_path = "rhymes-ai/Aria"
    model = AriaForConditionalGeneration.from_pretrained(
        model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16
    )

    processor = AriaProcessor.from_pretrained(model_id_or_path)

    image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)

    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"text": "what is the image?", "type": "text"},
            ],
        }
    ]

    text = processor.apply_chat_template(messages, add_generation_prompt=True)
    inputs = processor(text=text, images=image, return_tensors="pt")
    inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
    inputs.to(model.device)

    output = model.generate(
        **inputs,
        max_new_tokens=15,
        stop_strings=["<|im_end|>"],
        tokenizer=processor.tokenizer,
        do_sample=True,
        temperature=0.9,
    )
    output_ids = output[0][inputs["input_ids"].shape[1]:]
    response = processor.decode(output_ids, skip_special_tokens=True)
    return {"response": response}