machineuser
Sync widgets demo
9d298eb
raw
history blame
3.96 kB
import type { ModelData } from "../model-data.js";
import type { PipelineType } from "../pipelines.js";
import { getModelInputSnippet } from "./inputs.js";
export const snippetZeroShotClassification = (model: ModelData): string =>
`def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
output = query({
"inputs": ${getModelInputSnippet(model)},
"parameters": {"candidate_labels": ["refund", "legal", "faq"]},
})`;
export const snippetBasic = (model: ModelData): string =>
`def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
output = query({
"inputs": ${getModelInputSnippet(model)},
})`;
export const snippetFile = (model: ModelData): string =>
`def query(filename):
with open(filename, "rb") as f:
data = f.read()
response = requests.post(API_URL, headers=headers, data=data)
return response.json()
output = query(${getModelInputSnippet(model)})`;
export const snippetTextToImage = (model: ModelData): string =>
`def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.content
image_bytes = query({
"inputs": ${getModelInputSnippet(model)},
})
# You can access the image with PIL.Image for example
import io
from PIL import Image
image = Image.open(io.BytesIO(image_bytes))`;
export const snippetTextToAudio = (model: ModelData): string => {
// Transformers TTS pipeline and api-inference-community (AIC) pipeline outputs are diverged
// with the latest update to inference-api (IA).
// Transformers IA returns a byte object (wav file), whereas AIC returns wav and sampling_rate.
if (model.library_name === "transformers") {
return `def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.content
audio_bytes = query({
"inputs": ${getModelInputSnippet(model)},
})
# You can access the audio with IPython.display for example
from IPython.display import Audio
Audio(audio_bytes)`;
} else {
return `def query(payload):
response = requests.post(API_URL, headers=headers, json=payload)
return response.json()
audio, sampling_rate = query({
"inputs": ${getModelInputSnippet(model)},
})
# You can access the audio with IPython.display for example
from IPython.display import Audio
Audio(audio, rate=sampling_rate)`;
}
};
export const pythonSnippets: Partial<Record<PipelineType, (model: ModelData) => string>> = {
// Same order as in js/src/lib/interfaces/Types.ts
"text-classification": snippetBasic,
"token-classification": snippetBasic,
"table-question-answering": snippetBasic,
"question-answering": snippetBasic,
"zero-shot-classification": snippetZeroShotClassification,
translation: snippetBasic,
summarization: snippetBasic,
conversational: snippetBasic,
"feature-extraction": snippetBasic,
"text-generation": snippetBasic,
"text2text-generation": snippetBasic,
"fill-mask": snippetBasic,
"sentence-similarity": snippetBasic,
"automatic-speech-recognition": snippetFile,
"text-to-image": snippetTextToImage,
"text-to-speech": snippetTextToAudio,
"text-to-audio": snippetTextToAudio,
"audio-to-audio": snippetFile,
"audio-classification": snippetFile,
"image-classification": snippetFile,
"image-to-text": snippetFile,
"object-detection": snippetFile,
"image-segmentation": snippetFile,
};
export function getPythonInferenceSnippet(model: ModelData, accessToken: string): string {
const body =
model.pipeline_tag && model.pipeline_tag in pythonSnippets ? pythonSnippets[model.pipeline_tag]?.(model) ?? "" : "";
return `import requests
API_URL = "https://api-inference.huggingface.co/models/${model.id}"
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}
${body}`;
}
export function hasPythonInferenceSnippet(model: ModelData): boolean {
return !!model.pipeline_tag && model.pipeline_tag in pythonSnippets;
}