File size: 3,803 Bytes
6e91187 a93bd14 4573e09 740085d 6e91187 4573e09 a93bd14 4573e09 a93bd14 4573e09 9409942 d872a4a 4573e09 cadc76c 4573e09 8519f94 346f901 8519f94 3ef185b 8519f94 d4e2f52 cadc76c d4e2f52 4573e09 d872a4a 6e91187 f080981 6e91187 4573e09 6e91187 4573e09 6e91187 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 |
import gradio as gr
import spaces
from PIL import Image
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
import requests
import json
from io import BytesIO
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True)
SERVER_URL = 'http://43.156.72.113:8188'
FETCH_TASKS_URL = SERVER_URL + '/fetch/'
UPDATE_TASK_STATUS_URL = SERVER_URL + '/update/'
def fetch_task(category, fetch_all=False):
params = {'fetch_all': 'true' if fetch_all else 'false'}
response = requests.post(FETCH_TASKS_URL + category, params=params)
if response.status_code == 200:
return response.json()
else:
print(f"Failed to fetch tasks: {response.status_code} - {response.text}")
return None
def update_task_status(category, task_id, status, result=None):
data = {'status': status}
if result:
data['result'] = result
response = requests.post(UPDATE_TASK_STATUS_URL + category + f'/{task_id}', json=data)
if response.status_code == 200:
print(f"Task {task_id} updated successfully: {json.dumps(response.json(), indent=4)}")
else:
print(f"Failed to update task {task_id}: {response.status_code} - {response.text}")
@spaces.GPU(duration=150)
def infer():
img2text_tasks = fetch_task('img2text', fetch_all=True)
if not img2text_tasks:
return "No tasks found."
for task in img2text_tasks:
try:
image_url = task['content']['url']
prompt = task['content']['prompt']
image_response = requests.get(image_url)
image = Image.open(BytesIO(image_response.content)).convert("RGB")
max_size = 256
width, height = image.size
if width > height:
new_width = max_size
new_height = int((new_width / width) * height)
else:
new_height = max_size
new_width = int((new_height / height) * width)
image = image.resize((new_width, new_height), Image.LANCZOS)
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
do_sample=False,
num_beams=3
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
update_task_status('img2text', task['id'], 'Successed', {"text": parsed_answer})
except Exception as e:
print(f"Error processing task {task['id']}: {e}")
update_task_status('img2text', task['id'], 'Failed', {"error": str(e)})
return f"Error processing task {task['id']}: {e}"
return "Successed! No pending tasks found."
css = """
#col-container {
margin: 0 auto;
max-width: 800px;
}
"""
with gr.Blocks(css=css) as app:
with gr.Column(elem_id="col-container"):
gr.Markdown(f"""# Tag The Image
Get tag based on images using the Florence-2-base-PromptGen-v1.5 model.
""")
run_button = gr.Button("Submit", scale=0, elem_id="run-button")
result = gr.Textbox(label="Generated Text", show_label=False)
gr.on(
triggers=[run_button.click],
fn=infer,
inputs=[],
outputs=[result]
)
app.queue()
app.launch(show_error=True)
|