import requests import torch from PIL import Image from transformers import AriaProcessor, AriaForConditionalGeneration from fastapi import FastAPI, Request app = FastAPI() @app.get("/") def greet_json(): return {"Hello": "World!"} @app.post("/") async def aria_image_to_text(request: Request): data = await request.json() image_url = data.get("image_url") image = Image.open(requests.get(image_url, stream=True).raw) model_id_or_path = "rhymes-ai/Aria" model = AriaForConditionalGeneration.from_pretrained( model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16 ) processor = AriaProcessor.from_pretrained(model_id_or_path) messages = [ { "role": "user", "content": [ {"type": "image"}, {"text": "what is the image?", "type": "text"}, ], } ] text = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=text, images=image, return_tensors="pt") inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16) inputs.to(model.device) output = model.generate( **inputs, max_new_tokens=15, stop_strings=["<|im_end|>"], tokenizer=processor.tokenizer, do_sample=True, temperature=0.9, ) output_ids = output[0][inputs["input_ids"].shape[1]:] response = processor.decode(output_ids, skip_special_tokens=True) return {"response": response} @app.get("/aria-test") def aria_test(): model_id_or_path = "rhymes-ai/Aria" model = AriaForConditionalGeneration.from_pretrained( model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16 ) processor = AriaProcessor.from_pretrained(model_id_or_path) image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw) messages = [ { "role": "user", "content": [ {"type": "image"}, {"text": "what is the image?", "type": "text"}, ], } ] text = processor.apply_chat_template(messages, add_generation_prompt=True) inputs = processor(text=text, images=image, return_tensors="pt") inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16) inputs.to(model.device) output = model.generate( **inputs, max_new_tokens=15, stop_strings=["<|im_end|>"], tokenizer=processor.tokenizer, do_sample=True, temperature=0.9, ) output_ids = output[0][inputs["input_ids"].shape[1]:] response = processor.decode(output_ids, skip_special_tokens=True) return {"response": response}