Aria / app.py
Paul DAMPFHOEFFER
feat: init flash api
b5ed9fd
raw
history blame
2.73 kB
import requests
import torch
from PIL import Image
from transformers import AriaProcessor, AriaForConditionalGeneration
from fastapi import FastAPI, Request
app = FastAPI()
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.post("/")
async def aria_image_to_text(request: Request):
data = await request.json()
image_url = data.get("image_url")
image = Image.open(requests.get(image_url, stream=True).raw)
model_id_or_path = "rhymes-ai/Aria"
model = AriaForConditionalGeneration.from_pretrained(
model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16
)
processor = AriaProcessor.from_pretrained(model_id_or_path)
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"text": "what is the image?", "type": "text"},
],
}
]
text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=text, images=image, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
inputs.to(model.device)
output = model.generate(
**inputs,
max_new_tokens=15,
stop_strings=["<|im_end|>"],
tokenizer=processor.tokenizer,
do_sample=True,
temperature=0.9,
)
output_ids = output[0][inputs["input_ids"].shape[1]:]
response = processor.decode(output_ids, skip_special_tokens=True)
return {"response": response}
@app.get("/aria-test")
def aria_test():
model_id_or_path = "rhymes-ai/Aria"
model = AriaForConditionalGeneration.from_pretrained(
model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16
)
processor = AriaProcessor.from_pretrained(model_id_or_path)
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"text": "what is the image?", "type": "text"},
],
}
]
text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=text, images=image, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
inputs.to(model.device)
output = model.generate(
**inputs,
max_new_tokens=15,
stop_strings=["<|im_end|>"],
tokenizer=processor.tokenizer,
do_sample=True,
temperature=0.9,
)
output_ids = output[0][inputs["input_ids"].shape[1]:]
response = processor.decode(output_ids, skip_special_tokens=True)
return {"response": response}