|
import os |
|
import json |
|
import yaml |
|
from dotenv import load_dotenv |
|
import gradio as gr |
|
from smolagents import CodeAgent |
|
from smolagents.models import HfApiModel |
|
from tools.final_answer import FinalAnswerTool |
|
from tools.web_search import DuckDuckGoSearchTool |
|
from tools.visit_webpage import VisitWebpageTool |
|
from tools.vuln_search import VulnerabilitySearchTool |
|
|
|
|
|
load_dotenv() |
|
|
|
def load_agent_config(): |
|
"""Load agent configuration from agent.json""" |
|
with open('agent.json', 'r') as f: |
|
return json.load(f) |
|
|
|
def load_prompts(): |
|
"""Load prompt templates from prompts.yaml""" |
|
with open('prompts.yaml', 'r') as f: |
|
return yaml.safe_load(f) |
|
|
|
def initialize_tools(): |
|
"""Initialize agent tools""" |
|
tools = { |
|
'final_answer': FinalAnswerTool(), |
|
'web_search': DuckDuckGoSearchTool(), |
|
'visit_webpage': VisitWebpageTool(), |
|
'vuln_search': VulnerabilitySearchTool() |
|
} |
|
return tools |
|
|
|
def create_agent(): |
|
"""Create and configure the vulnerability agent""" |
|
config = load_agent_config() |
|
prompts = load_prompts() |
|
|
|
|
|
model_config = config['agent_config']['model'] |
|
model = HfApiModel( |
|
model_id=model_config['model_id'], |
|
max_tokens=model_config['max_tokens'], |
|
temperature=model_config['temperature'] |
|
) |
|
|
|
|
|
tools = initialize_tools() |
|
|
|
|
|
agent = CodeAgent( |
|
model=model, |
|
tools=tools, |
|
max_steps=config['agent_config']['max_steps'], |
|
verbosity_level=config['agent_config']['verbosity_level'] |
|
) |
|
|
|
return agent, prompts |
|
|
|
def process_query(query, analysis_type="general"): |
|
"""Process a user query""" |
|
agent, prompts = create_agent() |
|
|
|
|
|
if analysis_type == "vulnerability": |
|
template = prompts['vulnerability_analysis'] |
|
formatted_prompt = template.format(cve_id=query) |
|
elif analysis_type == "threat": |
|
template = prompts['threat_report'] |
|
formatted_prompt = template.format(target=query) |
|
else: |
|
template = prompts['user_prompt'] |
|
formatted_prompt = template.format(query=query) |
|
|
|
|
|
system_prompt = prompts['system_prompt'] |
|
result = agent.run(formatted_prompt, system_prompt=system_prompt) |
|
|
|
return result |
|
|
|
|
|
def create_interface(): |
|
"""Create the Gradio user interface""" |
|
with gr.Blocks(title="Vulnerability Intelligence Agent") as interface: |
|
gr.Markdown("# Vulnerability Intelligence Agent (VIA)") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
query_input = gr.Textbox( |
|
label="Query", |
|
placeholder="Enter your security query..." |
|
) |
|
analysis_type = gr.Radio( |
|
choices=["general", "vulnerability", "threat"], |
|
label="Analysis Type", |
|
value="general" |
|
) |
|
submit_btn = gr.Button("Analyze") |
|
|
|
with gr.Column(): |
|
output = gr.Markdown(label="Result") |
|
|
|
submit_btn.click( |
|
fn=process_query, |
|
inputs=[query_input, analysis_type], |
|
outputs=output |
|
) |
|
|
|
return interface |
|
|
|
if __name__ == "__main__": |
|
interface = create_interface() |
|
interface.launch() |