LittleApple-fp16's picture
Upload 88 files
4f8ad24
import glob
import os
import pathlib
import random
import re
from typing import Iterator
from PIL import UnidentifiedImageError
from imgutils.data import load_image
from .base import RootDataSource
from ..model import ImageItem
class LocalSource(RootDataSource):
def __init__(self, directory: str, recursive: bool = True, shuffle: bool = False):
self.directory = directory
self.recursive = recursive
self.shuffle = shuffle
def _iter_files(self):
if self.recursive:
for directory, _, files in os.walk(self.directory):
group_name = re.sub(r'[\W_]+', '_', directory).strip('_')
for file in files:
yield os.path.join(directory, file), group_name
else:
group_name = re.sub(r'[\W_]+', '_', self.directory).strip('_')
for file in os.listdir(self.directory):
yield os.path.join(self.directory, file), group_name
def _actual_iter_files(self):
lst = list(self._iter_files())
if self.shuffle:
random.shuffle(lst)
yield from lst
def _iter(self) -> Iterator[ImageItem]:
for file, group_name in self._iter_files():
try:
origin_item = ImageItem.load_from_image(file)
origin_item.image.load()
except UnidentifiedImageError:
continue
meta = origin_item.meta or {
'path': os.path.abspath(file),
'group_id': group_name,
'filename': os.path.basename(file),
}
yield ImageItem(origin_item.image, meta)
class LocalTISource(RootDataSource):
def __init__(self, directory: str):
self.directory = directory
def _iter(self) -> Iterator[ImageItem]:
group_name = re.sub(r'[\W_]+', '_', self.directory).strip('_')
for f in glob.glob(os.path.join(self.directory, '*')):
if not os.path.isfile(f):
continue
try:
image = load_image(f)
except UnidentifiedImageError:
continue
id_ = os.path.splitext(os.path.basename(f))[0]
txt_file = os.path.join(self.directory, f'{id_}.txt')
if os.path.exists(txt_file):
full_text = pathlib.Path(txt_file).read_text(encoding='utf-8')
words = re.split(r'\s*,\s*', full_text)
tags = {word: 1.0 for word in words}
else:
tags = {}
meta = {
'path': os.path.abspath(f),
'group_id': group_name,
'filename': os.path.basename(f),
'tags': tags,
}
yield ImageItem(image, meta)