from abc import abstractmethod
from dataclasses import asdict
from typing import Generator, List, Optional

from .card import ICLCard, TaskCard
from .common import CommonRecipe
from .operator import SourceOperator, StreamSource
from .random_utils import random
from .stream import MultiStream, Stream


class BaseFusion(SourceOperator):
    """
    BaseFusion operator that combines multiple streams into one.

    Args:
        include_splits: List of splits to include. If None, all splits are included.
    """

    include_splits: Optional[List[str]] = None

    @abstractmethod
    def fusion_generator(self, split) -> Generator:
        pass

    def splits(self) -> Generator:
        splits = []
        for origin in self.origins:
            for s in origin().keys():
                if s not in splits:
                    if self.include_splits is None or s in self.include_splits:
                        splits.append(s)
        return splits

    def process(
        self,
    ) -> MultiStream:
        result = {}
        for split in self.splits():
            result[split] = Stream(self.fusion_generator, gen_kwargs={"split": split})
        return MultiStream(result)


class FixedFusion(BaseFusion):
    """
    FixedFusion operator that combines multiple streams into one based on a fixed number of examples per task.

    Args:
        orgins: List of StreamSource objects.
        examples_per_task: Number of examples per task. If None, all examples are returned.
        splits: List of splits to include. If None, all splits are included.
    """

    examples_per_task: Optional[int] = None

    def fusion_generator(self, split) -> Generator:
        for origin in self.orgins:
            iterator = iter(origin()[split])
            if self.examples_per_task is not None:
                for i in range(self.examples_per_task):
                    yield next(iterator)
            else:
                yield from iterator


class WeightedFusion(BaseFusion):
    """
    Fusion operator that combines multiple streams based

    Args:
        orgins: List of StreamSource objects.
        weights: List of weights for each origin.
        total_examples: Total number of examples to return. If None, all examples are returned.
    """

    origins: List[StreamSource] = None
    weights: List[float] = None
    total_examples: int = None

    def verify(self):
        super().verify()
        assert self.origins is not None, "origins must be specified"
        assert self.weights is not None, "weights must be specified"
        assert len(self.origins) == len(self.weights), "origins and weights must have the same length"

    def fusion_generator(self, split) -> Generator:
        iterators = [iter(origin()[split]) for origin in self.origins]
        total_examples = 0
        while (self.total_examples is None or total_examples <= self.total_examples) and len(iterators) > 0:
            iterator = random.choices(population=iterators, weights=self.weights)[0]
            try:
                yield next(iterator)
                total_examples += 1
            except StopIteration:
                iterators.remove(iterator)


class TasksFusion(SourceOperator):
    """
    TasksFusion operator that combines multiple tasks into one.

    Args:
        tasks: List of TaskCard objects.
        config: ICLCard object.
        examples_per_task: Number of examples per task. If None, all examples are returned.
        include_splits: List of splits to include. If None, all splits are included.
    """

    tasks: List[TaskCard]
    config: ICLCard
    examples_per_task: Optional[int] = None
    include_splits: Optional[List[str]] = None

    def prepare(self):
        self.recipes = []
        for task in self.tasks:
            recipe = CommonRecipe(card=task, **asdict(self.config))

        self.fusion = FixedFusion(
            origins=self.recipes, examples_per_task=self.examples_per_task, include_splits=self.include_splits
        )

    def process(self) -> MultiStream:
        return self.fusion()