Spaces:
Sleeping
Sleeping
import datetime | |
import re | |
import socket | |
import requests | |
import dns.resolver | |
import ssl | |
from urllib.parse import urlparse, parse_qs | |
import whois | |
from tld import get_tld | |
import pandas as pd | |
import time | |
from googlesearch import search | |
from catboost import CatBoostClassifier | |
from lime.lime_tabular import LimeTabularExplainer | |
import logging | |
import asyncio | |
import aiohttp | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Feature extraction functions | |
def extract_url_features(url): | |
features = {} | |
try: | |
# Basic URL features | |
features['qty_dot_url'] = url.count('.') | |
features['qty_slash_url'] = url.count('/') | |
features['qty_at_url'] = url.count('@') | |
features['qty_space_url'] = url.count(' ') | |
features['qty_plus_url'] = url.count('+') | |
features['qty_dollar_url'] = url.count('$') | |
features['length_url'] = len(url) | |
features['qty_equal_url'] = url.count('=') | |
features['qty_asterisk_url'] = url.count('*') | |
features['qty_percent_url'] = url.count('%') | |
features['qty_exclamation_url'] = url.count('!') | |
features['qty_questionmark_url'] = url.count('?') | |
features['qty_tilde_url'] = url.count('~') | |
features['qty_hyphen_url'] = url.count('-') | |
features['qty_hashtag_url'] = url.count('#') | |
features['qty_underline_url'] = url.count('_') | |
tld_pattern = r'\.(com|org|net|gov|edu|io|xyz|info|biz|co|uk|us|ca|au|in|cn|jp|ru|de|fr|it|nl|es|ch|se|no|dk|fi|pl|tr|br|za|mx|kr|sg|my|th|hk|vn|ar|cl|pe|nz|il|pk)' | |
features['qty_tld_url'] = len(re.findall(tld_pattern, url, re.IGNORECASE)) | |
features['qty_tld_url'] == get_tld(url, as_object=True, fail_silently=False) | |
features['email_in_url'] = 1 if '@' in url else 0 | |
features['url_google_index'] = is_url_indexed(url) | |
features['url_shortened'] = 1 if 'bit.ly' in url or 'tinyurl.com' in url else 0 | |
features['qty_comma_url'] = url.count(',') | |
features['qty_and_url'] = url.count('&') | |
# Extract domain from URL | |
domain = url.split('/')[2] if '://' in url else url.split('/')[0] | |
# Domain features | |
features['qty_underline_domain'] = domain.count('_') | |
features['qty_equal_domain'] = domain.count('=') | |
features['qty_exclamation_domain'] = domain.count('!') | |
features['qty_comma_domain'] = domain.count(',') | |
features['qty_hashtag_domain'] = domain.count('#') | |
features['qty_vowels_domain'] = sum(1 for c in domain if c in 'aeiouAEIOU') | |
features['server_client_domain'] = 1 if 'server' in domain or 'client' in domain else 0 | |
features['qty_dot_domain'] = domain.count('.') | |
features['domain_in_ip'] = 1 if is_ip_address(domain) else 0 | |
features['domain_length'] = len(domain) | |
features['qty_hyphen_domain'] = domain.count('-') | |
features['time_domain_expiration'] = get_domain_time_features(domain)[1] | |
features['qty_percent_domain'] = domain.count('%') | |
features['qty_at_domain'] = domain.count('@') | |
features['domain_spf'] = get_spf_record(domain) | |
features['domain_google_index'] = is_domain_indexed(domain) | |
# Directory features | |
directory = url.split('/')[3] if len(url.split('/')) > 3 else "" | |
features['qty_underline_directory'] = directory.count('_') | |
features['qty_equal_directory'] = directory.count('=') | |
features['qty_exclamation_directory'] = directory.count('!') | |
features['qty_comma_directory'] = directory.count(',') | |
features['qty_hashtag_directory'] = directory.count('#') | |
features['directory_length'] = len(directory) | |
features['qty_space_directory'] = directory.count(' ') | |
features['qty_tilde_directory'] = directory.count('~') | |
features['qty_dollar_directory'] = directory.count('$') | |
features['qty_plus_directory'] = directory.count('+') | |
features['qty_and_directory'] = directory.count('&') | |
features['qty_slash_directory'] = directory.count('/') | |
features['qty_dot_directory'] = directory.count('.') | |
features['qty_asterisk_directory'] = directory.count('*') | |
features['qty_at_directory'] = directory.count('@') | |
features['qty_questionmark_directory'] = directory.count('?') | |
features['qty_hyphen_directory'] = directory.count('-') | |
features['qty_percent_directory'] = directory.count('%') | |
features['qty_equal_directory'] = directory.count('=') | |
# File features | |
file = url.split('/')[4] if len(url.split('/')) > 4 else "" | |
features['qty_underline_file'] = file.count('_') | |
features['qty_and_file'] = domain.count('&') | |
features['qty_dollar_file'] = domain.count('$') | |
features['qty_questionmark_file'] = domain.count('?') | |
features['qty_equal_file'] = file.count('=') | |
features['qty_slash_file'] = file.count('/') | |
features['qty_exclamation_file'] = file.count('!') | |
features['qty_comma_file'] = file.count(',') | |
features['qty_hashtag_file'] = file.count('#') | |
features['file_length'] = len(file) | |
features['qty_tilde_file'] = file.count('~') | |
features['qty_at_file'] = file.count('@') | |
features['qty_dot_file'] = file.count('.') | |
features['qty_space_file'] = file.count(' ') | |
features['qty_plus_file'] = file.count('+') | |
features['qty_asterisk_file'] = file.count('*') | |
features['qty_hyphen_file'] = file.count('-') | |
features['qty_underline_file'] = file.count('_') | |
features['qty_percent_file'] = file.count('%') | |
features['qty_equal_file'] = file.count('=') | |
# Parameters features | |
params = url.split('?')[1] if '?' in url else "" | |
features['qty_underline_params'] = params.count('_') | |
features['qty_equal_params'] = params.count('=') | |
features['qty_exclamation_params'] = params.count('!') | |
features['qty_comma_params'] = params.count(',') | |
features['qty_hashtag_params'] = params.count('#') | |
features['params_length'] = len(params) | |
features['qty_tilde_params'] = params.count('~') | |
features['qty_asterisk_params'] = params.count('*') | |
features['qty_space_params'] = params.count(' ') | |
features['qty_dollar_params'] = params.count('$') | |
features['qty_questionmark_params'] = params.count('?') | |
features['tld_present_params'] = 1 if get_tld(url, as_object=True, fail_silently=False) else 0 | |
features['qty_plus_params'] = params.count('+') | |
features['qty_at_params'] = params.count('@') | |
features['qty_params'] = url.count('?') | |
features['qty_and_params'] = params.count('&') | |
features['qty_hyphen_params'] = params.count('-') | |
features['qty_dot_params'] = params.count('.') | |
features['qty_percent_params'] = params.count('%') | |
features['qty_slash_params'] = params.count('/') | |
# Other features | |
features['email_in_url'] = 1 if '@' in url else 0 | |
features['asn_ip'] = get_asn(get_ip_from_url(url)) | |
features['qty_ip_resolved'] = get_resolved_ips(domain) | |
features['ttl_hostname'] = get_ttl(domain) | |
features['url_google_index'] = is_url_indexed(url) | |
# Extract domain time features and ensure timestamps | |
features['time_domain_activation'], features['time_domain_expiration'] = get_domain_time_features(domain) | |
# Convert activation time to a timestamp if it's a datetime object | |
if isinstance(features['time_domain_activation'], datetime): | |
features['time_domain_activation'] = features['time_domain_activation'].timestamp() | |
# Convert expiration time to a timestamp if it's a datetime object | |
if isinstance(features['time_domain_expiration'], datetime): | |
features['time_domain_expiration'] = features['time_domain_expiration'].timestamp() | |
try: | |
features['qty_redirects'] = len(requests.get(url, timeout=5).history) | |
except requests.exceptions.RequestException as e: | |
print(f"Error processing redirects for URL '{url}': {e}") | |
features['qty_redirects'] = -1 | |
features['qty_mx_servers'] = get_mx_record_count(domain) | |
features['qty_nameservers'] = get_nameserver_count(domain) | |
features['tls_ssl_certificate'] = get_tls_ssl_certificate(domain) | |
features['time_response'] = get_response_time(url) | |
except Exception as e: | |
print(f"Error extracting features for {url}: {e}") | |
for key in features.keys(): | |
features[key] = -1 | |
return features | |
# Function to count specific characters in a URL | |
def count_char_in_url(url, char): | |
try: | |
return url.count(char) | |
except Exception: | |
return -1 # Return -1 if unable to count characters | |
# Function to extract domain features from a URL | |
def extract_domain_features(url): | |
try: | |
domain = urlparse(url).netloc | |
domain_parts = domain.split('.') | |
tld_length = len(domain_parts[-1]) # Top-level domain length | |
return domain, tld_length | |
except Exception: | |
return -1, -1 # Return -1 if extraction fails | |
def is_domain_indexed(domain): | |
query = f"site:{domain}" | |
try: | |
results = list(search(query, num=1)) | |
return 1 if results else 0 | |
except Exception as e: | |
print(f"Error checking Google index for {domain}: {e}") | |
return -1 | |
def get_response_time(url): | |
try: | |
start_time = time.time() | |
response = requests.get(url, timeout=10) # 10-second timeout | |
end_time = time.time() | |
return end_time - start_time # Response time in seconds | |
except requests.exceptions.RequestException as e: | |
print(f"Error measuring response time for {url}: {e}") | |
return -1 # Return None if there's an error | |
def get_mx_record_count(domain): | |
try: | |
# Use dns.resolver.resolve instead of the deprecated query | |
answers = dns.resolver.resolve(domain, 'MX') | |
return len(answers) | |
except dns.resolver.NoAnswer: | |
# No MX records found for the domain | |
print(f"No MX records found for {domain}.") | |
return 0 | |
except (dns.resolver.NXDOMAIN, dns.resolver.Timeout, dns.resolver.NoNameservers) as e: | |
# Handle other DNS errors | |
print(f"Error fetching MX records for {domain}: {e}") | |
return -1 | |
def get_nameserver_count(domain): | |
try: | |
# Query the NS records for the domain | |
ns_records = dns.resolver.resolve(domain, 'NS') | |
return len(ns_records) # Return the count of NS records | |
except (dns.resolver.NoAnswer, dns.resolver.NXDOMAIN, dns.resolver.Timeout): | |
# Handle cases where no NS records are found or DNS query fails | |
print(f"No NS records found for {domain}.") | |
return 0 # Return 0 if no NS records exist | |
# Function to extract directory and file related features | |
def extract_path_features(url): | |
try: | |
parsed_url = urlparse(url) | |
path = parsed_url.path | |
return path | |
except Exception: | |
return -1 # Return -1 if extraction fails | |
# Function to extract query parameter related features | |
def extract_query_features(url): | |
try: | |
parsed_url = urlparse(url) | |
query_params = parse_qs(parsed_url.query) | |
return query_params | |
except Exception: | |
return -1 # Return -1 if extraction fails | |
# Function to check if the domain is an IP address format | |
def is_ip_address(domain): | |
try: | |
socket.inet_aton(domain) | |
return 1 # It's an IP address | |
except socket.error: | |
return 0 # It's not an IP address | |
# Function to get the time-related features | |
from datetime import datetime | |
def get_domain_time_features(domain): | |
try: | |
domain_info = whois.whois(domain) | |
activation_time = domain_info.creation_date | |
expiration_time = domain_info.expiration_date | |
# Ensure activation_time and expiration_time are valid datetime objects or None | |
if not isinstance(activation_time, (datetime, type(None))): | |
activation_time = None | |
if not isinstance(expiration_time, (datetime, type(None))): | |
expiration_time = None | |
return activation_time, expiration_time | |
except Exception as e: | |
print(f"Error fetching domain times for {domain}: {e}") | |
return -1, -1 | |
'''def get_domain_time_features(domain): | |
rdap_url = f"https://rdap.org/domain/{domain}" # RDAP public API endpoint | |
try: | |
response = requests.get(rdap_url, timeout=5) | |
if response.status_code == 200: | |
domain_data = response.json() | |
# Extract activation and expiration dates | |
activation_time = domain_data.get("events", [{}]) | |
creation_date = None | |
expiration_date = None | |
for event in activation_time: | |
if event.get("eventAction") == "registration": | |
creation_date = event.get("eventDate") | |
elif event.get("eventAction") == "expiration": | |
expiration_date = event.get("eventDate") | |
# Convert string dates to datetime objects | |
creation_date = datetime.fromisoformat(creation_date) if creation_date else 0 | |
expiration_date = datetime.fromisoformat(expiration_date) if expiration_date else 0 | |
return creation_date, expiration_date | |
elif response.status_code == 404: | |
# Domain does not exist | |
return 0, 0 | |
else: | |
# Failed to fetch data for other reasons | |
return -1, -1 | |
except Exception as e: | |
print(f"Error fetching domain times for {domain}: {e}") | |
return -1, -1''' | |
# Function to get SPF record | |
def get_spf_record(domain): | |
try: | |
txt_records = dns.resolver.resolve(domain, 'TXT') | |
for record in txt_records: | |
if "v=spf1" in str(record): | |
return 1 | |
return 0 | |
except Exception: | |
return -1 # Return -1 if SPF check fails | |
# Function to get TLS/SSL certificate | |
def get_tls_ssl_certificate(domain): | |
try: | |
context = ssl.create_default_context() | |
with socket.create_connection((domain, 443)) as sock: | |
with context.wrap_socket(sock, server_hostname=domain) as ssock: | |
cert = ssock.getpeercert() | |
return 1 # TLS/SSL certificate exists | |
except Exception: | |
return -1 # Return -1 if SSL check fails | |
# Function to get IP address from URL | |
def get_ip_from_url(url): | |
try: | |
domain = url.split('/')[2] if '://' in url else url.split('/')[0] | |
ip = socket.gethostbyname(domain) | |
return ip | |
except Exception: | |
return -1 # Return -1 if IP extraction fails | |
# Function to get ASN for an IP | |
def get_asn(ip): | |
if not ip or ip == -1: | |
return -1 # Return -1 if IP is invalid | |
try: | |
response = requests.get(f"https://ipinfo.io/{ip}/json") | |
data = response.json() | |
org = data.get("org", "Unknown ASN") | |
match = re.search(r'AS(\d+)', org) | |
return int(match.group(1)) if match else -1 | |
except Exception: | |
return -1 | |
# Function to get resolved IPs for a domain | |
def get_resolved_ips(domain): | |
try: | |
return len(socket.gethostbyname_ex(domain)[2]) | |
except Exception: | |
return -1 | |
# Function to get TTL value for a domain | |
def get_ttl(domain): | |
try: | |
answers = dns.resolver.resolve(domain, 'A') | |
return answers.rrset.ttl | |
except Exception: | |
return -1 | |
def is_url_indexed(url): | |
query = f"site:{url}" | |
try: | |
results = list(search(query, num=1)) | |
return 1 if results else 0 | |
except Exception as e: | |
print(f"Error checking if URL is indexed: {e}") | |
return -1 | |
def process_urls(urls): | |
url_features = [] | |
for url in urls: | |
if not (url.startswith("http://") or url.startswith("https://")): | |
url = "https://" + url | |
features = extract_url_features(url) | |
url_features.append(features) | |
return pd.DataFrame(url_features) | |
'''def predict_urls(urls, model_path): | |
features_df = process_urls(urls) | |
features_df.fillna(-1, inplace=True) | |
model = CatBoostClassifier() | |
model.load_model(model_path) | |
predictions = model.predict(features_df) | |
return predictions''' | |
'''def explain_prediction(features_df, model_path): | |
model = CatBoostClassifier() | |
model.load_model(model_path) | |
from lime.lime_tabular import LimeTabularExplainer | |
explainer = LimeTabularExplainer( | |
training_data=features_df.values, | |
feature_names=features_df.columns.tolist(), | |
class_names=["Legitimate", "Malicious"], | |
mode="classification" | |
) | |
explanation = explainer.explain_instance( | |
data_row=features_df.iloc[0].values, | |
predict_fn=model.predict_proba, | |
num_features=5 | |
) | |
explanation.show_in_notebook(show_table=True) | |
return explanation''' | |
# Async enhancements for faster processing (optional) | |
async def fetch_url(session, url): | |
try: | |
async with session.get(url, timeout=10) as response: | |
return len(response.history) | |
except Exception: | |
return -1 | |
async def process_urls_async(urls): | |
async with aiohttp.ClientSession() as session: | |
tasks = [fetch_url(session, url) for url in urls] | |
results = await asyncio.gather(*tasks) | |
return results | |