Spaces:
Running
Running
File size: 6,526 Bytes
30fabb4 7eda955 30fabb4 7eda955 30fabb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 |
import re
import requests
import textwrap
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class CoTStep:
"""Data class representing a single CoT step"""
number: int
content: str
@dataclass
class CoTResponse:
"""Data class representing a complete CoT response"""
question: str
steps: List[CoTStep]
answer: Optional[str] = None
@dataclass
class VisualizationConfig:
"""Configuration for CoT visualization"""
max_chars_per_line: int = 40
max_lines: int = 4
truncation_suffix: str = "..."
class AnthropicAPI:
"""Class to handle interactions with the Anthropic API"""
def __init__(self, api_key: str, model: str = "claude-3-opus-20240229"):
self.api_key = api_key
self.model = model
self.base_url = "https://api.anthropic.com/v1/messages"
self.headers = {
"x-api-key": api_key,
"anthropic-version": "2023-06-01",
"content-type": "application/json"
}
def generate_response(self, prompt: str, max_tokens: int = 1024, prompt_format: str = None) -> str:
"""Generate a response using the Anthropic API"""
formatted_prompt = self._format_prompt(prompt, prompt_format) if prompt_format else prompt
data = {
"model": self.model,
"messages": [{"role": "user", "content": formatted_prompt}],
"max_tokens": max_tokens
}
try:
response = requests.post(self.base_url, headers=self.headers, json=data)
response.raise_for_status()
return response.json()["content"][0]["text"]
except Exception as e:
raise Exception(f"API call failed: {str(e)}")
def _format_prompt(self, question: str, prompt_format: str = None) -> str:
"""Format the prompt using custom format if provided"""
if prompt_format:
return prompt_format.format(question=question)
# Default format if none provided
return f"""Please answer the question using the following format, with each step clearly marked:
Question: {question}
Let's solve this step by step:
<step number="1">
[First step of reasoning]
</step>
<step number="2">
[Second step of reasoning]
</step>
<step number="3">
[Third step of reasoning]
</step>
... (add more steps as needed)
<answer>
[Final answer]
</answer>
Note:
1. Each step must be wrapped in XML tags <step>
2. Each step must have a number attribute
3. The final answer must be wrapped in <answer> tags
"""
def wrap_text(text: str, config: VisualizationConfig) -> str:
"""Wrap text to fit within box constraints"""
text = text.replace('\n', ' ').replace('"', "'")
wrapped_lines = textwrap.wrap(text, width=config.max_chars_per_line)
if len(wrapped_lines) > config.max_lines:
# Option 1: Simply truncate and add ellipsis to the last line
wrapped_lines = wrapped_lines[:config.max_lines]
wrapped_lines[-1] = wrapped_lines[-1][:config.max_chars_per_line-3] + "..."
# Option 2 (alternative): Include part of the next line to show continuity
# original_next_line = wrapped_lines[config.max_lines] if len(wrapped_lines) > config.max_lines else ""
# wrapped_lines = wrapped_lines[:config.max_lines-1]
# wrapped_lines.append(original_next_line[:config.max_chars_per_line-3] + "...")
return "<br>".join(wrapped_lines)
def parse_cot_response(response_text: str, question: str) -> CoTResponse:
"""
Parse CoT response text to extract steps and final answer.
Args:
response_text: The raw response from the API
question: The original question
Returns:
CoTResponse object containing question, steps, and answer
"""
# Extract all steps
step_pattern = r'<step number="(\d+)">\s*(.*?)\s*</step>'
steps = []
for match in re.finditer(step_pattern, response_text, re.DOTALL):
number = int(match.group(1))
content = match.group(2).strip()
steps.append(CoTStep(number=number, content=content))
# Extract answer
answer_pattern = r'<answer>\s*(.*?)\s*</answer>'
answer_match = re.search(answer_pattern, response_text, re.DOTALL)
answer = answer_match.group(1).strip() if answer_match else None
# Sort steps by number
steps.sort(key=lambda x: x.number)
return CoTResponse(question=question, steps=steps, answer=answer)
def create_mermaid_diagram(cot_response: CoTResponse, config: VisualizationConfig) -> str:
"""
Convert CoT steps to Mermaid diagram with improved text wrapping.
Args:
cot_response: CoTResponse object containing the reasoning steps
config: VisualizationConfig for text formatting
Returns:
Mermaid diagram markup as a string
"""
diagram = ['<div class="mermaid">', 'graph TD']
# Add question node
question_content = wrap_text(cot_response.question, config)
diagram.append(f' Q["{question_content}"]')
# Add steps with wrapped text and connect them
if cot_response.steps:
# Connect question to first step
diagram.append(f' Q --> S{cot_response.steps[0].number}')
# Add all steps
for i, step in enumerate(cot_response.steps):
content = wrap_text(step.content, config)
node_id = f'S{step.number}'
diagram.append(f' {node_id}["{content}"]')
# Connect steps sequentially
if i < len(cot_response.steps) - 1:
next_id = f'S{cot_response.steps[i + 1].number}'
diagram.append(f' {node_id} --> {next_id}')
# Add final answer node
if cot_response.answer:
answer = wrap_text(cot_response.answer, config)
diagram.append(f' A["{answer}"]')
if cot_response.steps:
diagram.append(f' S{cot_response.steps[-1].number} --> A')
else:
diagram.append(' Q --> A')
# Add styles for better visualization
diagram.extend([
' classDef default fill:#f9f9f9,stroke:#333,stroke-width:2px;',
' classDef question fill:#e3f2fd,stroke:#1976d2,stroke-width:2px;',
' classDef answer fill:#d4edda,stroke:#28a745,stroke-width:2px;',
' class Q question;',
' class A answer;',
' linkStyle default stroke:#666,stroke-width:2px;'
])
diagram.append('</div>')
return '\n'.join(diagram) |