File size: 929 Bytes
4e99448
94753b6
 
 
 
4e99448
b2041e0
94753b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import type { TextGenerationInput, TextGenerationOutput } from "@huggingface/tasks";
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";

export type { TextGenerationInput, TextGenerationOutput };

/**
 * Use to continue text from a prompt. This is a very generic task. Recommended model: gpt2 (it’s a simple model, but fun to play with).
 */
export async function textGeneration(
	args: BaseArgs & TextGenerationInput,
	options?: Options
): Promise<TextGenerationOutput> {
	const res = await request<TextGenerationOutput[]>(args, {
		...options,
		taskHint: "text-generation",
	});
	const isValidOutput = Array.isArray(res) && res.every((x) => typeof x?.generated_text === "string");
	if (!isValidOutput) {
		throw new InferenceOutputError("Expected Array<{generated_text: string}>");
	}
	return res?.[0];
}