File size: 5,499 Bytes
6a0e448
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import base64
import re
import textwrap
from io import BytesIO

from PIL import Image, ImageDraw, ImageFont

from proxy_lite.environments.environment_base import Action, Observation
from proxy_lite.recorder import Run


def create_run_gif(
    run: Run, output_path: str, white_panel_width: int = 300, duration: int = 1500, resize_factor: int = 4
) -> None:
    """
    Generate a gif from the Run object's history.

    For each Observation record, the observation image is decoded from its base64
    encoded string. If the next record is an Action, its text is drawn onto a
    white panel. The observation image and the white panel are then concatenated
    horizontally to produce a frame.

    Parameters:
        run (Run): A Run object with its history containing Observation and Action records.
        output_path (str): The path where the GIF will be saved.
        white_panel_width (int): The width of the white panel for displaying text.
                                 Default increased to 400 for larger images.
        duration (int): Duration between frames in milliseconds.
                        Increased here to slow the FPS (default is 1000ms).
        resize_factor (int): The factor to resize the image down by.
    """
    frames = []
    history = run.history
    i = 0
    while i < len(history):
        if isinstance(history[i], Observation):
            observation = history[i]
            image_data = observation.state.image
            if not image_data:
                i += 1
                continue
            # Decode the base64 image
            image_bytes = base64.b64decode(image_data)
            obs_img = Image.open(BytesIO(image_bytes)).convert("RGB")

            # scale the image down
            obs_img = obs_img.resize((obs_img.width // resize_factor, obs_img.height // resize_factor))

            # Check if the next record is an Action and extract its text if available
            action_text = ""
            if i + 1 < len(history) and isinstance(history[i + 1], Action):
                action = history[i + 1]
                if action.text:
                    action_text = action.text

            # extract observation and thinking from tags in the action text
            observation_match = re.search(r"<observation>(.*?)</observation>", action_text, re.DOTALL)
            observation_content = observation_match.group(1).strip() if observation_match else None

            # Extract text between thinking tags if present
            thinking_match = re.search(r"<thinking>(.*?)</thinking>", action_text, re.DOTALL)
            thinking_content = thinking_match.group(1).strip() if thinking_match else None

            if observation_content and thinking_content:
                action_text = f"**OBSERVATION**\n{observation_content}\n\n**THINKING**\n{thinking_content}"

            # Create a white panel (same height as the observation image)
            panel = Image.new("RGB", (white_panel_width, obs_img.height), "white")
            draw = ImageDraw.Draw(panel)
            font = ImageFont.load_default()

            # Wrap the action text if it is too long
            max_chars_per_line = 40  # Adjusted for larger font size
            wrapped_text = textwrap.fill(action_text, width=max_chars_per_line)

            # Calculate text block size and center it on the panel
            try:
                # Use multiline_textbbox if available (returns bounding box tuple)
                bbox = draw.multiline_textbbox((0, 0), wrapped_text, font=font)
                text_width, text_height = bbox[2] - bbox[0], bbox[3] - bbox[1]
            except AttributeError:
                # Fallback for older Pillow versions: compute size for each line
                lines = wrapped_text.splitlines() or [wrapped_text]
                line_sizes = [draw.textsize(line, font=font) for line in lines]
                text_width = max(width for width, _ in line_sizes)
                text_height = sum(height for _, height in line_sizes)
            text_x = (white_panel_width - text_width) // 2
            text_y = (obs_img.height - text_height) // 2
            draw.multiline_text((text_x, text_y), wrapped_text, fill="black", font=font, align="center")

            # Create the combined frame by concatenating the observation image and the panel
            total_width = obs_img.width + white_panel_width
            combined_frame = Image.new("RGB", (total_width, obs_img.height))
            combined_frame.paste(obs_img, (0, 0))
            combined_frame.paste(panel, (obs_img.width, 0))
            frames.append(combined_frame)

            # Skip the Action record since it has been processed with this Observation
            if i + 1 < len(history) and isinstance(history[i + 1], Action):
                i += 2
            else:
                i += 1
        else:
            i += 1

    if frames:
        frames[0].save(output_path, save_all=True, append_images=frames[1:], duration=duration, loop=0)
    else:
        raise ValueError("No frames were generated from the Run object's history.")


# Example usage:
if __name__ == "__main__":
    from proxy_lite.recorder import Run

    dummy_run = Run.load("0abdb4cb-f289-48b0-ba13-35ed1210f7c1")

    num_steps = int(len(dummy_run.history) / 2)
    print(f"Number of steps: {num_steps}")
    output_gif_path = "trajectory.gif"
    create_run_gif(dummy_run, output_gif_path, duration=1000)
    print(f"Trajectory GIF saved to {output_gif_path}")