Spaces:
Running
Running
File size: 8,030 Bytes
2a0bc63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 |
# Copyright DataStax, Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from cassandra.graph import SimpleGraphStatement, GraphProtocol
from cassandra.cluster import EXEC_PROFILE_GRAPH_DEFAULT
from gremlin_python.process.graph_traversal import GraphTraversal
from gremlin_python.structure.io.graphsonV2d0 import GraphSONWriter as GraphSONWriterV2
from gremlin_python.structure.io.graphsonV3d0 import GraphSONWriter as GraphSONWriterV3
from cassandra.datastax.graph.fluent.serializers import GremlinUserTypeIO, \
dse_graphson2_serializers, dse_graphson3_serializers
log = logging.getLogger(__name__)
__all__ = ['TraversalBatch', '_query_from_traversal', '_DefaultTraversalBatch']
class _GremlinGraphSONWriterAdapter(object):
def __init__(self, context, **kwargs):
super(_GremlinGraphSONWriterAdapter, self).__init__(**kwargs)
self.context = context
self.user_types = None
def serialize(self, value, _):
return self.toDict(value)
def get_serializer(self, value):
serializer = None
try:
serializer = self.serializers[type(value)]
except KeyError:
for key, ser in self.serializers.items():
if isinstance(value, key):
serializer = ser
if self.context:
# Check if UDT
if self.user_types is None:
try:
user_types = self.context['cluster']._user_types[self.context['graph_name']]
self.user_types = dict(map(reversed, user_types.items()))
except KeyError:
self.user_types = {}
# Custom detection to map a namedtuple to udt
if (tuple in self.serializers and serializer is self.serializers[tuple] and hasattr(value, '_fields') or
(not serializer and type(value) in self.user_types)):
serializer = GremlinUserTypeIO
if serializer:
try:
# A serializer can have specialized serializers (e.g for Int32 and Int64, so value dependant)
serializer = serializer.get_specialized_serializer(value)
except AttributeError:
pass
return serializer
def toDict(self, obj):
serializer = self.get_serializer(obj)
return serializer.dictify(obj, self) if serializer else obj
def definition(self, value):
serializer = self.get_serializer(value)
return serializer.definition(value, self)
class GremlinGraphSON2Writer(_GremlinGraphSONWriterAdapter, GraphSONWriterV2):
pass
class GremlinGraphSON3Writer(_GremlinGraphSONWriterAdapter, GraphSONWriterV3):
pass
graphson2_writer = GremlinGraphSON2Writer
graphson3_writer = GremlinGraphSON3Writer
def _query_from_traversal(traversal, graph_protocol, context=None):
"""
From a GraphTraversal, return a query string.
:param traversal: The GraphTraversal object
:param graphson_protocol: The graph protocol to determine the output format.
"""
if graph_protocol == GraphProtocol.GRAPHSON_2_0:
graphson_writer = graphson2_writer(context, serializer_map=dse_graphson2_serializers)
elif graph_protocol == GraphProtocol.GRAPHSON_3_0:
if context is None:
raise ValueError('Missing context for GraphSON3 serialization requires.')
graphson_writer = graphson3_writer(context, serializer_map=dse_graphson3_serializers)
else:
raise ValueError('Unknown graph protocol: {}'.format(graph_protocol))
try:
query = graphson_writer.writeObject(traversal)
except Exception:
log.exception("Error preparing graphson traversal query:")
raise
return query
class TraversalBatch(object):
"""
A `TraversalBatch` is used to execute multiple graph traversals in a
single transaction. If any traversal in the batch fails, the entire
batch will fail to apply.
If a TraversalBatch is bounded to a DSE session, it can be executed using
`traversal_batch.execute()`.
"""
_session = None
_execution_profile = None
def __init__(self, session=None, execution_profile=None):
"""
:param session: (Optional) A DSE session
:param execution_profile: (Optional) The execution profile to use for the batch execution
"""
self._session = session
self._execution_profile = execution_profile
def add(self, traversal):
"""
Add a traversal to the batch.
:param traversal: A gremlin GraphTraversal
"""
raise NotImplementedError()
def add_all(self, traversals):
"""
Adds a sequence of traversals to the batch.
:param traversals: A sequence of gremlin GraphTraversal
"""
raise NotImplementedError()
def execute(self):
"""
Execute the traversal batch if bounded to a `DSE Session`.
"""
raise NotImplementedError()
def as_graph_statement(self, graph_protocol=GraphProtocol.GRAPHSON_2_0):
"""
Return the traversal batch as GraphStatement.
:param graph_protocol: The graph protocol for the GraphSONWriter. Default is GraphProtocol.GRAPHSON_2_0.
"""
raise NotImplementedError()
def clear(self):
"""
Clear a traversal batch for reuse.
"""
raise NotImplementedError()
def __len__(self):
raise NotImplementedError()
def __str__(self):
return u'<TraversalBatch traversals={0}>'.format(len(self))
__repr__ = __str__
class _DefaultTraversalBatch(TraversalBatch):
_traversals = None
def __init__(self, *args, **kwargs):
super(_DefaultTraversalBatch, self).__init__(*args, **kwargs)
self._traversals = []
def add(self, traversal):
if not isinstance(traversal, GraphTraversal):
raise ValueError('traversal should be a gremlin GraphTraversal')
self._traversals.append(traversal)
return self
def add_all(self, traversals):
for traversal in traversals:
self.add(traversal)
def as_graph_statement(self, graph_protocol=GraphProtocol.GRAPHSON_2_0, context=None):
statements = [_query_from_traversal(t, graph_protocol, context) for t in self._traversals]
query = u"[{0}]".format(','.join(statements))
return SimpleGraphStatement(query)
def execute(self):
if self._session is None:
raise ValueError('A DSE Session must be provided to execute the traversal batch.')
execution_profile = self._execution_profile if self._execution_profile else EXEC_PROFILE_GRAPH_DEFAULT
graph_options = self._session.get_execution_profile(execution_profile).graph_options
context = {
'cluster': self._session.cluster,
'graph_name': graph_options.graph_name
}
statement = self.as_graph_statement(graph_options.graph_protocol, context=context) \
if graph_options.graph_protocol else self.as_graph_statement(context=context)
return self._session.execute_graph(statement, execution_profile=execution_profile)
def clear(self):
del self._traversals[:]
def __len__(self):
return len(self._traversals)
|