Spaces:
Running
Running
File size: 3,904 Bytes
72f684c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
# hf https://huggingface.co/docs/transformers/main_classes/text_generation
from starvector.validation.svg_validator_base import SVGValidator, register_validator
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from starvector.data.util import rasterize_svg
class SVGValDataset(Dataset):
def __init__(self, dataset_name, config_name, split, im_size, num_samples, processor):
self.dataset_name = dataset_name
self.config_name = config_name
self.split = split
self.im_size = im_size
self.num_samples = num_samples
self.processor = processor
if self.config_name:
self.data = load_dataset(self.dataset_name, self.config_name, split=self.split)
else:
self.data = load_dataset(self.dataset_name, split=self.split)
if self.num_samples != -1:
self.data = self.data.select(range(self.num_samples))
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
svg_str = self.data[idx]['Svg']
sample_id = self.data[idx]['Filename']
image = rasterize_svg(svg_str, resolution=self.im_size)
image = self.processor(image, return_tensors="pt")['pixel_values'].squeeze(0)
caption = self.data[idx].get('Caption', "")
return {
'Svg': svg_str,
'image': image,
'Filename': sample_id,
'Caption': caption
}
@register_validator
class StarVectorHFSVGValidator(SVGValidator):
def __init__(self, config):
super().__init__(config)
# Initialize HuggingFace model and tokenizer here
self.torch_dtype = {
'bfloat16': torch.bfloat16,
'float16': torch.float16,
'float32': torch.float32
}[config.model.torch_dtype]
# could also use AutoModelForCausalLM
if config.model.from_checkpoint:
self.model = AutoModelForCausalLM.from_pretrained(self.resume_from_checkpoint, trust_remote_code=True, torch_dtype=self.torch_dtype).to(config.run.device)
else:
self.model = AutoModelForCausalLM.from_pretrained(config.model.name, trust_remote_code=True, torch_dtype=self.torch_dtype).to(config.run.device)
self.tokenizer = self.model.model.svg_transformer.tokenizer
self.svg_end_token_id = self.tokenizer.encode("</svg>")[0]
def get_dataloader(self):
self.dataset = SVGValDataset(self.config.dataset.dataset_name, self.config.dataset.config_name, self.config.dataset.split, self.config.dataset.im_size, self.config.dataset.num_samples, self.processor)
self.dataloader = DataLoader(self.dataset, batch_size=self.config.dataset.batch_size, shuffle=False, num_workers=self.config.dataset.num_workers)
def release_memory(self):
# Clear references to free GPU memory
self.model.model.svg_transformer.tokenizer = None
self.model.model.svg_transformer.model = None
# Force CUDA garbage collection
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
def generate_svg(self, batch, generate_config):
if generate_config['temperature'] == 0:
generate_config['temperature'] = 1.0
generate_config['do_sample'] = False
outputs = []
batch['image'] = batch['image'].to('cuda').to(self.torch_dtype)
# for i, batch in enumerate(batch['svg']):
if self.task == 'im2svg':
outputs = self.model.model.generate_im2svg(batch = batch, **generate_config)
elif self.task == 'text2svg':
outputs = self.model.model.generate_text2svg(batch = batch, **generate_config)
return outputs
|