File size: 2,516 Bytes
c0dd699
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import { OpenAPIRoute, contentJson } from 'chanfana';
import { Context } from 'hono';
import { z } from 'zod';
import { nanoid } from 'nanoid';
import { SQSClient, SendMessageCommand } from '@aws-sdk/client-sqs';

import { Bindings, TaskStatus } from '../types';

const PREMIUM_SCALE = 8;

const isPremium = (jwtPayload: any) => {
	return !!jwtPayload.experimentalPremiumAccess; // TODO...
};
const isNeedPremiun = () => {
	return false; // TODO...
};

export class CreateTask extends OpenAPIRoute {
	schema = {
		summary: '新建任务',
		description:
			'新建一个任务,把任务发送到 Queue 中等待处理。<br> id为可选项,需要上传nanoid,如果不上传则后端自动生成。',

		request: {
			body: contentJson({
				id: z.string().nanoid().optional(),
				prompt: z.string(),
				lora_name: z.string(),
				cover_size: z.string(),
				// image_url: z.string().url(),
			}),
		},
	};

	async handle(c: Context<{ Bindings: Bindings }>) {
		const data = await this.getValidatedData<typeof this.schema>();
		const jwtPayload = c.get('jwtPayload');
		if (isPremium(jwtPayload) === false && isNeedPremiun()) {
			return c.json({ error: 'Premium Access Required' }, { status: 403 });
		}

		const task_id = data.body.id || nanoid();

		const doId = c.env.TASK_STATUS_DURABLE_OBJECT.idFromName('a1d-flux-task-status');
		const stub = c.env.TASK_STATUS_DURABLE_OBJECT.get(doId);
		const status = await stub.getStatus(data.body.id);
		if (status?.id) {
			return c.json({ error: 'Task already exist' }, { status: 409 });
		}

		const client = new SQSClient({
			region: c.env.AWS_REGION,
			credentials: {
				// use wrangler secrets to provide these global variables
				accessKeyId: c.env.AWS_ACCESS_KEY_ID,
				secretAccessKey: c.env.AWS_SECRET_ACCESS_KEY,
			},
		});

		const prompt = data.body.prompt;
		const lora_name = data.body.lora_name;
		const cover_size = data.body.cover_size;
		const msg = {
			prompt,
			lora_name,
			cover_size,
			task_id,
			authorization: c.req.header('Authorization'),
		};
		const send = new SendMessageCommand({
			QueueUrl: c.env.AWS_SQS_QUEUE_URL,
			MessageBody: JSON.stringify(msg),
		});

		await client.send(send);

		c.env.BILLING?.writeDataPoint({
			blobs: [jwtPayload.experimentalPremiumAccess, prompt, lora_name, cover_size],
			doubles: [],
			indexes: [jwtPayload.userId],
		});

		const result = await stub.setStatus(task_id, {
			id: task_id,
			status: TaskStatus.WAITING,
			timestamp: Date.now(),
		});

		return c.json(result);
	}
}