File size: 3,128 Bytes
b5ed9fd
 
 
 
 
 
c800f36
 
 
 
 
b5ed9fd
 
 
 
cbe05b5
b5ed9fd
cbe05b5
b5ed9fd
cbe05b5
 
 
b5ed9fd
cbe05b5
b5ed9fd
cbe05b5
b5ed9fd
 
 
cbe05b5
b5ed9fd
cbe05b5
b5ed9fd
 
 
 
 
 
 
 
 
 
cbe05b5
b5ed9fd
cbe05b5
b5ed9fd
cbe05b5
b5ed9fd
cbe05b5
b5ed9fd
cbe05b5
b5ed9fd
 
 
 
 
 
 
 
cbe05b5
b5ed9fd
cbe05b5
b5ed9fd
cbe05b5
b5ed9fd
 
 
 
cbe05b5
b5ed9fd
cbe05b5
b5ed9fd
 
 
cbe05b5
b5ed9fd
cbe05b5
b5ed9fd
cbe05b5
b5ed9fd
 
 
 
 
 
 
 
 
cbe05b5
b5ed9fd
cbe05b5
b5ed9fd
cbe05b5
b5ed9fd
cbe05b5
b5ed9fd
cbe05b5
b5ed9fd
 
 
 
 
 
 
 
 
cbe05b5
b5ed9fd
cbe05b5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import requests
import torch
from PIL import Image

from transformers import AriaProcessor, AriaForConditionalGeneration
from fastapi import FastAPI, Request

app = FastAPI()

@app.get("/")
def greet_json():
    return {"Hello": "World!"}

@app.post("/")
async def aria_image_to_text(request: Request):
    print(1)
    data = await request.json()
    print(2)
    image_url = data.get("image_url")
    print(3)
    print('image_url')
    print(image_url)
    image = Image.open(requests.get(image_url, stream=True).raw)
    print(4)
    model_id_or_path = "rhymes-ai/Aria"
    print(5)
    model = AriaForConditionalGeneration.from_pretrained(
        model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16
    )
    print(6)
    processor = AriaProcessor.from_pretrained(model_id_or_path)
    print(7)
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"text": "what is the image?", "type": "text"},
            ],
        }
    ]

    print(8)
    text = processor.apply_chat_template(messages, add_generation_prompt=True)
    print(9)
    inputs = processor(text=text, images=image, return_tensors="pt")
    print(10)
    inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
    print(11)
    inputs.to(model.device)
    print(12)
    output = model.generate(
        **inputs,
        max_new_tokens=15,
        stop_strings=["<|im_end|>"],
        tokenizer=processor.tokenizer,
        do_sample=True,
        temperature=0.9,
    )
    print(13)
    output_ids = output[0][inputs["input_ids"].shape[1]:]
    print(14)
    response = processor.decode(output_ids, skip_special_tokens=True)
    print(15)
    return {"response": response}

@app.get("/aria-test")
def aria_test():
    print(1)
    model_id_or_path = "rhymes-ai/Aria"
    print(2)
    model = AriaForConditionalGeneration.from_pretrained(
        model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16
    )
    print(3)
    processor = AriaProcessor.from_pretrained(model_id_or_path)
    print(4)
    image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
    print(5)
    messages = [
        {
            "role": "user",
            "content": [
                {"type": "image"},
                {"text": "what is the image?", "type": "text"},
            ],
        }
    ]
    print(6)    
    text = processor.apply_chat_template(messages, add_generation_prompt=True)
    print(7)
    inputs = processor(text=text, images=image, return_tensors="pt")
    print(8)
    inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
    print(9)
    inputs.to(model.device)
    print(10)
    output = model.generate(
        **inputs,
        max_new_tokens=15,
        stop_strings=["<|im_end|>"],
        tokenizer=processor.tokenizer,
        do_sample=True,
        temperature=0.9,
    )
    output_ids = output[0][inputs["input_ids"].shape[1]:]
    print(11)
    response = processor.decode(output_ids, skip_special_tokens=True)
    print(12)
    return {"response": response}