|
import torch |
|
import torch.nn as nn |
|
from starvector.model.models.starvector_base import StarVectorBase |
|
from transformers import AutoProcessor |
|
|
|
class StarVectorStarCoder(StarVectorBase): |
|
def __init__(self, config, **kwargs): |
|
super().__init__(config, **kwargs) |
|
|
|
self.processor = AutoProcessor.from_pretrained(config._name_or_path) |
|
|
|
def _get_svg_transformer(self, config, **kwargs): |
|
from starvector.model.llm.starcoder import StarCoderModel |
|
return StarCoderModel(config, **kwargs) |
|
|
|
def _get_embeddings(self, input_ids): |
|
"""V1 specific embedding method""" |
|
return self.svg_transformer.transformer.transformer.wte(input_ids) |
|
|
|
def _get_svg_text(self, svg_list): |
|
"""V1 specific SVG text preparation""" |
|
return [t + self.svg_transformer.tokenizer.eos_token for t in svg_list] |
|
|