File size: 738 Bytes
268c7f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from typing import Optional, List
from datasets import load_dataset
import random


class PromptLoader:
    def __init__(self, seed: int = 42) -> None:
        self.randomizer = random.Random(seed)
        self.data: Optional[List[str]] = None

    def _load_data(self) -> None:
        self.data = load_dataset("daspartho/stable-diffusion-prompts")["train"][
            "prompt"
        ]

    def load_data(self, size: Optional[int] = None) -> List[str]:
        if not self.data:
            self._load_data()

        if size:
            if size > len(self.data):
                raise ValueError("Not enough samples available!")
            return self.randomizer.sample(self.data, size)
        else:
            return self.data