File size: 3,128 Bytes
b5ed9fd c800f36 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 b5ed9fd cbe05b5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 |
import requests
import torch
from PIL import Image
from transformers import AriaProcessor, AriaForConditionalGeneration
from fastapi import FastAPI, Request
app = FastAPI()
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.post("/")
async def aria_image_to_text(request: Request):
print(1)
data = await request.json()
print(2)
image_url = data.get("image_url")
print(3)
print('image_url')
print(image_url)
image = Image.open(requests.get(image_url, stream=True).raw)
print(4)
model_id_or_path = "rhymes-ai/Aria"
print(5)
model = AriaForConditionalGeneration.from_pretrained(
model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16
)
print(6)
processor = AriaProcessor.from_pretrained(model_id_or_path)
print(7)
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"text": "what is the image?", "type": "text"},
],
}
]
print(8)
text = processor.apply_chat_template(messages, add_generation_prompt=True)
print(9)
inputs = processor(text=text, images=image, return_tensors="pt")
print(10)
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
print(11)
inputs.to(model.device)
print(12)
output = model.generate(
**inputs,
max_new_tokens=15,
stop_strings=["<|im_end|>"],
tokenizer=processor.tokenizer,
do_sample=True,
temperature=0.9,
)
print(13)
output_ids = output[0][inputs["input_ids"].shape[1]:]
print(14)
response = processor.decode(output_ids, skip_special_tokens=True)
print(15)
return {"response": response}
@app.get("/aria-test")
def aria_test():
print(1)
model_id_or_path = "rhymes-ai/Aria"
print(2)
model = AriaForConditionalGeneration.from_pretrained(
model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16
)
print(3)
processor = AriaProcessor.from_pretrained(model_id_or_path)
print(4)
image = Image.open(requests.get("http://images.cocodataset.org/val2017/000000039769.jpg", stream=True).raw)
print(5)
messages = [
{
"role": "user",
"content": [
{"type": "image"},
{"text": "what is the image?", "type": "text"},
],
}
]
print(6)
text = processor.apply_chat_template(messages, add_generation_prompt=True)
print(7)
inputs = processor(text=text, images=image, return_tensors="pt")
print(8)
inputs['pixel_values'] = inputs['pixel_values'].to(torch.bfloat16)
print(9)
inputs.to(model.device)
print(10)
output = model.generate(
**inputs,
max_new_tokens=15,
stop_strings=["<|im_end|>"],
tokenizer=processor.tokenizer,
do_sample=True,
temperature=0.9,
)
output_ids = output[0][inputs["input_ids"].shape[1]:]
print(11)
response = processor.decode(output_ids, skip_special_tokens=True)
print(12)
return {"response": response}
|