from dataclasses import dataclass
from typing import List, Set



@dataclass
class GraphConfig:
    mem_f: float = 2
    mem_b: float = -1
    mem_w: float = -1
    max_mem: float = None
    cost_f: int = 1
    cost_b: int = 1
    cost_w: int = 1
    cost_comm: int = 0
    print_scaling: int = 1

    def __post_init__(self):
        assert type(self.cost_f) is int
        assert type(self.cost_b) is int
        assert type(self.cost_w) is int
        assert type(self.cost_comm) is int
        assert self.mem_f + self.mem_b + self.mem_w == 0

@dataclass(eq=True, frozen=True)
class ScheduledNode:
    type: str
    stage: int
    minibatch: int
    start_time: int
    completion_time: int
    rollback: bool = False


@dataclass
class Graph:
    nstages: int
    nmb: int
    nnodes: int
    config: GraphConfig
    parents: List[Set[int]] = None
    name: List[str] = None

    # ID mapping:
    # F[stage][minibatch]: 0..STAGE* MB
    # B[stage][minibatch]: STAGE* MB .. 2 * STAGE * MB
    # W[stage][minibatch]: 2 * STAGE* MB .. 3 * STAGE * MB

    def get_id(self, type, stage, mb):
        return type * (self.nstages * self.nmb) + stage * self.nmb + mb

    def get_stage(self, id):
        return (id // self.nmb) % self.nstages

    def get_cost(self, id):
        type = id // (self.nstages * self.nmb)
        return [self.config.cost_f, self.config.cost_b, self.config.cost_w][type]

    def get_mem(self, id):
        type = id // (self.nstages * self.nmb)
        return [self.config.mem_f, self.config.mem_b, self.config.mem_w][type]

    @classmethod
    def build_graph(cls, nstages, nmb, config):
        nnodes = nstages * nmb * 3
        g = Graph(nstages=nstages, nmb=nmb, nnodes=nnodes, config=config)
        parents = []
        name = []
        for type in range(3):
            for stage in range(nstages):
                for mb in range(nmb):
                    p = set()
                    if type == 0:
                        name.append(f'F{mb}')
                        if stage > 0:
                            p.add(g.get_id(type, stage - 1, mb))
                        if mb > 0:
                            p.add(g.get_id(type, stage, mb - 1))
                    elif type == 1:
                        name.append(f'B{mb}')
                        if stage == nstages - 1:
                            p.add(g.get_id(0, stage, mb))
                        else:
                            p.add(g.get_id(type, stage + 1, mb))
                        if mb > 0:
                            p.add(g.get_id(type, stage, mb - 1))
                    elif type == 2:
                        name.append(f'W{mb}')
                        p.add(g.get_id(1, stage, mb))
                        if mb > 0:
                            p.add(g.get_id(type, stage, mb - 1))
                    else:
                        assert False
                    parents.append(p)

        g.name = name
        g.parents = parents
        return g

    # Manual ordering producing this kind of schedule:
    # fffffffbfbfbfbfbfbwbwbwbwbwbwbwwwwww
    #  fffffbfbfbfbfbfbfbfbwbwbwbwbwwwwwwww
    #   fffbfbfbfbfbfbfbfbfbfbwbwbwwwwwwwwww
    #    fbfbfbfbfbfbfbfbfbfbfbfbwwwwwwwwwwww
    # Returns the order index of each node on its own stage
    def manual_order(
        self, allow_bubble_before_first_b=False, prioritize_b=False, no_bubble_greedy=True
    ):
        order = [0] * self.nnodes
        f = [0] * self.nstages
        b = [0] * self.nstages
        w = [0] * self.nstages
        o = [0] * self.nstages
        m = [0] * self.nstages
        e = [0] * self.nstages
        t = [0] * self.nnodes
        max_mem = self.config.max_mem or self.get_mem(self.get_id(0, 0, 0)) * self.nmb * 3
        comm = self.config.cost_comm
        order_str = [""] * self.nstages
        stage_bubble = [0] * self.nstages

        def get_max_bubble():
            max_bubble = 0
            for bb in stage_bubble:
                max_bubble = max(max_bubble, bb)
            return max_bubble

        def put(stage_j, type_k):
            if type_k == 0:
                _i = f[stage_j]
            elif type_k == 1:
                _i = b[stage_j]
            else:
                _i = w[stage_j]
            _j = stage_j
            _id = self.get_id(type_k, _j, _i)
            _mem = self.get_mem(_id)
            _cost = self.get_cost(_id)
            assert m[_j] + _mem <= max_mem

            tmp = e[_j] + _cost
            no_bubble = tmp
            if _j > 0 and type_k == 0:
                tmp = max(tmp, t[self.get_id(0, _j - 1, _i)] + comm + _cost)
            if _j < self.nstages - 1 and type_k == 1:
                tmp = max(tmp, t[self.get_id(1, _j + 1, _i)] + comm + _cost)
            if f[_j] > 0:
                stage_bubble[_j] += tmp - no_bubble
            e[_j] = tmp
            t[_id] = tmp
            m[_j] += _mem
            order[_id] = o[_j]
            if type_k == 0:
                f[_j] += 1
            elif type_k == 1:
                b[_j] += 1
            else:
                w[_j] += 1
            o[_j] += 1
            fbw = "fbw"
            order_str[stage_j] += fbw[type_k]

        for i in range(self.nmb):
            if i == 0:
                for j in range(self.nstages):
                    put(j, 0)
                f_required = [0] * self.nstages
                last_t = 0
                for j in range(self.nstages - 1, -1, -1):
                    if j == self.nstages - 1:
                        last_t = t[self.get_id(0, j, i)] + self.get_cost(self.get_id(1, j, i))
                        continue
                    mem = m[j]
                    cost = e[j]
                    while True:
                        f_id = self.get_id(0, j, f[j] + f_required[j])
                        if f[j] + f_required[j] < self.nmb and mem + self.get_mem(f_id) <= max_mem:
                            if allow_bubble_before_first_b:
                                if cost + self.get_cost(f_id) > last_t + comm:
                                    break
                            else:
                                if cost >= last_t + comm:
                                    break
                            mem += self.get_mem(f_id)
                            cost += self.get_cost(f_id)
                            f_required[j] += 1
                        else:
                            break
                    last_t = max(cost, last_t + comm) + self.get_cost(self.get_id(1, j, i))
                for j in range(self.nstages):
                    while j > 0 and f_required[j] > 0 and f_required[j] >= f_required[j - 1] and f[j] + f_required[j] < self.nmb:
                        f_required[j] -= 1
                for j in range(self.nstages - 1, -1, -1):
                    for _ in range(f_required[j]):
                        put(j, 0)
                    put(j, 1)
                continue
            f_required = [0] * self.nstages
            for j in range(self.nstages):
                if f[j] >= self.nmb:
                    continue
                if j + 1 < self.nstages and f[j] >= f[j + 1] + 2 and prioritize_b:
                    next_plus_fw = (
                        e[j + 1]
                        + self.get_cost(self.get_id(0, j + 1, f[j + 1]))
                        + self.get_cost(self.get_id(1, j + 1, b[j + 1]))
                        + comm
                    )
                    if e[j] >= next_plus_fw:
                        continue
                    f_id = self.get_id(0, j, f[j])
                    f_mem = self.get_mem(f_id)
                    w_cost, w_cnt = 0, 0
                    mem_with_w = m[j] + f_mem
                    while mem_with_w > max_mem and w[j] + w_cnt < b[j]:
                        w_id = self.get_id(2, j, w[j] + w_cnt)
                        w_cost += self.get_cost(w_id)
                        mem_with_w += self.get_mem(w_id)
                        w_cnt += 1
                    if e[j] + self.get_cost(f_id) + w_cost <= next_plus_fw:
                        f_required[j] = 1
                        continue

                    w_cost, w_cnt = 0, 0
                    # mem_with_w = m[j]
                    # while w[j] + w_cnt < b[j]:
                    #     w_id = self.get_id(2, j, w[j] + w_cnt)
                    #     w_cost += self.get_cost(w_id)
                    #     mem_with_w += self.get_mem(w_id)
                    #     w_cnt += 1
                    # if e[j] + w_cost >= next_plus_fw:
                    #     continue
                    if next_plus_fw - (e[j] + w_cost) <= get_max_bubble() - stage_bubble[j]:
                        # TODO: can sample here
                        continue
                f_required[j] = 1
            for j in range(self.nstages - 2, -1, -1):
                f_required[j] = min(f_required[j], f_required[j + 1])
            for j in range(self.nstages):
                if f_required[j] == 0:
                    continue
                f_id = self.get_id(0, j, f[j])
                mem = self.get_mem(f_id)
                while m[j] + mem > max_mem:
                    if w[j] >= b[j]:
                        raise ValueError("Cannot fit memory")
                    put(j, 2)
                if j > 0:
                    while (
                        w[j] < b[j]
                        and e[j] + self.get_cost(self.get_id(2, j, w[j]))
                        <= t[self.get_id(0, j - 1, f[j])] + comm
                    ):
                        put(j, 2)
                    if w[j] < b[j] and e[j] < t[self.get_id(0, j - 1, f[j])] + comm:
                        # TODO: e[j] + self.get_cost(self.get_id(2, j, w[j])) > t[self.get_id(0, j - 1, f[j])] + comm
                        if (
                            t[self.get_id(0, j - 1, f[j])] + comm - e[j]
                            <= get_max_bubble() - stage_bubble[j]
                        ):
                            # TODO: can sample here
                            if no_bubble_greedy:
                                put(j, 2)
                        else:
                            put(j, 2)
                put(j, 0)
            for j in range(self.nstages - 1, -1, -1):
                assert b[j] == i
                b_id = self.get_id(1, j, b[j])
                mem = self.get_mem(b_id)
                while m[j] + mem > max_mem:
                    if w[j] >= b[j]:
                        raise ValueError("Cannot fit memory")
                    put(j, 2)
                if j + 1 < self.nstages:
                    while (
                        w[j] < b[j]
                        and e[j] + self.get_cost(self.get_id(2, j, w[j]))
                        <= t[self.get_id(1, j + 1, i)] + comm
                    ):
                        put(j, 2)
                    if w[j] < b[j] and e[j] < t[self.get_id(1, j + 1, i)] + comm:
                        # TODO: e[j] + self.get_cost(self.get_id(2, j, w[j])) > t[self.get_id(1, j + 1, i)] + comm
                        if (
                            t[self.get_id(1, j + 1, i)] + comm - e[j]
                            <= get_max_bubble() - stage_bubble[j]
                        ):
                            # TODO: can sample here
                            if no_bubble_greedy:
                                put(j, 2)
                        else:
                            put(j, 2)
                if j == 0 and f[j] == self.nmb:
                    while w[j] < b[j]:
                        put(j, 2)
                put(j, 1)

        for i in range(self.nstages):
            while w[i] < self.nmb:
                put(i, 2)
            # print(f"{' ' * i}{order_str[i]}  -> {e[i]}")

        for i in range(self.nstages):
            for j in range(self.nmb):
                f_id = self.get_id(0, i, j)
                b_id = self.get_id(1, i, j)
                w_id = self.get_id(2, i, j)
                f_cost = self.get_cost(f_id)
                b_cost = self.get_cost(b_id)
                w_cost = self.get_cost(w_id)
                assert t[b_id] >= t[f_id] + b_cost
                assert t[w_id] >= t[b_id] + w_cost, f"{i}-{j}, {t[w_id]} >= {t[b_id]} + {w_cost}"
                if i > 0:
                    assert t[f_id] >= t[self.get_id(0, i - 1, j)] + comm + f_cost, f"{i}-{j}"
                if i < self.nstages - 1:
                    assert t[b_id] >= t[self.get_id(1, i + 1, j)] + comm + b_cost

        # print(order)
        best_time = 0
        for i in range(self.nstages):
            time_i = (
                t[self.get_id(2, i, self.nmb - 1)]
                - t[self.get_id(0, i, 0)]
                + self.get_cost(self.get_id(0, i, 0))
            )
            best_time = max(best_time, time_i)

        return order, t, best_time


def initial_solution(graph):
    best_time, order, complete_time = None, None, None
    for allow_bubble_before_first_b in [True, False]:
        for prioritize_b in [True, False]:
            for no_bubble_greedy in [True, False]:
                order_t, complete_time_t, best_time_t = graph.manual_order(
                    allow_bubble_before_first_b=allow_bubble_before_first_b,
                    prioritize_b=prioritize_b,
                    no_bubble_greedy=no_bubble_greedy,
                )
                if best_time is None or best_time_t < best_time:
                    best_time = best_time_t
                    order = order_t
                    complete_time = complete_time_t

    print_detail(graph, complete_time)
    print("-" * 20, best_time, "-" * 20)
    return best_time, order, complete_time


def print_detail(graph, F):
    typenames = ['F', 'B', 'W']
    times = []
    for stage in range(graph.nstages):
        stage_str = ['.'] * int(F[graph.get_id(2, stage, graph.nmb - 1)] / graph.config.print_scaling)
        for _type in range(3):
            for _mb in range(graph.nmb):
                _id = graph.get_id(_type, stage, _mb)
                end = int(F[_id] / graph.config.print_scaling)
                start = int((F[_id] - graph.get_cost(_id)) / graph.config.print_scaling)
                for j in range(start, end):
                    if j == start or j == end - 1:
                        stage_str[j] = typenames[_type]
                    elif j == start + 1:
                        if _mb >= 10:
                            stage_str[j] = str(_mb // 10)
                        else:
                            stage_str[j] = str(_mb)
                    elif j == start + 2 and _mb >= 10:
                        stage_str[j] = str(_mb % 10)
                    else:
                        stage_str[j] = "-"
        _str = ""
        for _c in stage_str:
            _str += _c
        times.append(
            F[graph.get_id(2, stage, graph.nmb - 1)]
            - F[graph.get_id(0, stage, 0)]
            + graph.get_cost(graph.get_id(0, stage, 0))
        )
        print(_str)
    print('Longest stage time: ', max(times))


def ilp_results(graph, F):
    typenames = ['F', 'B', 'W']
    local_order = []
    end_time = []
    for i in range(graph.nnodes):
        end_time.append(F[i])
    for stage in range(graph.nstages):
        order = []
        for type in range(3):
            for mb in range(graph.nmb):
                id = graph.get_id(type, stage, mb)
                order.append(
                    ScheduledNode(
                        type=typenames[type],
                        stage=stage,
                        minibatch=mb,
                        start_time=end_time[id] - graph.get_cost(id),
                        completion_time=F[id],
                    )
                )
        local_order.append(order)
    # For each F/B, append a send/recv node. The timestamp of recv node is the same as send node to guarrentee a global order.
    comm_id = {}
    comm_id_counter = 0
    post_validation_time = 0
    for i in range(graph.nstages - 1, -1, -1):
        warmup_f_count = -1
        first_b_end = end_time[graph.get_id(1, i, 0)]
        for j in range(graph.nmb):
            if end_time[graph.get_id(0, i, j)] < first_b_end:
                warmup_f_count += 1
        assert warmup_f_count >= 0
        pv_id = warmup_f_count
        _id = graph.get_id(0, i, pv_id)
        _cost = graph.get_cost(_id)
        post_validation_time = max(post_validation_time, end_time[_id] - _cost - graph.config.cost_comm)
        # post_validation_time = 0
        # print(i, pv_id, post_validation_time)
        for it in ["RECV_", "SEND_", ""]:
            if i == 0 and it == "SEND_":
                continue
            if i == graph.nstages - 1 and it == "RECV_":
                continue
            # stage_ = i - 1 if it == "RECV_" else i
            stage_ = i
            local_order[stage_].append(ScheduledNode(
                type=it + "POST_VALIDATION",
                stage=stage_,
                minibatch=0,
                start_time=post_validation_time,
                completion_time=post_validation_time,
            ))
            comm_id[local_order[stage_][-1]] = comm_id_counter
            comm_id_counter += 1
    for stage in range(graph.nstages):
        for node in local_order[stage]:
            if node.type == 'F' and node.stage != graph.nstages - 1:
                local_order[stage].append(
                    ScheduledNode(
                        type='SEND_FORWARD',
                        stage=stage,
                        minibatch=node.minibatch,
                        start_time=node.completion_time,
                        completion_time=node.completion_time,  # TODO: consider comm cost in completion time
                    )
                )
                local_order[stage + 1].append(
                    ScheduledNode(
                        type='RECV_FORWARD',
                        stage=stage + 1,
                        minibatch=node.minibatch,
                        start_time=node.completion_time,
                        completion_time=node.completion_time,  # TODO: consider comm cost in completion time
                    )
                )
                comm_id[local_order[stage][-1]] = comm_id_counter
                comm_id[local_order[stage + 1][-1]] = comm_id_counter
                comm_id_counter += 1
            if node.type == 'B' and node.stage != 0:
                local_order[stage].append(
                    ScheduledNode(
                        type='SEND_BACKWARD',
                        stage=stage,
                        minibatch=node.minibatch,
                        start_time=node.completion_time,
                        completion_time=node.completion_time,  # TODO: consider comm cost in completion time
                    )
                )
                local_order[stage - 1].append(
                    ScheduledNode(
                        type='RECV_BACKWARD',
                        stage=stage - 1,
                        minibatch=node.minibatch,
                        start_time=node.completion_time,
                        completion_time=node.completion_time,  # TODO: consider comm cost in completion time
                    )
                )
                comm_id[local_order[stage][-1]] = comm_id_counter
                comm_id[local_order[stage - 1][-1]] = comm_id_counter
                comm_id_counter += 1
    for stage in range(graph.nstages):
        # For nodes with the same timestamp on the same stage, communication will be prioritized.
        def even_breaker(x: ScheduledNode):
            # Compute nodes are always delayed.
            if x.type in ['F', 'B', 'W']:
                return comm_id_counter
            # For comm nodes, order by their unique comm id
            return comm_id[x]

        local_order[stage] = list(sorted(
            local_order[stage], key=lambda x: (x.start_time, even_breaker(x))
        ))
        # If a recv with intersects with previous computation, reorder them so that recv
        # is executed before computation and hence can be overlapped.
        for i in range(len(local_order[stage])):
            if i > 0 and local_order[stage][i - 1].type in {'F', 'B', 'W'} and \
                local_order[stage][i].type.startswith('RECV') and \
                "POST_VALIDATION" not in local_order[stage][i].type and \
                local_order[stage][i].start_time <= local_order[stage][i - 1].completion_time:
                (local_order[stage][i], local_order[stage][i - 1]) = (local_order[stage][i - 1], local_order[stage][i])
        # print([(x.type, x.start_time, x.completion_time) for x in local_order[stage]])

    local_order_with_rollback = [[] for _ in range(graph.nstages)]
    for rank in range(graph.nstages):
        rollback_comm = set()
        if rank > 0:
            for node in local_order[rank - 1]:
                if node.type == "POST_VALIDATION":
                    break
                if node.type == "SEND_FORWARD":
                    rollback_comm.add(node.minibatch)
        for node in local_order[rank]:
            if node.type == "RECV_FORWARD" and node.minibatch in rollback_comm:
                rollback = True
                rollback_comm.remove(node.minibatch)
            else:
                rollback = False
            local_order_with_rollback[rank].append(ScheduledNode(
                type=node.type,
                stage=node.stage,
                minibatch=node.minibatch,
                start_time=node.start_time,
                completion_time=node.completion_time,
                rollback=rollback,
            ))
        assert len(rollback_comm) == 0
        # for node in local_order_with_rollback[rank]:
        #     print(f"{node.type}-{node.minibatch}-{int(node.rollback)}", end=', ')
        # print()

    print_detail(graph, end_time)
    return local_order_with_rollback


def auto_schedule(nstages, nmb, config):
    graph = Graph.build_graph(nstages, nmb, config)
    
    best_time, order, complete_time = initial_solution(graph)
    return ilp_results(graph, complete_time)


if __name__ == "__main__":
    # auto_schedule(4, 12, GraphConfig(cost_f=5, cost_b=6, cost_w=4, cost_comm=0, max_mem=10))
    # auto_schedule(4, 12, GraphConfig(cost_f=5, cost_b=6, cost_w=4, cost_comm=0, max_mem=14))
    auto_schedule(24, 72, GraphConfig(cost_f=5, cost_b=6, cost_w=4, cost_comm=0, max_mem=100))
    auto_schedule(4, 12, GraphConfig(
        cost_f=5478,
        cost_b=5806,
        cost_w=3534,
        cost_comm=200,
        max_mem=32,
        print_scaling=1000
    ))
    auto_schedule(32, 16, GraphConfig(
        cost_f=1,
        cost_b=1,
        cost_w=1,
        cost_comm=0,
        max_mem=64,
    ))