|
import requests |
|
import torch |
|
from PIL import Image |
|
|
|
from transformers import AriaProcessor, AriaForConditionalGeneration |
|
from fastapi import FastAPI, Request |
|
|
|
app = FastAPI() |
|
|
|
@app.get("/") |
|
def greet_json(): |
|
return {"Hello": "World!"} |
|
|
|
@app.post("/") |
|
async def aria_image_to_text(request: Request): |
|
print(1) |
|
data = await request.json() |
|
print(2) |
|
image_url = data.get("image_url") |
|
print(3) |
|
print('image_url') |
|
print(image_url) |
|
image = Image.open(requests.get(image_url, stream=True).raw) |
|
print(4) |
|
model_id_or_path = "rhymes-ai/Aria" |
|
print(5) |
|
model = AriaForConditionalGeneration.from_pretrained( |
|
model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16 |
|
) |
|
print(6) |
|
processor = AriaProcessor.from_pretrained(model_id_or_path) |
|
print(7) |
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "image"}, |
|
{"text": "what is the image?", "type": "text"}, |
|
], |
|
} |
|
] |
|
|
|
print(8) |
|
text = processor.apply_chat_template(messages, add_generation_prompt=True) |
|
print(9) |
|
inputs = processor(text=text, images=image, return_tensors="pt") |
|
print(10) |
|
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16) |
|
print(11) |
|
inputs.to(model.device) |
|
print(12) |
|
output = model.generate( |
|
**inputs, |
|
max_new_tokens=15, |
|
stop_strings=["<|im_end|>"], |
|
tokenizer=processor.tokenizer, |
|
do_sample=True, |
|
temperature=0.9, |
|
) |
|
print(13) |
|
output_ids = output[0][inputs["input_ids"].shape[1]:] |
|
print(14) |
|
response = processor.decode(output_ids, skip_special_tokens=True) |
|
print(15) |
|
return {"response": response} |
|
|
|
@app.get("/aria-test") |
|
def aria_test(): |
|
print(1) |
|
model_id_or_path = "rhymes-ai/Aria" |
|
print(2) |
|
model = AriaForConditionalGeneration.from_pretrained( |
|
model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16 |
|
) |
|
print(3) |
|
processor = AriaProcessor.from_pretrained(model_id_or_path) |
|
print(4) |
|
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) |
|
print(5) |
|
messages = [ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{"type": "image"}, |
|
{"text": "what is the image?", "type": "text"}, |
|
], |
|
} |
|
] |
|
print(6) |
|
text = processor.apply_chat_template(messages, add_generation_prompt=True) |
|
print(7) |
|
inputs = processor(text=text, images=image, return_tensors="pt") |
|
print(8) |
|
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16) |
|
print(9) |
|
inputs.to(model.device) |
|
print(10) |
|
output = model.generate( |
|
**inputs, |
|
max_new_tokens=15, |
|
stop_strings=["<|im_end|>"], |
|
tokenizer=processor.tokenizer, |
|
do_sample=True, |
|
temperature=0.9, |
|
) |
|
output_ids = output[0][inputs["input_ids"].shape[1]:] |
|
print(11) |
|
response = processor.decode(output_ids, skip_special_tokens=True) |
|
print(12) |
|
return {"response": response} |
|
|