import { OpenAPIRoute, contentJson } from 'chanfana';
import { Context } from 'hono';
import { z } from 'zod';
import { nanoid } from 'nanoid';
import { SQSClient, SendMessageCommand } from '@aws-sdk/client-sqs';
import { Bindings, TaskStatus } from '../types';
const PREMIUM_SCALE = 8;
const isPremium = (jwtPayload: any) => {
return !!jwtPayload.experimentalPremiumAccess; // TODO...
};
const isNeedPremiun = () => {
return false; // TODO...
};
export class CreateTask extends OpenAPIRoute {
schema = {
summary: '新建任务',
description:
'新建一个任务,把任务发送到 Queue 中等待处理。
id为可选项,需要上传nanoid,如果不上传则后端自动生成。',
request: {
body: contentJson({
id: z.string().nanoid().optional(),
prompt: z.string(),
lora_name: z.string(),
cover_size: z.string(),
// image_url: z.string().url(),
}),
},
};
async handle(c: Context<{ Bindings: Bindings }>) {
const data = await this.getValidatedData();
const jwtPayload = c.get('jwtPayload');
if (isPremium(jwtPayload) === false && isNeedPremiun()) {
return c.json({ error: 'Premium Access Required' }, { status: 403 });
}
const task_id = data.body.id || nanoid();
const doId = c.env.TASK_STATUS_DURABLE_OBJECT.idFromName('a1d-flux-task-status');
const stub = c.env.TASK_STATUS_DURABLE_OBJECT.get(doId);
const status = await stub.getStatus(data.body.id);
if (status?.id) {
return c.json({ error: 'Task already exist' }, { status: 409 });
}
const client = new SQSClient({
region: c.env.AWS_REGION,
credentials: {
// use wrangler secrets to provide these global variables
accessKeyId: c.env.AWS_ACCESS_KEY_ID,
secretAccessKey: c.env.AWS_SECRET_ACCESS_KEY,
},
});
const prompt = data.body.prompt;
const lora_name = data.body.lora_name;
const cover_size = data.body.cover_size;
const msg = {
prompt,
lora_name,
cover_size,
task_id,
authorization: c.req.header('Authorization'),
};
const send = new SendMessageCommand({
QueueUrl: c.env.AWS_SQS_QUEUE_URL,
MessageBody: JSON.stringify(msg),
});
await client.send(send);
c.env.BILLING?.writeDataPoint({
blobs: [jwtPayload.experimentalPremiumAccess, prompt, lora_name, cover_size],
doubles: [],
indexes: [jwtPayload.userId],
});
const result = await stub.setStatus(task_id, {
id: task_id,
status: TaskStatus.WAITING,
timestamp: Date.now(),
});
return c.json(result);
}
}