Aria / app.py
Paul DAMPFHOEFFER
fix: small fix
cbe05b5
raw
history blame
3.13 kB
import requests
import torch
from PIL import Image
from transformers import AriaProcessor, AriaForConditionalGeneration
from fastapi import FastAPI, Request
app = FastAPI()
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.post("/")
async def aria_image_to_text(request: Request):
print(1)
data = await request.json()
print(2)
image_url = data.get("image_url")
print(3)
print('image_url')
print(image_url)
image = Image.open(requests.get(image_url, stream=True).raw)
print(4)
model_id_or_path = "rhymes-ai/Aria"
print(5)
model = AriaForConditionalGeneration.from_pretrained(
model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16
)
print(6)
processor = AriaProcessor.from_pretrained(model_id_or_path)
print(7)
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"text": "what is the image?", "type": "text"},
],
}
]
print(8)
text = processor.apply_chat_template(messages, add_generation_prompt=True)
print(9)
inputs = processor(text=text, images=image, return_tensors="pt")
print(10)
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
print(11)
inputs.to(model.device)
print(12)
output = model.generate(
**inputs,
max_new_tokens=15,
stop_strings=["<|im_end|>"],
tokenizer=processor.tokenizer,
do_sample=True,
temperature=0.9,
)
print(13)
output_ids = output[0][inputs["input_ids"].shape[1]:]
print(14)
response = processor.decode(output_ids, skip_special_tokens=True)
print(15)
return {"response": response}
@app.get("/aria-test")
def aria_test():
print(1)
model_id_or_path = "rhymes-ai/Aria"
print(2)
model = AriaForConditionalGeneration.from_pretrained(
model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16
)
print(3)
processor = AriaProcessor.from_pretrained(model_id_or_path)
print(4)
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
print(5)
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"text": "what is the image?", "type": "text"},
],
}
]
print(6)
text = processor.apply_chat_template(messages, add_generation_prompt=True)
print(7)
inputs = processor(text=text, images=image, return_tensors="pt")
print(8)
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
print(9)
inputs.to(model.device)
print(10)
output = model.generate(
**inputs,
max_new_tokens=15,
stop_strings=["<|im_end|>"],
tokenizer=processor.tokenizer,
do_sample=True,
temperature=0.9,
)
output_ids = output[0][inputs["input_ids"].shape[1]:]
print(11)
response = processor.decode(output_ids, skip_special_tokens=True)
print(12)
return {"response": response}