thon | |
def postprocess(self, model_outputs, top_k=5): | |
best_class = model_outputs["logits"].softmax(-1) | |
# Add logic to handle top_k | |
return best_class | |
def _sanitize_parameters(self, **kwargs): | |
preprocess_kwargs = {} | |
if "maybe_arg" in kwargs: | |
preprocess_kwargs["maybe_arg"] = kwargs["maybe_arg"] | |
postprocess_kwargs = {} | |
if "top_k" in kwargs: | |
postprocess_kwargs["top_k"] = kwargs["top_k"] | |
return preprocess_kwargs, {}, postprocess_kwargs | |
Try to keep the inputs/outputs very simple and ideally JSON-serializable as it makes the pipeline usage very easy | |
without requiring users to understand new kinds of objects. |