hz2475's picture
init
72f684c

StarVector Validation

This module provides validation functionality for StarVector models, allowing evaluation of SVG generation quality across different model architectures and generation backends.

Overview

The validation framework consists of:

  1. A base SVGValidator class that handles common validation logic
  2. Specific validator implementations for different backends:
    • StarVectorHFSVGValidator: Uses HuggingFace generation API
    • StarVectorVLLMValidator: Uses vLLM for faster generation
    • StarVectorVLLMAPIValidator: Uses vLLM through REST API

1. Running Validation

Using HuggingFace Backend

# StarVector-1B
python starvector/validation/validate.py \
config=configs/generation/hf/starvector-1b/im2svg.yaml \
dataset.name=starvector/svg-stack

# StarVector-8B 
python starvector/validation/validate.py \
config=configs/generation/hf/starvector-8b/im2svg.yaml \
dataset.name=starvector/svg-stack

vLLM Backend

For using the vLLM backend (StarVectorVLLMAPIValidator), first install our StarVector fork of VLLM, here.

git clone https://github.com/starvector/vllm.git
cd vllm
pip install -e .

Then, launch the using the vllm config file (it uses StarVectorVLLMValidator):

# StarVector-1B
python starvector/validation/validate.py \
config=configs/generation/vllm/starvector-1b/im2svg.yaml \
dataset.name=starvector/svg-stack

# StarVector-8B
python starvector/validation/validate.py \
config=configs/generation/vllm/starvector-8b/im2svg.yaml \
dataset.name=starvector/svg-stack

2. Creating a New SVG Validator

To create a new validator for a different model or generation backend:

  1. Create a new class inheriting from SVGValidator
  2. Implement required abstract methods:
    • __init__(self, config): Initialize the validator with the given config
    • get_dataloader(self): Get the dataloader for the given dataset
    • generate_svg(self, batch): Generate SVG from input batch
  3. Add the new validator to the registry in starvector/validation/__init__.py

Example:

from .svg_validator_base import SVGValidator, register_validator

@register_validator
class MyNewValidator(SVGValidator):
    def __init__(self, config):
        super().__init__(config)
        # Initialize your model/client here
        
    def generate_svg(self, batch, generate_config):
        # Implement generation logic
        # Return list of generated SVG strings
        pass
        
    def get_dataloader(self):
        # Implement dataloader logic
        pass

Key Features

The validation framework provides:

  • Automatic metrics calculation and logging
  • WandB integration for experiment tracking
  • Temperature sweep for exploring generation parameters
  • Comparison plot generation
  • Batch processing with configurable settings

Configuration

Validation is configured through YAML files in configs/generation/. Key configuration sections:

model:
  name: "model_name"  # HF model name or path
  task: "im2svg"      # Task type
  torch_dtype: "float16"  # Model precision

dataset:
  dataset_name: "svg-stack"  # Dataset to validate on
  batch_size: 1
  num_workers: 4

generation_params:
  temperature: 0.2
  top_p: 0.9
  max_length: 1024
  # ... other generation parameters

run:
  report_to: "wandb"  # Logging backend
  out_dir: "outputs"  # Output directory

Output Structure

The validator creates the following directory structure:

out_dir/
β”œβ”€β”€ {model}_{dataset}_{timestamp}/
β”‚   β”œβ”€β”€ config.yaml           # Run configuration
β”‚   β”œβ”€β”€ results/
β”‚   β”‚   β”œβ”€β”€ results_avg.json  # Average metrics
β”‚   β”‚   └── all_results.csv   # Per-sample metrics
β”‚   └── {sample_id}/         # Per-sample outputs
β”‚       β”œβ”€β”€ metadata.json
β”‚       β”œβ”€β”€ {sample_id}.svg
β”‚       β”œβ”€β”€ {sample_id}_raw.svg
β”‚       β”œβ”€β”€ {sample_id}_gt.svg
β”‚       β”œβ”€β”€ {sample_id}_generated.png
β”‚       β”œβ”€β”€ {sample_id}_original.png
β”‚       └── {sample_id}_comparison.png