Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import type { ModelData } from "../model-data.js"; | |
import type { PipelineType } from "../pipelines.js"; | |
import { getModelInputSnippet } from "./inputs.js"; | |
export const snippetZeroShotClassification = (model: ModelData): string => | |
`def query(payload): | |
response = requests.post(API_URL, headers=headers, json=payload) | |
return response.json() | |
output = query({ | |
"inputs": ${getModelInputSnippet(model)}, | |
"parameters": {"candidate_labels": ["refund", "legal", "faq"]}, | |
})`; | |
export const snippetBasic = (model: ModelData): string => | |
`def query(payload): | |
response = requests.post(API_URL, headers=headers, json=payload) | |
return response.json() | |
output = query({ | |
"inputs": ${getModelInputSnippet(model)}, | |
})`; | |
export const snippetFile = (model: ModelData): string => | |
`def query(filename): | |
with open(filename, "rb") as f: | |
data = f.read() | |
response = requests.post(API_URL, headers=headers, data=data) | |
return response.json() | |
output = query(${getModelInputSnippet(model)})`; | |
export const snippetTextToImage = (model: ModelData): string => | |
`def query(payload): | |
response = requests.post(API_URL, headers=headers, json=payload) | |
return response.content | |
image_bytes = query({ | |
"inputs": ${getModelInputSnippet(model)}, | |
}) | |
# You can access the image with PIL.Image for example | |
import io | |
from PIL import Image | |
image = Image.open(io.BytesIO(image_bytes))`; | |
export const snippetTextToAudio = (model: ModelData): string => { | |
// Transformers TTS pipeline and api-inference-community (AIC) pipeline outputs are diverged | |
// with the latest update to inference-api (IA). | |
// Transformers IA returns a byte object (wav file), whereas AIC returns wav and sampling_rate. | |
if (model.library_name === "transformers") { | |
return `def query(payload): | |
response = requests.post(API_URL, headers=headers, json=payload) | |
return response.content | |
audio_bytes = query({ | |
"inputs": ${getModelInputSnippet(model)}, | |
}) | |
# You can access the audio with IPython.display for example | |
from IPython.display import Audio | |
Audio(audio_bytes)`; | |
} else { | |
return `def query(payload): | |
response = requests.post(API_URL, headers=headers, json=payload) | |
return response.json() | |
audio, sampling_rate = query({ | |
"inputs": ${getModelInputSnippet(model)}, | |
}) | |
# You can access the audio with IPython.display for example | |
from IPython.display import Audio | |
Audio(audio, rate=sampling_rate)`; | |
} | |
}; | |
export const pythonSnippets: Partial<Record<PipelineType, (model: ModelData) => string>> = { | |
// Same order as in js/src/lib/interfaces/Types.ts | |
"text-classification": snippetBasic, | |
"token-classification": snippetBasic, | |
"table-question-answering": snippetBasic, | |
"question-answering": snippetBasic, | |
"zero-shot-classification": snippetZeroShotClassification, | |
translation: snippetBasic, | |
summarization: snippetBasic, | |
conversational: snippetBasic, | |
"feature-extraction": snippetBasic, | |
"text-generation": snippetBasic, | |
"text2text-generation": snippetBasic, | |
"fill-mask": snippetBasic, | |
"sentence-similarity": snippetBasic, | |
"automatic-speech-recognition": snippetFile, | |
"text-to-image": snippetTextToImage, | |
"text-to-speech": snippetTextToAudio, | |
"text-to-audio": snippetTextToAudio, | |
"audio-to-audio": snippetFile, | |
"audio-classification": snippetFile, | |
"image-classification": snippetFile, | |
"image-to-text": snippetFile, | |
"object-detection": snippetFile, | |
"image-segmentation": snippetFile, | |
}; | |
export function getPythonInferenceSnippet(model: ModelData, accessToken: string): string { | |
const body = | |
model.pipeline_tag && model.pipeline_tag in pythonSnippets ? pythonSnippets[model.pipeline_tag]?.(model) ?? "" : ""; | |
return `import requests | |
API_URL = "https://api-inference.huggingface.co/models/${model.id}" | |
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}} | |
${body}`; | |
} | |
export function hasPythonInferenceSnippet(model: ModelData): boolean { | |
return !!model.pipeline_tag && model.pipeline_tag in pythonSnippets; | |
} | |