File size: 5,063 Bytes
30fabb4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
from dataclasses import dataclass
from typing import List, Optional
import re
import textwrap
from cot_reasoning import VisualizationConfig, AnthropicAPI

@dataclass
class ToTNode:
    """Data class representing a node in the Tree of Thoughts"""
    id: str
    content: str
    parent_id: Optional[str] = None
    children: List['ToTNode'] = None
    is_answer: bool = False

    def __post_init__(self):
        if self.children is None:
            self.children = []

@dataclass
class ToTResponse:
    """Data class representing a complete ToT response"""
    question: str
    root: ToTNode
    answer: Optional[str] = None

def parse_tot_response(response_text: str, question: str) -> ToTResponse:
    """Parse ToT response text to extract nodes and build the tree"""
    # Parse nodes
    node_pattern = r'<node id="([^"]+)"(?:\s+parent="([^"]+)")?\s*>\s*(.*?)\s*</node>'
    nodes_dict = {}
    
    # First pass: create all nodes
    for match in re.finditer(node_pattern, response_text, re.DOTALL):
        node_id = match.group(1)
        parent_id = match.group(2)
        content = match.group(3).strip()
        
        node = ToTNode(id=node_id, content=content, parent_id=parent_id)
        nodes_dict[node_id] = node

    # Second pass: build tree relationships
    root = None
    for node in nodes_dict.values():
        if node.parent_id is None:
            root = node
        else:
            parent = nodes_dict.get(node.parent_id)
            if parent:
                parent.children.append(node)

    # Parse answer if present
    answer_pattern = r'<answer>\s*(.*?)\s*</answer>'
    answer_match = re.search(answer_pattern, response_text, re.DOTALL)
    answer = answer_match.group(1).strip() if answer_match else None
    
    if answer:
        # Mark the node leading to the answer
        for node in nodes_dict.values():
            if node.content.strip() in answer.strip():
                node.is_answer = True

    return ToTResponse(question=question, root=root, answer=answer)

def create_mermaid_diagram(tot_response: ToTResponse, config: VisualizationConfig) -> str:
    """Convert ToT response to Mermaid diagram"""
    diagram = ['<div class="mermaid">', 'graph TD']
    
    # Add question node
    question_content = wrap_text(tot_response.question, config)
    diagram.append(f'    Q["{question_content}"]')
    
    # Track leaf nodes for connecting to answer
    leaf_nodes = []
    
    def add_node_and_children(node: ToTNode, parent_id: Optional[str] = None):
        content = wrap_text(node.content, config)
        node_style = 'answer' if node.is_answer else 'default'
        
        # Add node
        diagram.append(f'    {node.id}["{content}"]')
        
        # Add connection from parent
        if parent_id:
            diagram.append(f'    {parent_id} --> {node.id}')
        
        # Process children
        if node.children:
            for child in node.children:
                add_node_and_children(child, node.id)
        else:
            # This is a leaf node
            leaf_nodes.append(node.id)
    
    # Build tree structure
    if tot_response.root:
        diagram.append(f'    Q --> {tot_response.root.id}')
        add_node_and_children(tot_response.root)
    
    # Add final answer node if answer exists
    if tot_response.answer:
        answer_content = wrap_text(tot_response.answer, config)
        diagram.append(f'    Answer["{answer_content}"]')
        # Connect all leaf nodes to the answer
        for leaf_id in leaf_nodes:
            diagram.append(f'    {leaf_id} --> Answer')
        diagram.append('    class Answer final_answer;')
    
    # Add styles
    diagram.extend([
        '    classDef default fill:#f9f9f9,stroke:#333,stroke-width:2px;',
        '    classDef question fill:#e3f2fd,stroke:#1976d2,stroke-width:2px;',
        '    classDef answer fill:#d4edda,stroke:#28a745,stroke-width:2px;',
        '    classDef final_answer fill:#d4edda,stroke:#28a745,stroke-width:2px;',
        '    class Q question;',
        '    linkStyle default stroke:#666,stroke-width:2px;'
    ])
    
    diagram.append('</div>')
    return '\n'.join(diagram)

def wrap_text(text: str, config: VisualizationConfig) -> str:
    """Wrap text to fit within box constraints"""
    text = text.replace('\n', ' ').replace('"', "'")
    wrapped_lines = textwrap.wrap(text, width=config.max_chars_per_line)
    
    if len(wrapped_lines) > config.max_lines:
        # Option 1: Simply truncate and add ellipsis to the last line
        wrapped_lines = wrapped_lines[:config.max_lines]
        wrapped_lines[-1] = wrapped_lines[-1][:config.max_chars_per_line-3] + "..."
        
        # Option 2 (alternative): Include part of the next line to show continuity
        # original_next_line = wrapped_lines[config.max_lines] if len(wrapped_lines) > config.max_lines else ""
        # wrapped_lines = wrapped_lines[:config.max_lines-1]
        # wrapped_lines.append(original_next_line[:config.max_chars_per_line-3] + "...")
    
    return "<br>".join(wrapped_lines)