Spaces:
Sleeping
Sleeping
""" | |
Command Line Interface for GreggRecognition | |
""" | |
import argparse | |
import os | |
import sys | |
from pathlib import Path | |
from typing import List | |
from .recognizer import GreggRecognition | |
def parse_args(): | |
"""Parse command line arguments""" | |
parser = argparse.ArgumentParser( | |
description="Recognize Gregg shorthand from images", | |
formatter_class=argparse.ArgumentDefaultsHelpFormatter | |
) | |
parser.add_argument( | |
"input", | |
help="Input image file or directory containing images" | |
) | |
parser.add_argument( | |
"--model", | |
choices=["image_to_text", "seq2seq"], | |
default="image_to_text", | |
help="Model type to use for recognition" | |
) | |
parser.add_argument( | |
"--model-path", | |
help="Path to custom model weights file" | |
) | |
parser.add_argument( | |
"--output", | |
help="Output file to save results (default: print to stdout)" | |
) | |
parser.add_argument( | |
"--device", | |
choices=["auto", "cpu", "cuda"], | |
default="auto", | |
help="Device to use for inference" | |
) | |
parser.add_argument( | |
"--batch-size", | |
type=int, | |
default=8, | |
help="Batch size for processing multiple images" | |
) | |
parser.add_argument( | |
"--beam-size", | |
type=int, | |
default=1, | |
help="Beam size for beam search (image_to_text model only)" | |
) | |
parser.add_argument( | |
"--temperature", | |
type=float, | |
default=1.0, | |
help="Temperature for sampling (seq2seq model only)" | |
) | |
parser.add_argument( | |
"--extensions", | |
nargs="+", | |
default=[".jpg", ".jpeg", ".png", ".bmp", ".tiff"], | |
help="Image file extensions to process when input is a directory" | |
) | |
parser.add_argument( | |
"--verbose", | |
action="store_true", | |
help="Enable verbose output" | |
) | |
return parser.parse_args() | |
def find_image_files(input_path: str, extensions: List[str]) -> List[str]: | |
"""Find all image files in a directory""" | |
input_path = Path(input_path) | |
if input_path.is_file(): | |
return [str(input_path)] | |
elif input_path.is_dir(): | |
image_files = [] | |
for ext in extensions: | |
pattern = f"*{ext.lower()}" | |
image_files.extend(input_path.glob(pattern)) | |
pattern = f"*{ext.upper()}" | |
image_files.extend(input_path.glob(pattern)) | |
return [str(f) for f in sorted(set(image_files))] | |
else: | |
raise FileNotFoundError(f"Input path does not exist: {input_path}") | |
def main(): | |
"""Main CLI function""" | |
args = parse_args() | |
try: | |
# Find input files | |
image_files = find_image_files(args.input, args.extensions) | |
if not image_files: | |
print(f"No image files found in: {args.input}") | |
sys.exit(1) | |
if args.verbose: | |
print(f"Found {len(image_files)} image file(s)") | |
print(f"Using model: {args.model}") | |
print(f"Device: {args.device}") | |
# Initialize recognizer | |
recognizer = GreggRecognition( | |
model_type=args.model, | |
device=args.device, | |
model_path=args.model_path | |
) | |
if args.verbose: | |
model_info = recognizer.get_model_info() | |
print(f"Model parameters: {model_info['num_parameters']:,}") | |
# Process images | |
if len(image_files) == 1: | |
# Single image | |
result = recognizer.recognize( | |
image_files[0], | |
beam_size=args.beam_size, | |
temperature=args.temperature | |
) | |
results = [(image_files[0], result)] | |
else: | |
# Multiple images | |
if args.verbose: | |
print(f"Processing {len(image_files)} images...") | |
recognized_texts = recognizer.batch_recognize( | |
image_files, | |
batch_size=args.batch_size, | |
beam_size=args.beam_size, | |
temperature=args.temperature | |
) | |
results = list(zip(image_files, recognized_texts)) | |
# Output results | |
if args.output: | |
# Write to file | |
with open(args.output, 'w', encoding='utf-8') as f: | |
for image_path, text in results: | |
f.write(f"{image_path}\t{text}\n") | |
if args.verbose: | |
print(f"Results saved to: {args.output}") | |
else: | |
# Print to stdout | |
for image_path, text in results: | |
if len(image_files) == 1: | |
print(text) | |
else: | |
print(f"{os.path.basename(image_path)}: {text}") | |
except Exception as e: | |
print(f"Error: {str(e)}", file=sys.stderr) | |
sys.exit(1) | |
if __name__ == "__main__": | |
main() | |