Spaces:
Running
Running
| from pathlib import Path | |
| from typing import Dict, NamedTuple, Union | |
| import numpy as np | |
| import torch | |
| NULL_CHAR = '\x00' | |
| class PrimingData(NamedTuple): | |
| """combines data required for priming the HandwritingRNN sampling""" | |
| stroke_tensors: torch.Tensor # (batch_size, num_prime_strokes, 3) | |
| char_seq_tensors: torch.Tensor # (batch_size, num_prime_chars) | |
| char_seq_lengths: torch.Tensor # (batch_size,) | |
| def construct_alphabet_list(alphabet_string: str) -> list[str]: | |
| if not isinstance(alphabet_string, str): | |
| raise TypeError("alphabet_string must be a string") | |
| char_list = list(alphabet_string) | |
| return [NULL_CHAR] + char_list | |
| def get_alphabet_map(alphabet_list: list[str]) -> Dict[str, int]: | |
| """creates a char to index map from full alphabet list""" | |
| return {char: idx for idx, char in enumerate(alphabet_list)} | |
| def encode_text(text: str, char_to_index_map: Dict[str, int], | |
| max_length: int, add_eos: bool = True, eos_char_index: int = 0 | |
| ) -> tuple[np.ndarray, int]: | |
| """Encode a text string into a sequence of integer indices""" | |
| encoded = [char_to_index_map.get(c, eos_char_index) for c in text] | |
| if add_eos: | |
| encoded.append(eos_char_index) | |
| true_length = len(encoded) | |
| if true_length <= max_length: | |
| padded_encoded = np.full(max_length, eos_char_index, dtype=np.int64) | |
| padded_encoded[:true_length] = encoded | |
| else: | |
| padded_encoded = np.array(encoded[:max_length], dtype=np.int64) | |
| true_length = max_length | |
| return np.array([padded_encoded]), true_length | |
| def convert_offsets_to_absolute_coords(stroke_offsets: list[list[float]]) -> list[list[float]]: | |
| if not stroke_offsets: | |
| return [] | |
| # convert to numpy for vectorized operations | |
| strokes_array = np.array(stroke_offsets) | |
| # vectorized cumulative sum for x and y | |
| strokes_array[:, 0] = np.cumsum(strokes_array[:, 0]) # cumulative dx | |
| strokes_array[:, 1] = np.cumsum(strokes_array[:, 1]) # cumulative dy | |
| return strokes_array.tolist() | |
| def load_np_strokes(stroke_path: Union[Path, str]) -> np.ndarray: | |
| """loads stroke sequence from stroke_path""" | |
| stroke_path = Path(stroke_path) | |
| if not stroke_path.exists(): | |
| raise FileNotFoundError(f"style strokes file not found at {stroke_path}") | |
| return np.load(stroke_path) | |
| def load_text(text_path: Union[Path, str]) -> str: | |
| """loads text from a text_path""" | |
| text_path = Path(text_path) | |
| if not text_path.exists(): | |
| raise FileNotFoundError(f"Text file not found at {text_path}") | |
| if not text_path.is_file(): | |
| raise IsADirectoryError(f"Path is a directory, not a file.") | |
| try: | |
| with open(text_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| return content | |
| except Exception as e: | |
| raise IOError(f"Error reading text file {text_path}: {e}") | |
| def load_priming_data(style: int): | |
| priming_text = load_text(f"./styles/style{style}.txt") | |
| priming_strokes = load_np_strokes(f"./styles/style{style}.npy") | |
| return priming_text, priming_strokes |