OpkaGames's picture
Upload folder using huggingface_hub
870ab6b
"""This module should not be used directly as its API is subject to change. Instead,
please use the `gr.Interface.from_pipeline()` function."""
from __future__ import annotations
from typing import TYPE_CHECKING
from gradio import components
if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from transformers import pipelines
def load_from_pipeline(pipeline: pipelines.base.Pipeline) -> dict:
"""
Gets the appropriate Interface kwargs for a given Hugging Face transformers.Pipeline.
pipeline (transformers.Pipeline): the transformers.Pipeline from which to create an interface
Returns:
(dict): a dictionary of kwargs that can be used to construct an Interface object
"""
try:
import transformers
from transformers import pipelines
except ImportError as ie:
raise ImportError(
"transformers not installed. Please try `pip install transformers`"
) from ie
if not isinstance(pipeline, pipelines.base.Pipeline):
raise ValueError("pipeline must be a transformers.Pipeline")
# Handle the different pipelines. The has_attr() checks to make sure the pipeline exists in the
# version of the transformers library that the user has installed.
if hasattr(transformers, "AudioClassificationPipeline") and isinstance(
pipeline, pipelines.audio_classification.AudioClassificationPipeline
):
pipeline_info = {
"inputs": components.Audio(
source="microphone",
type="filepath",
label="Input",
render=False,
),
"outputs": components.Label(label="Class", render=False),
"preprocess": lambda i: {"inputs": i},
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
}
elif hasattr(transformers, "AutomaticSpeechRecognitionPipeline") and isinstance(
pipeline,
pipelines.automatic_speech_recognition.AutomaticSpeechRecognitionPipeline,
):
pipeline_info = {
"inputs": components.Audio(
source="microphone", type="filepath", label="Input", render=False
),
"outputs": components.Textbox(label="Output", render=False),
"preprocess": lambda i: {"inputs": i},
"postprocess": lambda r: r["text"],
}
elif hasattr(transformers, "FeatureExtractionPipeline") and isinstance(
pipeline, pipelines.feature_extraction.FeatureExtractionPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.Dataframe(label="Output", render=False),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r[0],
}
elif hasattr(transformers, "FillMaskPipeline") and isinstance(
pipeline, pipelines.fill_mask.FillMaskPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.Label(label="Classification", render=False),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: {i["token_str"]: i["score"] for i in r},
}
elif hasattr(transformers, "ImageClassificationPipeline") and isinstance(
pipeline, pipelines.image_classification.ImageClassificationPipeline
):
pipeline_info = {
"inputs": components.Image(
type="filepath", label="Input Image", render=False
),
"outputs": components.Label(label="Classification", render=False),
"preprocess": lambda i: {"images": i},
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
}
elif hasattr(transformers, "QuestionAnsweringPipeline") and isinstance(
pipeline, pipelines.question_answering.QuestionAnsweringPipeline
):
pipeline_info = {
"inputs": [
components.Textbox(lines=7, label="Context", render=False),
components.Textbox(label="Question", render=False),
],
"outputs": [
components.Textbox(label="Answer", render=False),
components.Label(label="Score", render=False),
],
"preprocess": lambda c, q: {"context": c, "question": q},
"postprocess": lambda r: (r["answer"], r["score"]),
}
elif hasattr(transformers, "SummarizationPipeline") and isinstance(
pipeline, pipelines.text2text_generation.SummarizationPipeline
):
pipeline_info = {
"inputs": components.Textbox(lines=7, label="Input", render=False),
"outputs": components.Textbox(label="Summary", render=False),
"preprocess": lambda x: {"inputs": x},
"postprocess": lambda r: r[0]["summary_text"],
}
elif hasattr(transformers, "TextClassificationPipeline") and isinstance(
pipeline, pipelines.text_classification.TextClassificationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.Label(label="Classification", render=False),
"preprocess": lambda x: [x],
"postprocess": lambda r: {i["label"].split(", ")[0]: i["score"] for i in r},
}
elif hasattr(transformers, "TextGenerationPipeline") and isinstance(
pipeline, pipelines.text_generation.TextGenerationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.Textbox(label="Output", render=False),
"preprocess": lambda x: {"text_inputs": x},
"postprocess": lambda r: r[0]["generated_text"],
}
elif hasattr(transformers, "TranslationPipeline") and isinstance(
pipeline, pipelines.text2text_generation.TranslationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.Textbox(label="Translation", render=False),
"preprocess": lambda x: [x],
"postprocess": lambda r: r[0]["translation_text"],
}
elif hasattr(transformers, "Text2TextGenerationPipeline") and isinstance(
pipeline, pipelines.text2text_generation.Text2TextGenerationPipeline
):
pipeline_info = {
"inputs": components.Textbox(label="Input", render=False),
"outputs": components.Textbox(label="Generated Text", render=False),
"preprocess": lambda x: [x],
"postprocess": lambda r: r[0]["generated_text"],
}
elif hasattr(transformers, "ZeroShotClassificationPipeline") and isinstance(
pipeline, pipelines.zero_shot_classification.ZeroShotClassificationPipeline
):
pipeline_info = {
"inputs": [
components.Textbox(label="Input", render=False),
components.Textbox(
label="Possible class names (" "comma-separated)", render=False
),
components.Checkbox(label="Allow multiple true classes", render=False),
],
"outputs": components.Label(label="Classification", render=False),
"preprocess": lambda i, c, m: {
"sequences": i,
"candidate_labels": c,
"multi_label": m,
},
"postprocess": lambda r: {
r["labels"][i]: r["scores"][i] for i in range(len(r["labels"]))
},
}
elif hasattr(transformers, "DocumentQuestionAnsweringPipeline") and isinstance(
pipeline,
pipelines.document_question_answering.DocumentQuestionAnsweringPipeline, # type: ignore
):
pipeline_info = {
"inputs": [
components.Image(type="filepath", label="Input Document", render=False),
components.Textbox(label="Question", render=False),
],
"outputs": components.Label(label="Label", render=False),
"preprocess": lambda img, q: {"image": img, "question": q},
"postprocess": lambda r: {i["answer"]: i["score"] for i in r},
}
elif hasattr(transformers, "VisualQuestionAnsweringPipeline") and isinstance(
pipeline, pipelines.visual_question_answering.VisualQuestionAnsweringPipeline
):
pipeline_info = {
"inputs": [
components.Image(type="filepath", label="Input Image", render=False),
components.Textbox(label="Question", render=False),
],
"outputs": components.Label(label="Score", render=False),
"preprocess": lambda img, q: {"image": img, "question": q},
"postprocess": lambda r: {i["answer"]: i["score"] for i in r},
}
elif hasattr(transformers, "ImageToTextPipeline") and isinstance(
pipeline, pipelines.image_to_text.ImageToTextPipeline # type: ignore
):
pipeline_info = {
"inputs": components.Image(
type="filepath", label="Input Image", render=False
),
"outputs": components.Textbox(label="Text", render=False),
"preprocess": lambda i: {"images": i},
"postprocess": lambda r: r[0]["generated_text"],
}
else:
raise ValueError(f"Unsupported pipeline type: {type(pipeline)}")
# define the function that will be called by the Interface
def fn(*params):
data = pipeline_info["preprocess"](*params)
# special cases that needs to be handled differently
if isinstance(
pipeline,
(
pipelines.text_classification.TextClassificationPipeline,
pipelines.text2text_generation.Text2TextGenerationPipeline,
pipelines.text2text_generation.TranslationPipeline,
),
):
data = pipeline(*data)
else:
data = pipeline(**data)
output = pipeline_info["postprocess"](data)
return output
interface_info = pipeline_info.copy()
interface_info["fn"] = fn
del interface_info["preprocess"]
del interface_info["postprocess"]
# define the title/description of the Interface
interface_info["title"] = pipeline.model.__class__.__name__
return interface_info