|
import os |
|
from typing import Dict, Any, Optional, List |
|
import nvdlib |
|
from smolagents.tools import Tool |
|
|
|
class VulnerabilitySearchTool(Tool): |
|
name = "vuln_search" |
|
description = "Search for vulnerabilities in NVD (National Vulnerability Database)" |
|
inputs = { |
|
'query': {'type': 'str', 'description': 'Search term or CVE ID'}, |
|
'max_results': {'type': 'int', 'description': 'Maximum number of results', 'default': 5} |
|
} |
|
output_type = Dict[str, Any] |
|
|
|
def __init__(self): |
|
"""Initialize NVD API connection""" |
|
self.nvd_api_key = os.getenv('NVD_API_KEY') |
|
|
|
|
|
if self.nvd_api_key: |
|
nvdlib.set_api_key(self.nvd_api_key) |
|
|
|
def search_nvd(self, query: str, max_results: int) -> List[Dict[str, Any]]: |
|
"""Search vulnerabilities in NVD""" |
|
try: |
|
|
|
if query.startswith('CVE-'): |
|
results = nvdlib.get_cve(query) |
|
return [{ |
|
'id': results.id, |
|
'description': results.descriptions[0].value, |
|
'severity': results.metrics.cvssMetricV31[0].cvssData.baseScore if results.metrics else None, |
|
'published': results.published, |
|
'references': [ref.url for ref in results.references] |
|
}] |
|
|
|
|
|
results = nvdlib.searchCVE( |
|
keyword=query, |
|
limit=max_results |
|
) |
|
|
|
return [{ |
|
'id': r.id, |
|
'description': r.descriptions[0].value, |
|
'severity': r.metrics.cvssMetricV31[0].cvssData.baseScore if r.metrics else None, |
|
'published': r.published, |
|
'references': [ref.url for ref in r.references] |
|
} for r in results] |
|
|
|
except Exception as e: |
|
return [{'error': f"Error in NVD search: {str(e)}"}] |
|
|
|
def forward(self, query: str, max_results: int = 5) -> Dict[str, Any]: |
|
"""Process search in NVD""" |
|
results = self.search_nvd(query, max_results) |
|
|
|
return { |
|
'query': query, |
|
'source': 'nvd', |
|
'results': results |
|
} |