File size: 2,517 Bytes
9c18e52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import json
import os
import time

import gradio as gr
import requests

from src.log import logger
from src.util import download_images

anystory_url = os.getenv("ANYSTORY_URL")
anystory_api_key = os.getenv("ANYSTORY_DS_API_KEY")
anystory_model = os.getenv("ANYSTORY_MODEL")


def call_anystory(image_urls, prompt):
    headers = {
        "Content-Type": "application/json",
        "Accept": "application/json",
        "Authorization": f"Bearer {anystory_api_key}",
        "X-DashScope-Async": "enable",
        "X-DashScope-DataInspection": "enable",
    }
    data = {
        "model": anystory_model,
        "input": {
            "image_urls": image_urls,
            "prompt": prompt
        },
        "parameters": {
        },
    }

    res = requests.post(anystory_url, data=json.dumps(data), headers=headers)

    respose_code = res.status_code
    if 200 == respose_code:
        res = json.loads(res.content.decode())
        task_id = res['output']['task_id']
        logger.info(f"task_id: {task_id}: Create request success. Params: {data}")

        # Async query
        is_running = True
        while is_running:
            res = requests.post(f'https://poc-dashscope.aliyuncs.com/api/v1/tasks/{task_id}', headers=headers)
            respose_code = res.status_code
            if 200 == respose_code:
                res = json.loads(res.content.decode())
                if "SUCCEEDED" == res['output']['task_status']:
                    logger.info(f"task_id: {task_id}: Generation task query success.")
                    results = res['output']['results']
                    img_urls = [x['url'] for x in results]
                    logger.info(f"task_id: {task_id}: {res}")
                    break
                elif "FAILED" != res['output']['task_status']:
                    logger.debug(f"task_id: {task_id}: query result...")
                    time.sleep(1)
                else:
                    raise gr.Error("Fail to get results from Generation task.")

            else:
                logger.error(f'task_id: {task_id}: Fail to query task result: {res.content}')
                raise gr.Error("Fail to query task result.")

        logger.info(f"task_id: {task_id}: download generated images.")
        img_data = download_images(img_urls)
        logger.info(f"task_id: {task_id}: Generate done.")
    else:
        logger.error(f'Fail to create Generation task: {res.content}')
        raise gr.Error("Fail to create Generation task.")

    return img_data