File size: 6,577 Bytes
2a0bc63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import logging
import json
import sys
import tempfile
import shutil
from urllib.request import urlopen

_HAS_SSL = True
try:
    from ssl import SSLContext, PROTOCOL_TLS, CERT_REQUIRED
except:
    _HAS_SSL = False

from zipfile import ZipFile

# 2.7 vs 3.x
try:
    from zipfile import BadZipFile
except:
    from zipfile import BadZipfile as BadZipFile

from cassandra import DriverException

log = logging.getLogger(__name__)

__all__ = ['get_cloud_config']

DATASTAX_CLOUD_PRODUCT_TYPE = "DATASTAX_APOLLO"


class CloudConfig(object):

    username = None
    password = None
    host = None
    port = None
    keyspace = None
    local_dc = None
    ssl_context = None

    sni_host = None
    sni_port = None
    host_ids = None

    @classmethod
    def from_dict(cls, d):
        c = cls()

        c.port = d.get('port', None)
        try:
            c.port = int(d['port'])
        except:
            pass

        c.username = d.get('username', None)
        c.password = d.get('password', None)
        c.host = d.get('host', None)
        c.keyspace = d.get('keyspace', None)
        c.local_dc = d.get('localDC', None)

        return c


def get_cloud_config(cloud_config, create_pyopenssl_context=False):
    if not _HAS_SSL:
        raise DriverException("A Python installation with SSL is required to connect to a cloud cluster.")

    if 'secure_connect_bundle' not in cloud_config:
        raise ValueError("The cloud config doesn't have a secure_connect_bundle specified.")

    try:
        config = read_cloud_config_from_zip(cloud_config, create_pyopenssl_context)
    except BadZipFile:
        raise ValueError("Unable to open the zip file for the cloud config. Check your secure connect bundle.")

    config = read_metadata_info(config, cloud_config)
    if create_pyopenssl_context:
        config.ssl_context = config.pyopenssl_context
    return config


def read_cloud_config_from_zip(cloud_config, create_pyopenssl_context):
    secure_bundle = cloud_config['secure_connect_bundle']
    use_default_tempdir = cloud_config.get('use_default_tempdir', None)
    with ZipFile(secure_bundle) as zipfile:
        base_dir = tempfile.gettempdir() if use_default_tempdir else os.path.dirname(secure_bundle)
        tmp_dir = tempfile.mkdtemp(dir=base_dir)
        try:
            zipfile.extractall(path=tmp_dir)
            return parse_cloud_config(os.path.join(tmp_dir, 'config.json'), cloud_config, create_pyopenssl_context)
        finally:
            shutil.rmtree(tmp_dir)


def parse_cloud_config(path, cloud_config, create_pyopenssl_context):
    with open(path, 'r') as stream:
        data = json.load(stream)

    config = CloudConfig.from_dict(data)
    config_dir = os.path.dirname(path)

    if 'ssl_context' in cloud_config:
        config.ssl_context = cloud_config['ssl_context']
    else:
        # Load the ssl_context before we delete the temporary directory
        ca_cert_location = os.path.join(config_dir, 'ca.crt')
        cert_location = os.path.join(config_dir, 'cert')
        key_location = os.path.join(config_dir, 'key')
        # Regardless of if we create a pyopenssl context, we still need the builtin one
        # to connect to the metadata service
        config.ssl_context = _ssl_context_from_cert(ca_cert_location, cert_location, key_location)
        if create_pyopenssl_context:
            config.pyopenssl_context = _pyopenssl_context_from_cert(ca_cert_location, cert_location, key_location)

    return config


def read_metadata_info(config, cloud_config):
    url = "https://{}:{}/metadata".format(config.host, config.port)
    timeout = cloud_config['connect_timeout'] if 'connect_timeout' in cloud_config else 5
    try:
        response = urlopen(url, context=config.ssl_context, timeout=timeout)
    except Exception as e:
        log.exception(e)
        raise DriverException("Unable to connect to the metadata service at %s. "
                              "Check the cluster status in the cloud console. " % url)

    if response.code != 200:
        raise DriverException(("Error while fetching the metadata at: %s. "
                               "The service returned error code %d." % (url, response.code)))
    return parse_metadata_info(config, response.read().decode('utf-8'))


def parse_metadata_info(config, http_data):
    try:
        data = json.loads(http_data)
    except:
        msg = "Failed to load cluster metadata"
        raise DriverException(msg)

    contact_info = data['contact_info']
    config.local_dc = contact_info['local_dc']

    proxy_info = contact_info['sni_proxy_address'].split(':')
    config.sni_host = proxy_info[0]
    try:
        config.sni_port = int(proxy_info[1])
    except:
        config.sni_port = 9042

    config.host_ids = [host_id for host_id in contact_info['contact_points']]

    return config


def _ssl_context_from_cert(ca_cert_location, cert_location, key_location):
    ssl_context = SSLContext(PROTOCOL_TLS)
    ssl_context.load_verify_locations(ca_cert_location)
    ssl_context.verify_mode = CERT_REQUIRED
    ssl_context.load_cert_chain(certfile=cert_location, keyfile=key_location)

    return ssl_context


def _pyopenssl_context_from_cert(ca_cert_location, cert_location, key_location):
    try:
        from OpenSSL import SSL
    except ImportError as e:
        raise ImportError(
            "PyOpenSSL must be installed to connect to Astra with the Eventlet or Twisted event loops")\
            .with_traceback(e.__traceback__)
    ssl_context = SSL.Context(SSL.TLSv1_METHOD)
    ssl_context.set_verify(SSL.VERIFY_PEER, callback=lambda _1, _2, _3, _4, ok: ok)
    ssl_context.use_certificate_file(cert_location)
    ssl_context.use_privatekey_file(key_location)
    ssl_context.load_verify_locations(ca_cert_location)

    return ssl_context