File size: 3,904 Bytes
72f684c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
# hf https://huggingface.co/docs/transformers/main_classes/text_generation
from starvector.validation.svg_validator_base import SVGValidator, register_validator
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from starvector.data.util import rasterize_svg

class SVGValDataset(Dataset):
    def __init__(self, dataset_name, config_name, split, im_size, num_samples, processor):
        self.dataset_name = dataset_name
        self.config_name = config_name
        self.split = split
        self.im_size = im_size
        self.num_samples = num_samples
        self.processor = processor

        if self.config_name:
            self.data = load_dataset(self.dataset_name, self.config_name, split=self.split)
        else:
            self.data = load_dataset(self.dataset_name, split=self.split)
        
        if self.num_samples != -1:
            self.data = self.data.select(range(self.num_samples))
            
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        svg_str = self.data[idx]['Svg']
        sample_id = self.data[idx]['Filename']
        image = rasterize_svg(svg_str, resolution=self.im_size)
        image = self.processor(image, return_tensors="pt")['pixel_values'].squeeze(0)
        caption = self.data[idx].get('Caption', "")
        return {
            'Svg': svg_str,
            'image': image,
            'Filename': sample_id,
            'Caption': caption
        }
    
                
@register_validator
class StarVectorHFSVGValidator(SVGValidator):
    def __init__(self, config):
        super().__init__(config)
        # Initialize HuggingFace model and tokenizer here
        self.torch_dtype = {
            'bfloat16': torch.bfloat16,
            'float16': torch.float16,
            'float32': torch.float32
        }[config.model.torch_dtype]

        # could also use AutoModelForCausalLM
        if config.model.from_checkpoint:
            self.model = AutoModelForCausalLM.from_pretrained(self.resume_from_checkpoint, trust_remote_code=True, torch_dtype=self.torch_dtype).to(config.run.device)
        else:
            self.model = AutoModelForCausalLM.from_pretrained(config.model.name, trust_remote_code=True, torch_dtype=self.torch_dtype).to(config.run.device)
        
        self.tokenizer = self.model.model.svg_transformer.tokenizer
        self.svg_end_token_id = self.tokenizer.encode("</svg>")[0] 

    def get_dataloader(self):
        self.dataset = SVGValDataset(self.config.dataset.dataset_name, self.config.dataset.config_name, self.config.dataset.split, self.config.dataset.im_size, self.config.dataset.num_samples, self.processor)
        self.dataloader = DataLoader(self.dataset, batch_size=self.config.dataset.batch_size, shuffle=False, num_workers=self.config.dataset.num_workers)
    
    def release_memory(self):
        # Clear references to free GPU memory
        self.model.model.svg_transformer.tokenizer = None
        self.model.model.svg_transformer.model = None
        
        # Force CUDA garbage collection
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
        
    def generate_svg(self, batch, generate_config):
        if generate_config['temperature'] == 0:
            generate_config['temperature'] = 1.0
            generate_config['do_sample'] = False
        outputs = []
        batch['image'] = batch['image'].to('cuda').to(self.torch_dtype)
        # for i, batch in enumerate(batch['svg']):
        if self.task == 'im2svg':
            outputs = self.model.model.generate_im2svg(batch = batch, **generate_config)
        elif self.task == 'text2svg':
            outputs = self.model.model.generate_text2svg(batch = batch, **generate_config)
        return outputs