Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 1,927 Bytes
94753b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
import { InferenceOutputError } from "../../lib/InferenceOutputError";
import { getDefaultTask } from "../../lib/getDefaultTask";
import type { BaseArgs, Options } from "../../types";
import { request } from "../custom/request";
export type FeatureExtractionArgs = BaseArgs & {
/**
* The inputs is a string or a list of strings to get the features from.
*
* inputs: "That is a happy person",
*
*/
inputs: string | string[];
};
/**
* Returned values are a multidimensional array of floats (dimension depending on if you sent a string or a list of string, and if the automatic reduction, usually mean_pooling for instance was applied for you or not. This should be explained on the model's README).
*/
export type FeatureExtractionOutput = (number | number[] | number[][])[];
/**
* This task reads some text and outputs raw float values, that are usually consumed as part of a semantic database/semantic search.
*/
export async function featureExtraction(
args: FeatureExtractionArgs,
options?: Options
): Promise<FeatureExtractionOutput> {
const defaultTask = args.model ? await getDefaultTask(args.model, args.accessToken, options) : undefined;
const res = await request<FeatureExtractionOutput>(args, {
...options,
taskHint: "feature-extraction",
...(defaultTask === "sentence-similarity" && { forceTask: "feature-extraction" }),
});
let isValidOutput = true;
const isNumArrayRec = (arr: unknown[], maxDepth: number, curDepth = 0): boolean => {
if (curDepth > maxDepth) return false;
if (arr.every((x) => Array.isArray(x))) {
return arr.every((x) => isNumArrayRec(x as unknown[], maxDepth, curDepth + 1));
} else {
return arr.every((x) => typeof x === "number");
}
};
isValidOutput = Array.isArray(res) && isNumArrayRec(res, 3, 0);
if (!isValidOutput) {
throw new InferenceOutputError("Expected Array<number[][][] | number[][] | number[] | number>");
}
return res;
}
|