import type { ModelData } from "../model-data.js"; import type { PipelineType } from "../pipelines.js"; import { getModelInputSnippet } from "./inputs.js"; export const snippetZeroShotClassification = (model: ModelData): string => `def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.json() output = query({ "inputs": ${getModelInputSnippet(model)}, "parameters": {"candidate_labels": ["refund", "legal", "faq"]}, })`; export const snippetBasic = (model: ModelData): string => `def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.json() output = query({ "inputs": ${getModelInputSnippet(model)}, })`; export const snippetFile = (model: ModelData): string => `def query(filename): with open(filename, "rb") as f: data = f.read() response = requests.post(API_URL, headers=headers, data=data) return response.json() output = query(${getModelInputSnippet(model)})`; export const snippetTextToImage = (model: ModelData): string => `def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.content image_bytes = query({ "inputs": ${getModelInputSnippet(model)}, }) # You can access the image with PIL.Image for example import io from PIL import Image image = Image.open(io.BytesIO(image_bytes))`; export const snippetTextToAudio = (model: ModelData): string => { // Transformers TTS pipeline and api-inference-community (AIC) pipeline outputs are diverged // with the latest update to inference-api (IA). // Transformers IA returns a byte object (wav file), whereas AIC returns wav and sampling_rate. if (model.library_name === "transformers") { return `def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.content audio_bytes = query({ "inputs": ${getModelInputSnippet(model)}, }) # You can access the audio with IPython.display for example from IPython.display import Audio Audio(audio_bytes)`; } else { return `def query(payload): response = requests.post(API_URL, headers=headers, json=payload) return response.json() audio, sampling_rate = query({ "inputs": ${getModelInputSnippet(model)}, }) # You can access the audio with IPython.display for example from IPython.display import Audio Audio(audio, rate=sampling_rate)`; } }; export const pythonSnippets: Partial string>> = { // Same order as in js/src/lib/interfaces/Types.ts "text-classification": snippetBasic, "token-classification": snippetBasic, "table-question-answering": snippetBasic, "question-answering": snippetBasic, "zero-shot-classification": snippetZeroShotClassification, translation: snippetBasic, summarization: snippetBasic, conversational: snippetBasic, "feature-extraction": snippetBasic, "text-generation": snippetBasic, "text2text-generation": snippetBasic, "fill-mask": snippetBasic, "sentence-similarity": snippetBasic, "automatic-speech-recognition": snippetFile, "text-to-image": snippetTextToImage, "text-to-speech": snippetTextToAudio, "text-to-audio": snippetTextToAudio, "audio-to-audio": snippetFile, "audio-classification": snippetFile, "image-classification": snippetFile, "image-to-text": snippetFile, "object-detection": snippetFile, "image-segmentation": snippetFile, }; export function getPythonInferenceSnippet(model: ModelData, accessToken: string): string { const body = model.pipeline_tag && model.pipeline_tag in pythonSnippets ? pythonSnippets[model.pipeline_tag]?.(model) ?? "" : ""; return `import requests API_URL = "https://api-inference.huggingface.co/models/${model.id}" headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}} ${body}`; } export function hasPythonInferenceSnippet(model: ModelData): boolean { return !!model.pipeline_tag && model.pipeline_tag in pythonSnippets; }