File size: 3,803 Bytes
6e91187
 
 
 
 
a93bd14
4573e09
740085d
6e91187
 
 
 
 
 
4573e09
 
 
a93bd14
4573e09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a93bd14
4573e09
9409942
d872a4a
 
4573e09
 
 
cadc76c
4573e09
 
8519f94
 
 
346f901
8519f94
3ef185b
8519f94
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4e2f52
 
cadc76c
d4e2f52
4573e09
d872a4a
6e91187
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f080981
6e91187
 
 
4573e09
6e91187
4573e09
6e91187
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import gradio as gr
import spaces
from PIL import Image
import torch
from transformers import AutoModelForCausalLM, AutoProcessor
import requests
import json
from io import BytesIO

device = "cuda" if torch.cuda.is_available() else "cpu"

model = AutoModelForCausalLM.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True).to(device)
processor = AutoProcessor.from_pretrained("MiaoshouAI/Florence-2-base-PromptGen-v1.5", trust_remote_code=True)

SERVER_URL = 'http://43.156.72.113:8188'
FETCH_TASKS_URL = SERVER_URL + '/fetch/'
UPDATE_TASK_STATUS_URL = SERVER_URL + '/update/'

def fetch_task(category, fetch_all=False):
    params = {'fetch_all': 'true' if fetch_all else 'false'}
    response = requests.post(FETCH_TASKS_URL + category, params=params)
    if response.status_code == 200:
        return response.json()
    else:
        print(f"Failed to fetch tasks: {response.status_code} - {response.text}")
        return None

def update_task_status(category, task_id, status, result=None):
    data = {'status': status}
    if result:
        data['result'] = result

    response = requests.post(UPDATE_TASK_STATUS_URL + category + f'/{task_id}', json=data)
    if response.status_code == 200:
        print(f"Task {task_id} updated successfully: {json.dumps(response.json(), indent=4)}")
    else:
        print(f"Failed to update task {task_id}: {response.status_code} - {response.text}")


@spaces.GPU(duration=150)
def infer():

    img2text_tasks = fetch_task('img2text', fetch_all=True)

    if not img2text_tasks:
        return "No tasks found."

    for task in img2text_tasks:
        try:
            image_url = task['content']['url']
            prompt = task['content']['prompt']

            image_response = requests.get(image_url)
            image = Image.open(BytesIO(image_response.content)).convert("RGB")

            max_size = 256
            width, height = image.size
            if width > height:
                new_width = max_size
                new_height = int((new_width / width) * height)
            else:
                new_height = max_size
                new_width = int((new_height / height) * width)
            image = image.resize((new_width, new_height), Image.LANCZOS)

            inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
            generated_ids = model.generate(
                input_ids=inputs["input_ids"],
                pixel_values=inputs["pixel_values"],
                max_new_tokens=1024,
                do_sample=False,
                num_beams=3
            )

            generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
            parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))

            update_task_status('img2text', task['id'], 'Successed', {"text": parsed_answer})
        except Exception as e:
            print(f"Error processing task {task['id']}: {e}")
            update_task_status('img2text', task['id'], 'Failed', {"error": str(e)})
            return f"Error processing task {task['id']}: {e}"

    return "Successed! No pending tasks found."


css = """
#col-container {
    margin: 0 auto;
    max-width: 800px;
}
"""

with gr.Blocks(css=css) as app:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""# Tag The Image
        Get tag based on images using the Florence-2-base-PromptGen-v1.5 model.
        """)

        run_button = gr.Button("Submit", scale=0, elem_id="run-button")
        result = gr.Textbox(label="Generated Text", show_label=False)

    gr.on(
        triggers=[run_button.click],
        fn=infer,
        inputs=[],
        outputs=[result]
    )

app.queue()
app.launch(show_error=True)