Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import json | |
| import pandas as pd | |
| import markdown | |
| # Page configuration for wide layout | |
| st.set_page_config( | |
| page_title="Chat Trajectory Viewer", | |
| page_icon="π¬", | |
| layout="wide", | |
| initial_sidebar_state="collapsed" | |
| ) | |
| def load_data(): | |
| """Load the JSONL data file""" | |
| try: | |
| df = pd.read_json("/home/lisabdunlap/LMM-Vibes/data/taubench/airline_data_oai_format.jsonl", lines=True) | |
| return df | |
| except Exception as e: | |
| st.error(f"Error loading data: {e}") | |
| return None | |
| def display_message(role, content, index): | |
| """Display a single message with role-specific styling""" | |
| # Define colors for different roles | |
| role_colors = { | |
| "system": "#ff6b6b", # Red | |
| "user": "#4ecdc4", # Teal | |
| "assistant": "#45b7d1", # Blue | |
| "tool": "#96ceb4", # Green | |
| "info": "#feca57" # Yellow | |
| } | |
| # Get color for this role, default to gray | |
| color = role_colors.get(role.lower(), "#95a5a6") | |
| # Format content for HTML display | |
| if isinstance(content, dict): | |
| content_html = f"<pre style='background: #f8f9fa; padding: 10px; border-radius: 4px; overflow-x: auto;'>{json.dumps(content, indent=2)}</pre>" | |
| elif isinstance(content, str): | |
| # Convert markdown to HTML properly | |
| content_html = markdown.markdown(content, extensions=['nl2br', 'fenced_code']) | |
| elif content is None: | |
| content_html = "<em>(No content)</em>" | |
| else: | |
| content_html = markdown.markdown(str(content), extensions=['nl2br']) | |
| # Special handling for system messages - make them collapsible | |
| if role.lower() == "system": | |
| with st.expander(f"π§ {role.upper()}", expanded=False): | |
| st.markdown(f""" | |
| <div style=" | |
| border-left: 4px solid {color}; | |
| margin: 8px 0; | |
| background-color: #f8f9fa; | |
| padding: 12px; | |
| border-radius: 0 8px 8px 0; | |
| box-shadow: 0 1px 3px rgba(0,0,0,0.1); | |
| "> | |
| <div style=" | |
| color: #666; | |
| font-size: 14px; | |
| font-weight: bold; | |
| margin-bottom: 8px; | |
| text-transform: uppercase; | |
| letter-spacing: 1px; | |
| ">{role}</div> | |
| <div style=" | |
| color: #333; | |
| line-height: 1.6; | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| "> | |
| {content_html} | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| else: | |
| # Regular message display for non-system messages | |
| st.markdown(f""" | |
| <div style=" | |
| border-left: 4px solid {color}; | |
| margin: 8px 0; | |
| background-color: #f8f9fa; | |
| padding: 12px; | |
| border-radius: 0 8px 8px 0; | |
| box-shadow: 0 1px 3px rgba(0,0,0,0.1); | |
| "> | |
| <div style=" | |
| color: #666; | |
| font-size: 14px; | |
| font-weight: bold; | |
| margin-bottom: 8px; | |
| text-transform: uppercase; | |
| letter-spacing: 1px; | |
| ">{role}</div> | |
| <div style=" | |
| color: #333; | |
| line-height: 1.6; | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| "> | |
| {content_html} | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Handle special cases that need separate display | |
| if isinstance(content, dict) and len(json.dumps(content, indent=2)) > 500: | |
| # For very large JSON, also show expandable version | |
| with st.expander("View JSON in expandable format"): | |
| st.json(content) | |
| def main(): | |
| st.title("Chat Trajectory Viewer") | |
| st.markdown("Browse through airline booking conversation trajectories") | |
| # Custom CSS with more specific selectors and Streamlit configuration | |
| st.markdown(""" | |
| <style> | |
| /* Force full width layout */ | |
| .main .block-container { | |
| max-width: 100% !important; | |
| padding-left: 0.5rem !important; | |
| padding-right: 0.5rem !important; | |
| } | |
| /* Target Streamlit's main container */ | |
| section[data-testid="stSidebar"] { | |
| display: none !important; | |
| } | |
| /* Make all content use full width */ | |
| .main > div { | |
| max-width: 100% !important; | |
| padding-left: 0.5rem !important; | |
| padding-right: 0.5rem !important; | |
| } | |
| /* Reduce spacing between all elements */ | |
| .element-container { | |
| margin-bottom: 0.25rem !important; | |
| } | |
| /* Target the main content area specifically */ | |
| div[data-testid="stVerticalBlock"] { | |
| max-width: 100% !important; | |
| } | |
| /* Override any default margins */ | |
| .stMarkdown { | |
| margin-bottom: 0.25rem !important; | |
| } | |
| /* Make sure columns use full width */ | |
| .row-widget.stHorizontal { | |
| max-width: 100% !important; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Load data | |
| df = load_data() | |
| if df is None: | |
| return | |
| # Navigation | |
| col1, col2, col3 = st.columns([1, 2, 1]) | |
| with col1: | |
| if st.button("β Previous"): | |
| if 'trajectory_index' not in st.session_state: | |
| st.session_state.trajectory_index = 0 | |
| st.session_state.trajectory_index = max(0, st.session_state.trajectory_index - 1) | |
| with col2: | |
| # Initialize trajectory index | |
| if 'trajectory_index' not in st.session_state: | |
| st.session_state.trajectory_index = 0 | |
| # Trajectory selector | |
| trajectory_index = st.selectbox( | |
| "Select Trajectory:", | |
| range(len(df)), | |
| index=st.session_state.trajectory_index, | |
| format_func=lambda x: f"Trajectory {x} (Task ID: {df.iloc[x]['task_id']}, Reward: {df.iloc[x]['reward']})" | |
| ) | |
| st.session_state.trajectory_index = trajectory_index | |
| with col3: | |
| if st.button("Next β"): | |
| st.session_state.trajectory_index = min(len(df) - 1, st.session_state.trajectory_index + 1) | |
| # Display current trajectory info | |
| current_row = df.iloc[st.session_state.trajectory_index] | |
| data = current_row['oai_traj_format'] | |
| # Display metadata | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| st.metric("Task ID", data['task_id']) | |
| with col2: | |
| st.metric("Reward", data['reward']) | |
| with col3: | |
| st.metric("Messages", len(data['messages'])) | |
| st.markdown("---") | |
| # Display messages | |
| st.subheader("Conversation Trajectory") | |
| for i, message in enumerate(data['messages']): | |
| role = message.get('role', 'unknown') | |
| content = message.get('content') | |
| # Store message data in session state for tool calls | |
| st.session_state[f'message_{i}'] = message | |
| display_message(role, content, i) | |
| # Handle tool calls if present | |
| if 'tool_calls' in message: | |
| for j, tool_call in enumerate(message['tool_calls']): | |
| st.markdown(f""" | |
| <div style=" | |
| border-left: 4px solid #e67e22; | |
| padding-left: 20px; | |
| margin: 5px 0 5px 20px; | |
| background-color: #fdf6e3; | |
| padding: 10px; | |
| border-radius: 0 5px 5px 0; | |
| "> | |
| <h5 style=" | |
| color: #666; | |
| font-size: 12px; | |
| font-weight: bold; | |
| margin-bottom: 5px; | |
| text-transform: uppercase; | |
| ">TOOL CALL</h5> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Display tool call details | |
| tool_info = { | |
| "function": tool_call.get('function', {}), | |
| "id": tool_call.get('id', ''), | |
| "type": tool_call.get('type', '') | |
| } | |
| st.json(tool_info) | |
| # Handle tool call responses | |
| if role == 'tool': | |
| tool_call_id = message.get('tool_call_id', 'Unknown') | |
| name = message.get('name', 'Unknown') | |
| st.caption(f"Tool: {name} | Call ID: {tool_call_id}") | |
| if __name__ == "__main__": | |
| main() |