inference-widgets / packages /tasks /src /library-ui-elements.ts
machineuser
Sync widgets demo
9d298eb
raw
history blame
21.1 kB
import type { ModelData } from "./model-data";
import type { ModelLibraryKey } from "./model-libraries";
/**
* Elements configurable by a model library.
*/
export interface LibraryUiElement {
/**
* Name displayed on the main
* call-to-action button on the model page.
*/
btnLabel: string;
/**
* Repo name
*/
repoName: string;
/**
* URL to library's repo
*/
repoUrl: string;
/**
* URL to library's docs
*/
docsUrl?: string;
/**
* Code snippet displayed on model page
*/
snippets: (model: ModelData) => string[];
}
function nameWithoutNamespace(modelId: string): string {
const splitted = modelId.split("/");
return splitted.length === 1 ? splitted[0] : splitted[1];
}
//#region snippets
const adapter_transformers = (model: ModelData) => [
`from transformers import ${model.config?.adapter_transformers?.model_class}
model = ${model.config?.adapter_transformers?.model_class}.from_pretrained("${model.config?.adapter_transformers?.model_name}")
model.load_adapter("${model.id}", source="hf")`,
];
const allennlpUnknown = (model: ModelData) => [
`import allennlp_models
from allennlp.predictors.predictor import Predictor
predictor = Predictor.from_path("hf://${model.id}")`,
];
const allennlpQuestionAnswering = (model: ModelData) => [
`import allennlp_models
from allennlp.predictors.predictor import Predictor
predictor = Predictor.from_path("hf://${model.id}")
predictor_input = {"passage": "My name is Wolfgang and I live in Berlin", "question": "Where do I live?"}
predictions = predictor.predict_json(predictor_input)`,
];
const allennlp = (model: ModelData) => {
if (model.tags?.includes("question-answering")) {
return allennlpQuestionAnswering(model);
}
return allennlpUnknown(model);
};
const asteroid = (model: ModelData) => [
`from asteroid.models import BaseModel
model = BaseModel.from_pretrained("${model.id}")`,
];
function get_base_diffusers_model(model: ModelData): string {
return model.cardData?.base_model ?? "fill-in-base-model";
}
const bertopic = (model: ModelData) => [
`from bertopic import BERTopic
model = BERTopic.load("${model.id}")`,
];
const diffusers_default = (model: ModelData) => [
`from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained("${model.id}")`,
];
const diffusers_controlnet = (model: ModelData) => [
`from diffusers import ControlNetModel, StableDiffusionControlNetPipeline
controlnet = ControlNetModel.from_pretrained("${model.id}")
pipeline = StableDiffusionControlNetPipeline.from_pretrained(
"${get_base_diffusers_model(model)}", controlnet=controlnet
)`,
];
const diffusers_lora = (model: ModelData) => [
`from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained("${get_base_diffusers_model(model)}")
pipeline.load_lora_weights("${model.id}")`,
];
const diffusers_textual_inversion = (model: ModelData) => [
`from diffusers import DiffusionPipeline
pipeline = DiffusionPipeline.from_pretrained("${get_base_diffusers_model(model)}")
pipeline.load_textual_inversion("${model.id}")`,
];
const diffusers = (model: ModelData) => {
if (model.tags?.includes("controlnet")) {
return diffusers_controlnet(model);
} else if (model.tags?.includes("lora")) {
return diffusers_lora(model);
} else if (model.tags?.includes("textual_inversion")) {
return diffusers_textual_inversion(model);
} else {
return diffusers_default(model);
}
};
const espnetTTS = (model: ModelData) => [
`from espnet2.bin.tts_inference import Text2Speech
model = Text2Speech.from_pretrained("${model.id}")
speech, *_ = model("text to generate speech from")`,
];
const espnetASR = (model: ModelData) => [
`from espnet2.bin.asr_inference import Speech2Text
model = Speech2Text.from_pretrained(
"${model.id}"
)
speech, rate = soundfile.read("speech.wav")
text, *_ = model(speech)[0]`,
];
const espnetUnknown = () => [`unknown model type (must be text-to-speech or automatic-speech-recognition)`];
const espnet = (model: ModelData) => {
if (model.tags?.includes("text-to-speech")) {
return espnetTTS(model);
} else if (model.tags?.includes("automatic-speech-recognition")) {
return espnetASR(model);
}
return espnetUnknown();
};
const fairseq = (model: ModelData) => [
`from fairseq.checkpoint_utils import load_model_ensemble_and_task_from_hf_hub
models, cfg, task = load_model_ensemble_and_task_from_hf_hub(
"${model.id}"
)`,
];
const flair = (model: ModelData) => [
`from flair.models import SequenceTagger
tagger = SequenceTagger.load("${model.id}")`,
];
const keras = (model: ModelData) => [
`from huggingface_hub import from_pretrained_keras
model = from_pretrained_keras("${model.id}")
`,
];
const open_clip = (model: ModelData) => [
`import open_clip
model, preprocess_train, preprocess_val = open_clip.create_model_and_transforms('hf-hub:${model.id}')
tokenizer = open_clip.get_tokenizer('hf-hub:${model.id}')`,
];
const paddlenlp = (model: ModelData) => {
if (model.config?.architectures?.[0]) {
const architecture = model.config.architectures[0];
return [
[
`from paddlenlp.transformers import AutoTokenizer, ${architecture}`,
"",
`tokenizer = AutoTokenizer.from_pretrained("${model.id}"${
model.private ? ", use_auth_token=True" : ""
}, from_hf_hub=True)`,
`model = ${architecture}.from_pretrained("${model.id}"${
model.private ? ", use_auth_token=True" : ""
}, from_hf_hub=True)`,
].join("\n"),
];
} else {
return [
[
`# ⚠️ Type of model unknown`,
`from paddlenlp.transformers import AutoTokenizer, AutoModel`,
"",
`tokenizer = AutoTokenizer.from_pretrained("${model.id}"${
model.private ? ", use_auth_token=True" : ""
}, from_hf_hub=True)`,
`model = AutoModel.from_pretrained("${model.id}"${
model.private ? ", use_auth_token=True" : ""
}, from_hf_hub=True)`,
].join("\n"),
];
}
};
const pyannote_audio_pipeline = (model: ModelData) => [
`from pyannote.audio import Pipeline
pipeline = Pipeline.from_pretrained("${model.id}")
# inference on the whole file
pipeline("file.wav")
# inference on an excerpt
from pyannote.core import Segment
excerpt = Segment(start=2.0, end=5.0)
from pyannote.audio import Audio
waveform, sample_rate = Audio().crop("file.wav", excerpt)
pipeline({"waveform": waveform, "sample_rate": sample_rate})`,
];
const pyannote_audio_model = (model: ModelData) => [
`from pyannote.audio import Model, Inference
model = Model.from_pretrained("${model.id}")
inference = Inference(model)
# inference on the whole file
inference("file.wav")
# inference on an excerpt
from pyannote.core import Segment
excerpt = Segment(start=2.0, end=5.0)
inference.crop("file.wav", excerpt)`,
];
const pyannote_audio = (model: ModelData) => {
if (model.tags?.includes("pyannote-audio-pipeline")) {
return pyannote_audio_pipeline(model);
}
return pyannote_audio_model(model);
};
const tensorflowttsTextToMel = (model: ModelData) => [
`from tensorflow_tts.inference import AutoProcessor, TFAutoModel
processor = AutoProcessor.from_pretrained("${model.id}")
model = TFAutoModel.from_pretrained("${model.id}")
`,
];
const tensorflowttsMelToWav = (model: ModelData) => [
`from tensorflow_tts.inference import TFAutoModel
model = TFAutoModel.from_pretrained("${model.id}")
audios = model.inference(mels)
`,
];
const tensorflowttsUnknown = (model: ModelData) => [
`from tensorflow_tts.inference import TFAutoModel
model = TFAutoModel.from_pretrained("${model.id}")
`,
];
const tensorflowtts = (model: ModelData) => {
if (model.tags?.includes("text-to-mel")) {
return tensorflowttsTextToMel(model);
} else if (model.tags?.includes("mel-to-wav")) {
return tensorflowttsMelToWav(model);
}
return tensorflowttsUnknown(model);
};
const timm = (model: ModelData) => [
`import timm
model = timm.create_model("hf_hub:${model.id}", pretrained=True)`,
];
const skopsPickle = (model: ModelData, modelFile: string) => {
return [
`import joblib
from skops.hub_utils import download
download("${model.id}", "path_to_folder")
model = joblib.load(
"${modelFile}"
)
# only load pickle files from sources you trust
# read more about it here https://skops.readthedocs.io/en/stable/persistence.html`,
];
};
const skopsFormat = (model: ModelData, modelFile: string) => {
return [
`from skops.hub_utils import download
from skops.io import load
download("${model.id}", "path_to_folder")
# make sure model file is in skops format
# if model is a pickle file, make sure it's from a source you trust
model = load("path_to_folder/${modelFile}")`,
];
};
const skopsJobLib = (model: ModelData) => {
return [
`from huggingface_hub import hf_hub_download
import joblib
model = joblib.load(
hf_hub_download("${model.id}", "sklearn_model.joblib")
)
# only load pickle files from sources you trust
# read more about it here https://skops.readthedocs.io/en/stable/persistence.html`,
];
};
const sklearn = (model: ModelData) => {
if (model.tags?.includes("skops")) {
const skopsmodelFile = model.config?.sklearn?.filename;
const skopssaveFormat = model.config?.sklearn?.model_format;
if (!skopsmodelFile) {
return [`# ⚠️ Model filename not specified in config.json`];
}
if (skopssaveFormat === "pickle") {
return skopsPickle(model, skopsmodelFile);
} else {
return skopsFormat(model, skopsmodelFile);
}
} else {
return skopsJobLib(model);
}
};
const fastai = (model: ModelData) => [
`from huggingface_hub import from_pretrained_fastai
learn = from_pretrained_fastai("${model.id}")`,
];
const sampleFactory = (model: ModelData) => [
`python -m sample_factory.huggingface.load_from_hub -r ${model.id} -d ./train_dir`,
];
const sentenceTransformers = (model: ModelData) => [
`from sentence_transformers import SentenceTransformer
model = SentenceTransformer("${model.id}")`,
];
const spacy = (model: ModelData) => [
`!pip install https://huggingface.co/${model.id}/resolve/main/${nameWithoutNamespace(model.id)}-any-py3-none-any.whl
# Using spacy.load().
import spacy
nlp = spacy.load("${nameWithoutNamespace(model.id)}")
# Importing as module.
import ${nameWithoutNamespace(model.id)}
nlp = ${nameWithoutNamespace(model.id)}.load()`,
];
const span_marker = (model: ModelData) => [
`from span_marker import SpanMarkerModel
model = SpanMarkerModel.from_pretrained("${model.id}")`,
];
const stanza = (model: ModelData) => [
`import stanza
stanza.download("${nameWithoutNamespace(model.id).replace("stanza-", "")}")
nlp = stanza.Pipeline("${nameWithoutNamespace(model.id).replace("stanza-", "")}")`,
];
const speechBrainMethod = (speechbrainInterface: string) => {
switch (speechbrainInterface) {
case "EncoderClassifier":
return "classify_file";
case "EncoderDecoderASR":
case "EncoderASR":
return "transcribe_file";
case "SpectralMaskEnhancement":
return "enhance_file";
case "SepformerSeparation":
return "separate_file";
default:
return undefined;
}
};
const speechbrain = (model: ModelData) => {
const speechbrainInterface = model.config?.speechbrain?.interface;
if (speechbrainInterface === undefined) {
return [`# interface not specified in config.json`];
}
const speechbrainMethod = speechBrainMethod(speechbrainInterface);
if (speechbrainMethod === undefined) {
return [`# interface in config.json invalid`];
}
return [
`from speechbrain.pretrained import ${speechbrainInterface}
model = ${speechbrainInterface}.from_hparams(
"${model.id}"
)
model.${speechbrainMethod}("file.wav")`,
];
};
const transformers = (model: ModelData) => {
const info = model.transformersInfo;
if (!info) {
return [`# ⚠️ Type of model unknown`];
}
const remote_code_snippet = info.custom_class ? ", trust_remote_code=True" : "";
let autoSnippet: string;
if (info.processor) {
const varName =
info.processor === "AutoTokenizer"
? "tokenizer"
: info.processor === "AutoFeatureExtractor"
? "extractor"
: "processor";
autoSnippet = [
"# Load model directly",
`from transformers import ${info.processor}, ${info.auto_model}`,
"",
`${varName} = ${info.processor}.from_pretrained("${model.id}"` + remote_code_snippet + ")",
`model = ${info.auto_model}.from_pretrained("${model.id}"` + remote_code_snippet + ")",
].join("\n");
} else {
autoSnippet = [
"# Load model directly",
`from transformers import ${info.auto_model}`,
`model = ${info.auto_model}.from_pretrained("${model.id}"` + remote_code_snippet + ")",
].join("\n");
}
if (model.pipeline_tag) {
const pipelineSnippet = [
"# Use a pipeline as a high-level helper",
"from transformers import pipeline",
"",
`pipe = pipeline("${model.pipeline_tag}", model="${model.id}"` + remote_code_snippet + ")",
].join("\n");
return [pipelineSnippet, autoSnippet];
}
return [autoSnippet];
};
const transformersJS = (model: ModelData) => {
if (!model.pipeline_tag) {
return [`// ⚠️ Unknown pipeline tag`];
}
const libName = "@xenova/transformers";
return [
`// npm i ${libName}
import { pipeline } from '${libName}';
// Allocate pipeline
const pipe = await pipeline('${model.pipeline_tag}', '${model.id}');`,
];
};
const peftTask = (peftTaskType?: string) => {
switch (peftTaskType) {
case "CAUSAL_LM":
return "CausalLM";
case "SEQ_2_SEQ_LM":
return "Seq2SeqLM";
case "TOKEN_CLS":
return "TokenClassification";
case "SEQ_CLS":
return "SequenceClassification";
default:
return undefined;
}
};
const peft = (model: ModelData) => {
const { base_model_name: peftBaseModel, task_type: peftTaskType } = model.config?.peft ?? {};
const pefttask = peftTask(peftTaskType);
if (!pefttask) {
return [`Task type is invalid.`];
}
if (!peftBaseModel) {
return [`Base model is not found.`];
}
return [
`from peft import PeftModel, PeftConfig
from transformers import AutoModelFor${pefttask}
config = PeftConfig.from_pretrained("${model.id}")
model = AutoModelFor${pefttask}.from_pretrained("${peftBaseModel}")
model = PeftModel.from_pretrained(model, "${model.id}")`,
];
};
const fasttext = (model: ModelData) => [
`from huggingface_hub import hf_hub_download
import fasttext
model = fasttext.load_model(hf_hub_download("${model.id}", "model.bin"))`,
];
const stableBaselines3 = (model: ModelData) => [
`from huggingface_sb3 import load_from_hub
checkpoint = load_from_hub(
repo_id="${model.id}",
filename="{MODEL FILENAME}.zip",
)`,
];
const nemoDomainResolver = (domain: string, model: ModelData): string[] | undefined => {
switch (domain) {
case "ASR":
return [
`import nemo.collections.asr as nemo_asr
asr_model = nemo_asr.models.ASRModel.from_pretrained("${model.id}")
transcriptions = asr_model.transcribe(["file.wav"])`,
];
default:
return undefined;
}
};
const mlAgents = (model: ModelData) => [`mlagents-load-from-hf --repo-id="${model.id}" --local-dir="./downloads"`];
const nemo = (model: ModelData) => {
let command: string[] | undefined = undefined;
// Resolve the tag to a nemo domain/sub-domain
if (model.tags?.includes("automatic-speech-recognition")) {
command = nemoDomainResolver("ASR", model);
}
return command ?? [`# tag did not correspond to a valid NeMo domain.`];
};
const pythae = (model: ModelData) => [
`from pythae.models import AutoModel
model = AutoModel.load_from_hf_hub("${model.id}")`,
];
//#endregion
export const MODEL_LIBRARIES_UI_ELEMENTS: Partial<Record<ModelLibraryKey, LibraryUiElement>> = {
"adapter-transformers": {
btnLabel: "Adapter Transformers",
repoName: "adapter-transformers",
repoUrl: "https://github.com/Adapter-Hub/adapter-transformers",
docsUrl: "https://huggingface.co/docs/hub/adapter-transformers",
snippets: adapter_transformers,
},
allennlp: {
btnLabel: "AllenNLP",
repoName: "AllenNLP",
repoUrl: "https://github.com/allenai/allennlp",
docsUrl: "https://huggingface.co/docs/hub/allennlp",
snippets: allennlp,
},
asteroid: {
btnLabel: "Asteroid",
repoName: "Asteroid",
repoUrl: "https://github.com/asteroid-team/asteroid",
docsUrl: "https://huggingface.co/docs/hub/asteroid",
snippets: asteroid,
},
bertopic: {
btnLabel: "BERTopic",
repoName: "BERTopic",
repoUrl: "https://github.com/MaartenGr/BERTopic",
snippets: bertopic,
},
diffusers: {
btnLabel: "Diffusers",
repoName: "🤗/diffusers",
repoUrl: "https://github.com/huggingface/diffusers",
docsUrl: "https://huggingface.co/docs/hub/diffusers",
snippets: diffusers,
},
espnet: {
btnLabel: "ESPnet",
repoName: "ESPnet",
repoUrl: "https://github.com/espnet/espnet",
docsUrl: "https://huggingface.co/docs/hub/espnet",
snippets: espnet,
},
fairseq: {
btnLabel: "Fairseq",
repoName: "fairseq",
repoUrl: "https://github.com/pytorch/fairseq",
snippets: fairseq,
},
flair: {
btnLabel: "Flair",
repoName: "Flair",
repoUrl: "https://github.com/flairNLP/flair",
docsUrl: "https://huggingface.co/docs/hub/flair",
snippets: flair,
},
keras: {
btnLabel: "Keras",
repoName: "Keras",
repoUrl: "https://github.com/keras-team/keras",
docsUrl: "https://huggingface.co/docs/hub/keras",
snippets: keras,
},
nemo: {
btnLabel: "NeMo",
repoName: "NeMo",
repoUrl: "https://github.com/NVIDIA/NeMo",
snippets: nemo,
},
open_clip: {
btnLabel: "OpenCLIP",
repoName: "OpenCLIP",
repoUrl: "https://github.com/mlfoundations/open_clip",
snippets: open_clip,
},
paddlenlp: {
btnLabel: "paddlenlp",
repoName: "PaddleNLP",
repoUrl: "https://github.com/PaddlePaddle/PaddleNLP",
docsUrl: "https://huggingface.co/docs/hub/paddlenlp",
snippets: paddlenlp,
},
peft: {
btnLabel: "PEFT",
repoName: "PEFT",
repoUrl: "https://github.com/huggingface/peft",
snippets: peft,
},
"pyannote-audio": {
btnLabel: "pyannote.audio",
repoName: "pyannote-audio",
repoUrl: "https://github.com/pyannote/pyannote-audio",
snippets: pyannote_audio,
},
"sentence-transformers": {
btnLabel: "sentence-transformers",
repoName: "sentence-transformers",
repoUrl: "https://github.com/UKPLab/sentence-transformers",
docsUrl: "https://huggingface.co/docs/hub/sentence-transformers",
snippets: sentenceTransformers,
},
sklearn: {
btnLabel: "Scikit-learn",
repoName: "Scikit-learn",
repoUrl: "https://github.com/scikit-learn/scikit-learn",
snippets: sklearn,
},
fastai: {
btnLabel: "fastai",
repoName: "fastai",
repoUrl: "https://github.com/fastai/fastai",
docsUrl: "https://huggingface.co/docs/hub/fastai",
snippets: fastai,
},
spacy: {
btnLabel: "spaCy",
repoName: "spaCy",
repoUrl: "https://github.com/explosion/spaCy",
docsUrl: "https://huggingface.co/docs/hub/spacy",
snippets: spacy,
},
"span-marker": {
btnLabel: "SpanMarker",
repoName: "SpanMarkerNER",
repoUrl: "https://github.com/tomaarsen/SpanMarkerNER",
docsUrl: "https://huggingface.co/docs/hub/span_marker",
snippets: span_marker,
},
speechbrain: {
btnLabel: "speechbrain",
repoName: "speechbrain",
repoUrl: "https://github.com/speechbrain/speechbrain",
docsUrl: "https://huggingface.co/docs/hub/speechbrain",
snippets: speechbrain,
},
stanza: {
btnLabel: "Stanza",
repoName: "stanza",
repoUrl: "https://github.com/stanfordnlp/stanza",
docsUrl: "https://huggingface.co/docs/hub/stanza",
snippets: stanza,
},
tensorflowtts: {
btnLabel: "TensorFlowTTS",
repoName: "TensorFlowTTS",
repoUrl: "https://github.com/TensorSpeech/TensorFlowTTS",
snippets: tensorflowtts,
},
timm: {
btnLabel: "timm",
repoName: "pytorch-image-models",
repoUrl: "https://github.com/rwightman/pytorch-image-models",
docsUrl: "https://huggingface.co/docs/hub/timm",
snippets: timm,
},
transformers: {
btnLabel: "Transformers",
repoName: "🤗/transformers",
repoUrl: "https://github.com/huggingface/transformers",
docsUrl: "https://huggingface.co/docs/hub/transformers",
snippets: transformers,
},
"transformers.js": {
btnLabel: "Transformers.js",
repoName: "transformers.js",
repoUrl: "https://github.com/xenova/transformers.js",
docsUrl: "https://huggingface.co/docs/hub/transformers-js",
snippets: transformersJS,
},
fasttext: {
btnLabel: "fastText",
repoName: "fastText",
repoUrl: "https://fasttext.cc/",
snippets: fasttext,
},
"sample-factory": {
btnLabel: "sample-factory",
repoName: "sample-factory",
repoUrl: "https://github.com/alex-petrenko/sample-factory",
docsUrl: "https://huggingface.co/docs/hub/sample-factory",
snippets: sampleFactory,
},
"stable-baselines3": {
btnLabel: "stable-baselines3",
repoName: "stable-baselines3",
repoUrl: "https://github.com/huggingface/huggingface_sb3",
docsUrl: "https://huggingface.co/docs/hub/stable-baselines3",
snippets: stableBaselines3,
},
"ml-agents": {
btnLabel: "ml-agents",
repoName: "ml-agents",
repoUrl: "https://github.com/huggingface/ml-agents",
docsUrl: "https://huggingface.co/docs/hub/ml-agents",
snippets: mlAgents,
},
pythae: {
btnLabel: "pythae",
repoName: "pythae",
repoUrl: "https://github.com/clementchadebec/benchmark_VAE",
snippets: pythae,
},
} as const;