Spaces:
Runtime error
Runtime error
| # Copyright 2020 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from argparse import ArgumentParser | |
| from ..pipelines import Pipeline, PipelineDataFormat, get_supported_tasks, pipeline | |
| from ..utils import logging | |
| from . import BaseTransformersCLICommand | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| def try_infer_format_from_ext(path: str): | |
| if not path: | |
| return "pipe" | |
| for ext in PipelineDataFormat.SUPPORTED_FORMATS: | |
| if path.endswith(ext): | |
| return ext | |
| raise Exception( | |
| f"Unable to determine file format from file extension {path}. " | |
| f"Please provide the format through --format {PipelineDataFormat.SUPPORTED_FORMATS}" | |
| ) | |
| def run_command_factory(args): | |
| nlp = pipeline( | |
| task=args.task, | |
| model=args.model if args.model else None, | |
| config=args.config, | |
| tokenizer=args.tokenizer, | |
| device=args.device, | |
| ) | |
| format = try_infer_format_from_ext(args.input) if args.format == "infer" else args.format | |
| reader = PipelineDataFormat.from_str( | |
| format=format, | |
| output_path=args.output, | |
| input_path=args.input, | |
| column=args.column if args.column else nlp.default_input_names, | |
| overwrite=args.overwrite, | |
| ) | |
| return RunCommand(nlp, reader) | |
| class RunCommand(BaseTransformersCLICommand): | |
| def __init__(self, nlp: Pipeline, reader: PipelineDataFormat): | |
| self._nlp = nlp | |
| self._reader = reader | |
| def register_subcommand(parser: ArgumentParser): | |
| run_parser = parser.add_parser("run", help="Run a pipeline through the CLI") | |
| run_parser.add_argument("--task", choices=get_supported_tasks(), help="Task to run") | |
| run_parser.add_argument("--input", type=str, help="Path to the file to use for inference") | |
| run_parser.add_argument("--output", type=str, help="Path to the file that will be used post to write results.") | |
| run_parser.add_argument("--model", type=str, help="Name or path to the model to instantiate.") | |
| run_parser.add_argument("--config", type=str, help="Name or path to the model's config to instantiate.") | |
| run_parser.add_argument( | |
| "--tokenizer", type=str, help="Name of the tokenizer to use. (default: same as the model name)" | |
| ) | |
| run_parser.add_argument( | |
| "--column", | |
| type=str, | |
| help="Name of the column to use as input. (For multi columns input as QA use column1,columns2)", | |
| ) | |
| run_parser.add_argument( | |
| "--format", | |
| type=str, | |
| default="infer", | |
| choices=PipelineDataFormat.SUPPORTED_FORMATS, | |
| help="Input format to read from", | |
| ) | |
| run_parser.add_argument( | |
| "--device", | |
| type=int, | |
| default=-1, | |
| help="Indicate the device to run onto, -1 indicates CPU, >= 0 indicates GPU (default: -1)", | |
| ) | |
| run_parser.add_argument("--overwrite", action="store_true", help="Allow overwriting the output file.") | |
| run_parser.set_defaults(func=run_command_factory) | |
| def run(self): | |
| nlp, outputs = self._nlp, [] | |
| for entry in self._reader: | |
| output = nlp(**entry) if self._reader.is_multi_columns else nlp(entry) | |
| if isinstance(output, dict): | |
| outputs.append(output) | |
| else: | |
| outputs += output | |
| # Saving data | |
| if self._nlp.binary_output: | |
| binary_path = self._reader.save_binary(outputs) | |
| logger.warning(f"Current pipeline requires output to be in binary format, saving at {binary_path}") | |
| else: | |
| self._reader.save(outputs) | |