File size: 2,317 Bytes
eea2f4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import os
from typing import Dict, Any, Optional, List
import nvdlib
from smolagents.tools import Tool

class VulnerabilitySearchTool(Tool):
    name = "vuln_search"
    description = "Search for vulnerabilities in NVD (National Vulnerability Database)"
    inputs = {
        'query': {'type': 'str', 'description': 'Search term or CVE ID'},
        'max_results': {'type': 'int', 'description': 'Maximum number of results', 'default': 5}
    }
    output_type = Dict[str, Any]

    def __init__(self):
        """Initialize NVD API connection"""
        self.nvd_api_key = os.getenv('NVD_API_KEY')
        
        # Configure NVD
        if self.nvd_api_key:
            nvdlib.set_api_key(self.nvd_api_key)

    def search_nvd(self, query: str, max_results: int) -> List[Dict[str, Any]]:
        """Search vulnerabilities in NVD"""
        try:
            # If query looks like a CVE-ID, search directly
            if query.startswith('CVE-'):
                results = nvdlib.get_cve(query)
                return [{
                    'id': results.id,
                    'description': results.descriptions[0].value,
                    'severity': results.metrics.cvssMetricV31[0].cvssData.baseScore if results.metrics else None,
                    'published': results.published,
                    'references': [ref.url for ref in results.references]
                }]
            
            # Otherwise, perform general search
            results = nvdlib.searchCVE(
                keyword=query,
                limit=max_results
            )
            
            return [{
                'id': r.id,
                'description': r.descriptions[0].value,
                'severity': r.metrics.cvssMetricV31[0].cvssData.baseScore if r.metrics else None,
                'published': r.published,
                'references': [ref.url for ref in r.references]
            } for r in results]
        
        except Exception as e:
            return [{'error': f"Error in NVD search: {str(e)}"}]

    def forward(self, query: str, max_results: int = 5) -> Dict[str, Any]:
        """Process search in NVD"""
        results = self.search_nvd(query, max_results)
        
        return {
            'query': query,
            'source': 'nvd',
            'results': results
        }