File size: 2,728 Bytes
b5ed9fd c800f36 b5ed9fd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import requests
import torch
from PIL import Image
from transformers import AriaProcessor, AriaForConditionalGeneration
from fastapi import FastAPI, Request
app = FastAPI()
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.post("/")
async def aria_image_to_text(request: Request):
data = await request.json()
image_url = data.get("image_url")
image = Image.open(requests.get(image_url, stream=True).raw)
model_id_or_path = "rhymes-ai/Aria"
model = AriaForConditionalGeneration.from_pretrained(
model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16
)
processor = AriaProcessor.from_pretrained(model_id_or_path)
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"text": "what is the image?", "type": "text"},
],
}
]
text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=text, images=image, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
inputs.to(model.device)
output = model.generate(
**inputs,
max_new_tokens=15,
stop_strings=["<|im_end|>"],
tokenizer=processor.tokenizer,
do_sample=True,
temperature=0.9,
)
output_ids = output[0][inputs["input_ids"].shape[1]:]
response = processor.decode(output_ids, skip_special_tokens=True)
return {"response": response}
@app.get("/aria-test")
def aria_test():
model_id_or_path = "rhymes-ai/Aria"
model = AriaForConditionalGeneration.from_pretrained(
model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16
)
processor = AriaProcessor.from_pretrained(model_id_or_path)
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"text": "what is the image?", "type": "text"},
],
}
]
text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(text=text, images=image, return_tensors="pt")
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
inputs.to(model.device)
output = model.generate(
**inputs,
max_new_tokens=15,
stop_strings=["<|im_end|>"],
tokenizer=processor.tokenizer,
do_sample=True,
temperature=0.9,
)
output_ids = output[0][inputs["input_ids"].shape[1]:]
response = processor.decode(output_ids, skip_special_tokens=True)
return {"response": response} |