Phoenix21's picture
Update app.py
bcdbab3 verified
#!/usr/bin/env python3
"""
Code Flow Analyzer with Gradio Interface - Hugging Face Spaces & Colab Compatible
A single-file application that uses LangChain agents with the Gemini model to analyze code structure
and generate Mermaid.js flowchart diagrams through a web interface.
"""
import ast
import re
import os
import traceback
import sys
from typing import Dict, Any, List, Tuple
import getpass
# Check if running in Colab
try:
import google.colab
IN_COLAB = True
print("🟒 Running in Google Colab")
except ImportError:
IN_COLAB = False
print("🟑 Running locally or in Hugging Face Spaces")
# Install dependencies if in Colab
if IN_COLAB:
print("πŸ“¦ Installing dependencies...")
os.system("pip install -q gradio langchain langgraph langchain-google-genai")
print("βœ… Dependencies installed")
import gradio as gr
from langchain.chat_models import init_chat_model
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.tools import tool
from langgraph.prebuilt import create_react_agent
from langgraph.checkpoint.memory import MemorySaver
# Sample code examples (unchanged)
SAMPLE_PYTHON = '''def main():
user_input = get_user_input()
if user_input:
result = process_data(user_input)
if result > 0:
display_result(result)
else:
show_error()
else:
show_help()
def get_user_input():
return input("Enter data: ")
def process_data(data):
for i in range(len(data)):
if data[i].isdigit():
return int(data[i])
return -1
def display_result(result):
print(f"Result: {result}")
def show_error():
print("Error processing data")
def show_help():
print("Please provide valid input")'''
SAMPLE_JAVASCRIPT = '''function calculateTotal(items) {
let total = 0;
for (let item of items) {
if (item.price > 0) {
total += item.price;
}
}
return total;
}
function processOrder(order) {
if (validateOrder(order)) {
const total = calculateTotal(order.items);
if (total > 100) {
applyDiscount(order);
}
return generateReceipt(order);
} else {
throw new Error("Invalid order");
}
}
function validateOrder(order) {
return order && order.items && order.items.length > 0;
}
function applyDiscount(order) {
order.discount = 0.1; // 10% discount
}
function generateReceipt(order) {
return {
items: order.items,
total: calculateTotal(order.items),
timestamp: new Date()
};
}'''
SAMPLE_JAVA = '''public class Calculator {
public static void main(String[] args) {
Calculator calc = new Calculator();
int result = calc.performCalculation();
calc.displayResult(result);
}
public int performCalculation() {
int a = getUserInput();
int b = getUserInput();
if (a > b) {
return multiply(a, b);
} else {
return add(a, b);
}
}
private int add(int x, int y) {
return x + y;
}
private int multiply(int x, int y) {
return x * y;
}
private int getUserInput() {
return 5; // Simplified for demo
}
private void displayResult(int result) {
System.out.println("Result: " + result);
}
}'''
# --- Gemini API Key Setup ---
def setup_api_key():
"""Setup API key for Colab, Hugging Face Spaces, and local environments"""
api_key = os.getenv("GOOGLE_API_KEY")
if not api_key:
if IN_COLAB:
print("πŸ”‘ Please enter your Google API key:")
print(" Get a key from: https://aistudio.google.com/app/apikey")
api_key = getpass.getpass("GOOGLE_API_KEY: ")
if api_key:
os.environ["GOOGLE_API_KEY"] = api_key
print("βœ… API key set successfully")
else:
print("⚠️ No API key provided - agent features will be disabled")
else:
print("⚠️ GOOGLE_API_KEY not found in environment variables")
print(" Set it with: export GOOGLE_API_KEY='your-key-here'")
print(" In Hugging Face Spaces, use the 'Secrets' tab to set the key.")
else:
print("βœ… Google API key found")
return api_key or os.getenv("GOOGLE_API_KEY")
# Setup API key
api_key = setup_api_key()
# Initialize LangChain components
model = None
memory = None
agent_executor = None
if api_key:
try:
model = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
print("βœ… Gemini model initialized successfully:gemini-2.0-flash")
memory = MemorySaver()
except Exception as e:
print(f"❌ Could not initialize Gemini model: {e}")
print(" Please check your API key and internet connection.")
model = None
memory = None
# --- Tool Definitions (unchanged) ---
@tool
def analyze_code_structure(source_code: str) -> Dict[str, Any]:
"""
Analyzes source code structure to identify functions, control flow, and dependencies.
Returns structured data about the code that can be used to generate flow diagrams.
"""
try:
# Try to parse as Python first
try:
tree = ast.parse(source_code)
return _analyze_python_ast(tree)
except SyntaxError:
# If Python parsing fails, do basic text analysis
return _analyze_code_text(source_code)
except Exception as e:
return {"error": f"Analysis error: {str(e)}"}
def _analyze_python_ast(tree) -> Dict[str, Any]:
"""Analyze Python AST"""
analysis = {
"functions": [],
"classes": [],
"control_flows": [],
"imports": [],
"call_graph": {}
}
class CodeAnalyzer(ast.NodeVisitor):
def __init__(self):
self.current_function = None
def visit_FunctionDef(self, node):
func_info = {
"name": node.name,
"line": node.lineno,
"args": [arg.arg for arg in node.args.args],
"calls": [],
"conditions": [],
"loops": []
}
self.current_function = node.name
analysis["call_graph"][node.name] = []
# Analyze function body
for child in ast.walk(node):
if isinstance(child, ast.Call):
if hasattr(child.func, 'id'):
func_info["calls"].append(child.func.id)
analysis["call_graph"][node.name].append(child.func.id)
elif hasattr(child.func, 'attr'):
func_info["calls"].append(child.func.attr)
elif isinstance(child, ast.If):
func_info["conditions"].append(f"if condition at line {child.lineno}")
elif isinstance(child, (ast.For, ast.While)):
loop_type = "for" if isinstance(child, ast.For) else "while"
func_info["loops"].append(f"{loop_type} loop at line {child.lineno}")
analysis["functions"].append(func_info)
self.generic_visit(node)
def visit_ClassDef(self, node):
class_info = {
"name": node.name,
"line": node.lineno,
"methods": []
}
for item in node.body:
if isinstance(item, ast.FunctionDef):
class_info["methods"].append(item.name)
analysis["classes"].append(class_info)
self.generic_visit(node)
def visit_Import(self, node):
for alias in node.names:
analysis["imports"].append(alias.name)
self.generic_visit(node)
def visit_ImportFrom(self, node):
module = node.module or ""
for alias in node.names:
analysis["imports"].append(f"{module}.{alias.name}")
self.generic_visit(node)
analyzer = CodeAnalyzer()
analyzer.visit(tree)
return analysis
def _analyze_code_text(source_code: str) -> Dict[str, Any]:
"""Basic text-based code analysis for non-Python code"""
lines = source_code.split('\n')
analysis = {
"functions": [],
"classes": [],
"control_flows": [],
"imports": [],
"call_graph": {}
}
for i, line in enumerate(lines, 1):
line = line.strip()
# JavaScript function detection
js_func_match = re.match(r'function\s+(\w+)\s*\(', line)
if js_func_match:
func_name = js_func_match.group(1)
analysis["functions"].append({
"name": func_name,
"line": i,
"args": [],
"calls": [],
"conditions": [],
"loops": []
})
analysis["call_graph"][func_name] = []
# Java/C++ method detection
java_method_match = re.match(r'(?:public|private|protected)?\s*(?:static)?\s*\w+\s+(\w+)\s*\(', line)
if java_method_match and not js_func_match:
func_name = java_method_match.group(1)
if func_name not in ['class', 'if', 'for', 'while']: # Avoid keywords
analysis["functions"].append({
"name": func_name,
"line": i,
"args": [],
"calls": [],
"conditions": [],
"loops": []
})
analysis["call_graph"][func_name] = []
# Control structures
if re.match(r'\s*(if|else|elif|switch)\s*[\(\{]', line):
analysis["control_flows"].append(f"condition at line {i}")
if re.match(r'\s*(for|while|do)\s*[\(\{]', line):
analysis["control_flows"].append(f"loop at line {i}")
return analysis
@tool
def generate_mermaid_diagram(analysis_data: Dict[str, Any]) -> str:
"""
Generates a Mermaid.js flowchart diagram from code analysis data.
Creates a visual representation of the code flow including function calls and control structures.
"""
if "error" in analysis_data:
return f"flowchart TD\n Error[❌ {analysis_data['error']}]"
functions = analysis_data.get("functions", [])
call_graph = analysis_data.get("call_graph", {})
if not functions:
return """flowchart TD
Start([πŸš€ Program Start]) --> NoFunc[No Functions Found]
NoFunc --> End([🏁 Program End])
classDef startEnd fill:#e1f5fe,stroke:#01579b,stroke-width:2px
classDef warning fill:#fff3e0,stroke:#e65100,stroke-width:2px
class Start,End startEnd
class NoFunc warning"""
mermaid_lines = ["flowchart TD"]
mermaid_lines.append(" Start([πŸš€ Program Start]) --> Main")
# Create nodes for each function
func_nodes = []
for i, func in enumerate(functions):
func_name = func["name"]
safe_name = re.sub(r'[^a-zA-Z0-9_]', '_', func_name)
node_id = f"F{i}_{safe_name}"
func_nodes.append(node_id)
# Function node with emoji
mermaid_lines.append(f" {node_id}[βš™οΈ {func_name}()]")
# Add control structures within function
conditions = func.get("conditions", [])
loops = func.get("loops", [])
if conditions:
for j, condition in enumerate(conditions[:2]): # Limit to 2 conditions per function
cond_id = f"{node_id}_C{j}"
mermaid_lines.append(f" {node_id} --> {cond_id}{{πŸ€” Decision}}")
mermaid_lines.append(f" {cond_id} -->|Yes| {node_id}_Y{j}[βœ… True Path]")
mermaid_lines.append(f" {cond_id} -->|No| {node_id}_N{j}[❌ False Path]")
if loops:
for j, loop in enumerate(loops[:1]): # Limit to 1 loop per function
loop_id = f"{node_id}_L{j}"
loop_type = "πŸ”„ Loop" if "for" in loop else "⏰ While Loop"
mermaid_lines.append(f" {node_id} --> {loop_id}[{loop_type}]")
mermaid_lines.append(f" {loop_id} --> {loop_id}") # Self-loop
# Connect main flow
if func_nodes:
mermaid_lines.append(f" Main --> {func_nodes[0]}")
# Connect functions in sequence (simplified)
for i in range(len(func_nodes) - 1):
mermaid_lines.append(f" {func_nodes[i]} --> {func_nodes[i + 1]}")
# Connect to end
mermaid_lines.append(f" {func_nodes[-1]} --> End([🏁 Program End])")
# Add function call relationships (simplified to avoid clutter)
call_count = 0
for caller, callees in call_graph.items():
if call_count >= 3: # Limit number of call relationships
break
caller_node = None
for node in func_nodes:
if caller.lower() in node.lower():
caller_node = node
break
if caller_node:
for callee in callees[:2]: # Limit callees per function
callee_node = None
for node in func_nodes:
if callee.lower() in node.lower():
callee_node = node
break
if callee_node and callee_node != caller_node:
mermaid_lines.append(f" {caller_node} -.->|calls| {callee_node}")
call_count += 1
# Add styling
mermaid_lines.extend([
"",
" classDef startEnd fill:#e1f5fe,stroke:#01579b,stroke-width:3px,color:#000",
" classDef process fill:#f3e5f5,stroke:#4a148c,stroke-width:2px,color:#000",
" classDef decision fill:#fff3e0,stroke:#e65100,stroke-width:2px,color:#000",
" classDef success fill:#e8f5e8,stroke:#2e7d32,stroke-width:2px,color:#000",
" classDef error fill:#ffebee,stroke:#c62828,stroke-width:2px,color:#000",
"",
" class Start,End startEnd",
f" class {','.join(func_nodes)} process" if func_nodes else ""
])
return "\n".join(mermaid_lines)
@tool
def calculate_complexity_score(analysis_data: Dict[str, Any]) -> int:
"""
Calculates a complexity score for the code based on various metrics.
Higher scores indicate more complex code structure.
"""
if "error" in analysis_data:
return 0
score = 0
functions = analysis_data.get("functions", [])
# Base score for number of functions
score += len(functions) * 3
# Add score for control structures
for func in functions:
score += len(func.get("conditions", [])) * 4 # Conditions add complexity
score += len(func.get("loops", [])) * 3 # Loops add complexity
score += len(func.get("calls", [])) * 1 # Function calls add some complexity
score += len(func.get("args", [])) * 1 # Parameters add complexity
# Add score for classes
score += len(analysis_data.get("classes", [])) * 5
return min(score, 100) # Cap at 100
# Create the agent if model is available
if model and memory:
tools = [analyze_code_structure, generate_mermaid_diagram, calculate_complexity_score]
agent_executor = create_react_agent(model, tools, checkpointer=memory)
print("βœ… LangChain agent created successfully")
else:
agent_executor = None
print("❌ LangChain agent not available")
def analyze_code_with_agent(source_code: str, language: str = "auto") -> Tuple[str, str, List[str], int, str]:
"""
Main function that uses the LangChain agent to analyze code and generate diagrams.
Returns: (mermaid_diagram, analysis_summary, functions_found, complexity_score, error_message)
"""
if not source_code.strip():
return "", "No code provided", [], 0, "Please enter some source code to analyze"
if not agent_executor:
return "", "Agent not available", [], 0, "❌ LangChain agent not initialized. Please check your GOOGLE_API_KEY"
try:
# Detect language if auto
if language == "auto":
if "def " in source_code or "import " in source_code:
language = "Python"
elif "function " in source_code or "const " in source_code or "let " in source_code:
language = "JavaScript"
elif ("public " in source_code and "class " in source_code) or "System.out" in source_code:
language = "Java"
elif "#include" in source_code or "std::" in source_code:
language = "C++"
else:
language = "Unknown"
config = {
"configurable": {"thread_id": f"session_{hash(source_code) % 10000}"},
"recursion_limit": 100
}
# Refined prompt for better tool use
prompt = f"""
You are a code analysis expert. Analyze the following {language} source code.
Your task is to:
1. Use the 'analyze_code_structure' tool with the full source code provided below.
2. Use the 'generate_mermaid_diagram' tool with the output of the first tool.
3. Use the 'calculate_complexity_score' tool with the output of the first tool.
4. Provide a brief, human-readable summary of the analysis, including the generated Mermaid diagram, complexity score, and a list of functions found.
5. Present the final result in a clear, easy-to-read format.
Source Code to Analyze:
```{language.lower()}
{source_code}
```
"""
result = agent_executor.invoke(
{"messages": [{"role": "user", "content": prompt}]},
config
)
if result and "messages" in result:
response_content = result["messages"][-1].content
# Extract Mermaid diagram
mermaid_match = re.search(r'```mermaid\n(.*?)\n```', response_content, re.DOTALL)
mermaid_diagram = mermaid_match.group(1) if mermaid_match else ""
# Extract complexity score
complexity_match = re.search(r'complexity.*?(\d+)', response_content, re.IGNORECASE)
complexity_score = int(complexity_match.group(1)) if complexity_match else 0
# Extract functions
functions_found = []
func_matches = re.findall(r'Functions found:.*?([^\n]+)', response_content, re.IGNORECASE)
if func_matches:
functions_found = [f.strip() for f in func_matches[0].split(',')]
else:
# Fallback: extract from analysis
analysis_result = analyze_code_structure.invoke({"source_code": source_code})
functions_found = [f["name"] for f in analysis_result.get("functions", [])]
# Clean up the response for summary
summary = re.sub(r'```mermaid.*?```', '', response_content, flags=re.DOTALL)
summary = re.sub(r'flowchart TD.*?(?=\n\n|\Z)', '', summary, flags=re.DOTALL)
summary = summary.strip()
if not mermaid_diagram and not summary:
# Last resort fallback if agent fails entirely
analysis_result = analyze_code_structure.invoke({"source_code": source_code})
mermaid_diagram = generate_mermaid_diagram.invoke({"analysis_data": analysis_result})
complexity_score = calculate_complexity_score.invoke({"analysis_data": analysis_result})
functions_found = [f["name"] for f in analysis_result.get("functions", [])]
summary = "Agent failed to provide a detailed summary, but a fallback analysis was successful."
return mermaid_diagram, summary, functions_found, complexity_score, ""
except Exception as e:
error_msg = f"❌ Analysis failed: {str(e)}"
print(f"Error details: {traceback.format_exc()}")
return "", "", [], 0, error_msg
# --- Gradio Interface Setup (unchanged) ---
def create_gradio_interface():
"""Create and configure the Gradio interface"""
def analyze_code_gradio(code, language):
"""Wrapper function for Gradio interface"""
if not code.strip():
return (
"Please enter some code to analyze",
"",
"No analysis performed",
"Functions: 0 | Complexity: 0/100",
""
)
mermaid, summary, functions, complexity, error = analyze_code_with_agent(code, language)
if error:
return (
error,
"",
"Analysis failed",
"Functions: 0 | Complexity: 0/100",
""
)
# Format the outputs
mermaid_display = f"```mermaid\n{mermaid}\n```" if mermaid else "No diagram generated"
functions_display = f"**Functions Found:** {', '.join(functions)}" if functions else "No functions detected"
stats_display = f"Functions: {len(functions)} | Complexity: {complexity}/100"
return (
"βœ… Analysis completed successfully!",
mermaid_display,
summary,
stats_display,
functions_display
)
# Define the interface
with gr.Blocks(
title="πŸ”„ Code Flow Analyzer",
theme=gr.themes.Soft(),
css="""
.gradio-container {
max-width: 1200px !important;
}
.code-input {
font-family: 'Monaco', 'Menlo', 'Ubuntu Mono', monospace !important;
}
"""
) as interface:
gr.Markdown("""
# πŸ”„ Code Flow Analyzer
**LangChain Agent + Mermaid.js** β€’ Visualize Your Code Flow
This tool uses AI agents to analyze your source code and generate visual flowchart diagrams.
""")
# API Status
model_info = ""
if agent_executor and model:
model_info = " (Gemini LLM)"
api_status = f"🟒 Gemini LangChain Agent Ready{model_info}" if agent_executor else "πŸ”΄ Agent Not Available (Check GOOGLE_API_KEY)"
gr.Markdown(f"**Status:** {api_status}")
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### πŸ“ Source Code Input")
language_dropdown = gr.Dropdown(
choices=["auto", "Python", "JavaScript", "Java", "C++", "Other"],
value="auto",
label="Programming Language",
info="Auto-detection usually works well"
)
code_input = gr.TextArea(
placeholder="Paste your source code here...",
lines=15,
label="Source Code",
elem_classes=["code-input"]
)
with gr.Row():
gr.Examples(
examples=[
[SAMPLE_PYTHON, "Python"],
[SAMPLE_JAVASCRIPT, "JavaScript"],
[SAMPLE_JAVA, "Java"]
],
inputs=[code_input, language_dropdown],
label="Quick Examples"
)
analyze_btn = gr.Button(
"πŸš€ Analyze Code Flow",
variant="primary",
size="lg"
)
with gr.Column(scale=1):
gr.Markdown("### πŸ“Š Analysis Results")
status_output = gr.Textbox(
label="Status",
interactive=False
)
stats_output = gr.Textbox(
label="Statistics",
interactive=False
)
functions_output = gr.Markdown(
label="Functions Found"
)
with gr.Row():
with gr.Column():
gr.Markdown("### 🎨 Generated Mermaid Diagram")
mermaid_output = gr.Textbox(
label="Mermaid Code",
lines=15,
max_lines=20,
interactive=True,
show_copy_button=True
)
gr.Markdown("""
**πŸ’‘ How to visualize:**
1. Copy the Mermaid code above
2. Visit [mermaid.live](https://mermaid.live)
3. Paste and see your code flow diagram!
**πŸ“± For Colab users:**
- The Mermaid code above shows your program's flow structure
- Copy it to mermaid.live for a beautiful visual diagram
- Try the examples above to see different code patterns
""")
with gr.Row():
with gr.Column():
gr.Markdown("### πŸ“‹ Analysis Summary")
summary_output = gr.Textbox(
label="AI Agent Analysis",
lines=8,
interactive=False
)
# Connect the analyze button
analyze_btn.click(
fn=analyze_code_gradio,
inputs=[code_input, language_dropdown],
outputs=[status_output, mermaid_output, summary_output, stats_output, functions_output]
)
# Footer
environment_info = "Google Colab" if IN_COLAB else "Hugging Face Spaces or Local Environment"
gr.Markdown(f"""
---
**πŸ› οΈ Running in:** {environment_info}
**πŸ“¦ Dependencies:** gradio, langchain, langgraph, langchain-google-genai
**πŸ”§ Powered by:** LangChain Agents, Google Gemini, Mermaid.js, Gradio
**πŸ†“ Get Google API Key:** [aistudio.google.com/app/apikey](https://aistudio.google.com/app/apikey)
""")
return interface
def main():
"""Main function to run the application"""
print("πŸ”„ Code Flow Analyzer with Gradio")
print("=" * 50)
print(f"🌐 Environment: {'Google Colab' if IN_COLAB else 'Hugging Face Spaces or Local'}")
if agent_executor:
print("βœ… LangChain agent ready")
else:
print("❌ LangChain agent not available")
if IN_COLAB:
print(" πŸ’‘ Restart this cell and enter your GOOGLE_API_KEY when prompted")
else:
print(" πŸ’‘ Please set your GOOGLE_API_KEY as an environment variable or secret")
print("\nπŸš€ Starting Gradio interface...")
# Create and launch the interface
interface = create_gradio_interface()
# Launch configuration for Colab vs local/Spaces
interface.launch(
share=True if IN_COLAB else False,
debug=False,
height=600,
show_error=True
)
# Auto-run if in Colab or when script is executed directly
if __name__ == "__main__" or IN_COLAB:
main()