File size: 3,672 Bytes
a4da721
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from collections import deque
from typing import Set, Tuple, List
import heapq

def parse_maze(filename: str) -> Tuple[List[List[str]], Tuple[int, int], Tuple[int, int]]:
    maze = []
    start = end = None
    with open(filename, 'r') as f:
        for i, line in enumerate(f):
            row = list(line.strip())
            if 'S' in row:
                start = (i, row.index('S'))
            if 'E' in row:
                end = (i, row.index('E'))
            maze.append(row)
    return maze, start, end

def get_normal_distance(maze: List[List[str]], start: Tuple[int, int], end: Tuple[int, int]) -> int:
    rows, cols = len(maze), len(maze[0])
    visited = set()
    queue = [(0, start)]
    
    while queue:
        dist, (r, c) = heapq.heappop(queue)
        
        if (r, c) == end:
            return dist
            
        if (r, c) in visited:
            continue
            
        visited.add((r, c))
        
        for dr, dc in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
            nr, nc = r + dr, c + dc
            if (0 <= nr < rows and 0 <= nc < cols and 
                maze[nr][nc] != '#' and (nr, nc) not in visited):
                heapq.heappush(queue, (dist + 1, (nr, nc)))
    
    return float('inf')

def find_cheats(maze: List[List[str]], start: Tuple[int, int], end: Tuple[int, int], 
                max_cheat_duration: int) -> Set[int]:
    rows, cols = len(maze), len(maze[0])
    normal_dist = get_normal_distance(maze, start, end)
    saved_times = set()
    visited = set()
    
    def is_valid(r: int, c: int) -> bool:
        return 0 <= r < rows and 0 <= nc < cols
    
    queue = deque([(start, 0, 0, False, set())])  # pos, dist, cheat_duration, is_cheating, path
    
    while queue:
        (r, c), dist, cheat_dur, is_cheating, path = queue.popleft()
        
        state = ((r, c), cheat_dur, is_cheating)
        if state in visited:
            continue
        visited.add(state)
        
        if (r, c) == end and not is_cheating:
            if dist < normal_dist:
                saved_times.add(normal_dist - dist)
            continue
            
        for dr, dc in [(0, 1), (1, 0), (0, -1), (-1, 0)]:
            nr, nc = r + dr, c + dc
            if not (0 <= nr < rows and 0 <= nc < cols):
                continue
                
            new_path = path | {(nr, nc)}
            
            # Normal movement
            if not is_cheating and maze[nr][nc] != '#':
                queue.append(((nr, nc), dist + 1, cheat_dur, False, new_path))
            
            # Start cheating
            if not is_cheating and cheat_dur == 0 and maze[nr][nc] == '#':
                queue.append(((nr, nc), dist + 1, 1, True, new_path))
            
            # Continue cheating
            if is_cheating and cheat_dur < max_cheat_duration:
                if maze[nr][nc] != '#' and (nr, nc) not in path:  # Can end cheat
                    queue.append(((nr, nc), dist + 1, cheat_dur, False, new_path))
                queue.append(((nr, nc), dist + 1, cheat_dur + 1, True, new_path))
    
    return saved_times

def solve(filename: str) -> Tuple[int, int]:
    maze, start, end = parse_maze(filename)
    
    # Part 1: cheats that save >= 100 picoseconds with max duration of 2
    saved_times_p1 = find_cheats(maze, start, end, 2)
    part1 = sum(1 for t in saved_times_p1 if t >= 100)
    
    # Part 2: cheats that save >= 100 picoseconds with max duration of 20
    saved_times_p2 = find_cheats(maze, start, end, 20)
    part2 = sum(1 for t in saved_times_p2 if t >= 100)
    
    return part1, part2

part1, part2 = solve("input.txt")
print(str(part1))
print(str(part2))