Spaces:
Running
Running
# Copyright DataStax, Inc. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
from datetime import datetime, timedelta | |
import time | |
from cassandra.query import FETCH_SIZE_UNSET | |
from cassandra.cqlengine import columns | |
from cassandra.cqlengine import UnicodeMixin | |
from cassandra.cqlengine.functions import QueryValue | |
from cassandra.cqlengine.operators import BaseWhereOperator, InOperator, EqualsOperator, IsNotNullOperator | |
class StatementException(Exception): | |
pass | |
class ValueQuoter(UnicodeMixin): | |
def __init__(self, value): | |
self.value = value | |
def __unicode__(self): | |
from cassandra.encoder import cql_quote | |
if isinstance(self.value, (list, tuple)): | |
return '[' + ', '.join([cql_quote(v) for v in self.value]) + ']' | |
elif isinstance(self.value, dict): | |
return '{' + ', '.join([cql_quote(k) + ':' + cql_quote(v) for k, v in self.value.items()]) + '}' | |
elif isinstance(self.value, set): | |
return '{' + ', '.join([cql_quote(v) for v in self.value]) + '}' | |
return cql_quote(self.value) | |
def __eq__(self, other): | |
if isinstance(other, self.__class__): | |
return self.value == other.value | |
return False | |
class InQuoter(ValueQuoter): | |
def __unicode__(self): | |
from cassandra.encoder import cql_quote | |
return '(' + ', '.join([cql_quote(v) for v in self.value]) + ')' | |
class BaseClause(UnicodeMixin): | |
def __init__(self, field, value): | |
self.field = field | |
self.value = value | |
self.context_id = None | |
def __unicode__(self): | |
raise NotImplementedError | |
def __hash__(self): | |
return hash(self.field) ^ hash(self.value) | |
def __eq__(self, other): | |
if isinstance(other, self.__class__): | |
return self.field == other.field and self.value == other.value | |
return False | |
def __ne__(self, other): | |
return not self.__eq__(other) | |
def get_context_size(self): | |
""" returns the number of entries this clause will add to the query context """ | |
return 1 | |
def set_context_id(self, i): | |
""" sets the value placeholder that will be used in the query """ | |
self.context_id = i | |
def update_context(self, ctx): | |
""" updates the query context with this clauses values """ | |
assert isinstance(ctx, dict) | |
ctx[str(self.context_id)] = self.value | |
class WhereClause(BaseClause): | |
""" a single where statement used in queries """ | |
def __init__(self, field, operator, value, quote_field=True): | |
""" | |
:param field: | |
:param operator: | |
:param value: | |
:param quote_field: hack to get the token function rendering properly | |
:return: | |
""" | |
if not isinstance(operator, BaseWhereOperator): | |
raise StatementException( | |
"operator must be of type {0}, got {1}".format(BaseWhereOperator, type(operator)) | |
) | |
super(WhereClause, self).__init__(field, value) | |
self.operator = operator | |
self.query_value = self.value if isinstance(self.value, QueryValue) else QueryValue(self.value) | |
self.quote_field = quote_field | |
def __unicode__(self): | |
field = ('"{0}"' if self.quote_field else '{0}').format(self.field) | |
return u'{0} {1} {2}'.format(field, self.operator, str(self.query_value)) | |
def __hash__(self): | |
return super(WhereClause, self).__hash__() ^ hash(self.operator) | |
def __eq__(self, other): | |
if super(WhereClause, self).__eq__(other): | |
return self.operator.__class__ == other.operator.__class__ | |
return False | |
def get_context_size(self): | |
return self.query_value.get_context_size() | |
def set_context_id(self, i): | |
super(WhereClause, self).set_context_id(i) | |
self.query_value.set_context_id(i) | |
def update_context(self, ctx): | |
if isinstance(self.operator, InOperator): | |
ctx[str(self.context_id)] = InQuoter(self.value) | |
else: | |
self.query_value.update_context(ctx) | |
class IsNotNullClause(WhereClause): | |
def __init__(self, field): | |
super(IsNotNullClause, self).__init__(field, IsNotNullOperator(), '') | |
def __unicode__(self): | |
field = ('"{0}"' if self.quote_field else '{0}').format(self.field) | |
return u'{0} {1}'.format(field, self.operator) | |
def update_context(self, ctx): | |
pass | |
def get_context_size(self): | |
return 0 | |
# alias for convenience | |
IsNotNull = IsNotNullClause | |
class AssignmentClause(BaseClause): | |
""" a single variable st statement """ | |
def __unicode__(self): | |
return u'"{0}" = %({1})s'.format(self.field, self.context_id) | |
def insert_tuple(self): | |
return self.field, self.context_id | |
class ConditionalClause(BaseClause): | |
""" A single variable iff statement """ | |
def __unicode__(self): | |
return u'"{0}" = %({1})s'.format(self.field, self.context_id) | |
def insert_tuple(self): | |
return self.field, self.context_id | |
class ContainerUpdateTypeMapMeta(type): | |
def __init__(cls, name, bases, dct): | |
if not hasattr(cls, 'type_map'): | |
cls.type_map = {} | |
else: | |
cls.type_map[cls.col_type] = cls | |
super(ContainerUpdateTypeMapMeta, cls).__init__(name, bases, dct) | |
class ContainerUpdateClause(AssignmentClause, metaclass=ContainerUpdateTypeMapMeta): | |
def __init__(self, field, value, operation=None, previous=None): | |
super(ContainerUpdateClause, self).__init__(field, value) | |
self.previous = previous | |
self._assignments = None | |
self._operation = operation | |
self._analyzed = False | |
def _analyze(self): | |
raise NotImplementedError | |
def get_context_size(self): | |
raise NotImplementedError | |
def update_context(self, ctx): | |
raise NotImplementedError | |
class SetUpdateClause(ContainerUpdateClause): | |
""" updates a set collection """ | |
col_type = columns.Set | |
_additions = None | |
_removals = None | |
def __unicode__(self): | |
qs = [] | |
ctx_id = self.context_id | |
if (self.previous is None and | |
self._assignments is None and | |
self._additions is None and | |
self._removals is None): | |
qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)] | |
if self._assignments is not None: | |
qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)] | |
ctx_id += 1 | |
if self._additions is not None: | |
qs += ['"{0}" = "{0}" + %({1})s'.format(self.field, ctx_id)] | |
ctx_id += 1 | |
if self._removals is not None: | |
qs += ['"{0}" = "{0}" - %({1})s'.format(self.field, ctx_id)] | |
return ', '.join(qs) | |
def _analyze(self): | |
""" works out the updates to be performed """ | |
if self.value is None or self.value == self.previous: | |
pass | |
elif self._operation == "add": | |
self._additions = self.value | |
elif self._operation == "remove": | |
self._removals = self.value | |
elif self.previous is None: | |
self._assignments = self.value | |
else: | |
# partial update time | |
self._additions = (self.value - self.previous) or None | |
self._removals = (self.previous - self.value) or None | |
self._analyzed = True | |
def get_context_size(self): | |
if not self._analyzed: | |
self._analyze() | |
if (self.previous is None and | |
not self._assignments and | |
self._additions is None and | |
self._removals is None): | |
return 1 | |
return int(bool(self._assignments)) + int(bool(self._additions)) + int(bool(self._removals)) | |
def update_context(self, ctx): | |
if not self._analyzed: | |
self._analyze() | |
ctx_id = self.context_id | |
if (self.previous is None and | |
self._assignments is None and | |
self._additions is None and | |
self._removals is None): | |
ctx[str(ctx_id)] = set() | |
if self._assignments is not None: | |
ctx[str(ctx_id)] = self._assignments | |
ctx_id += 1 | |
if self._additions is not None: | |
ctx[str(ctx_id)] = self._additions | |
ctx_id += 1 | |
if self._removals is not None: | |
ctx[str(ctx_id)] = self._removals | |
class ListUpdateClause(ContainerUpdateClause): | |
""" updates a list collection """ | |
col_type = columns.List | |
_append = None | |
_prepend = None | |
def __unicode__(self): | |
if not self._analyzed: | |
self._analyze() | |
qs = [] | |
ctx_id = self.context_id | |
if self._assignments is not None: | |
qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)] | |
ctx_id += 1 | |
if self._prepend is not None: | |
qs += ['"{0}" = %({1})s + "{0}"'.format(self.field, ctx_id)] | |
ctx_id += 1 | |
if self._append is not None: | |
qs += ['"{0}" = "{0}" + %({1})s'.format(self.field, ctx_id)] | |
return ', '.join(qs) | |
def get_context_size(self): | |
if not self._analyzed: | |
self._analyze() | |
return int(self._assignments is not None) + int(bool(self._append)) + int(bool(self._prepend)) | |
def update_context(self, ctx): | |
if not self._analyzed: | |
self._analyze() | |
ctx_id = self.context_id | |
if self._assignments is not None: | |
ctx[str(ctx_id)] = self._assignments | |
ctx_id += 1 | |
if self._prepend is not None: | |
ctx[str(ctx_id)] = self._prepend | |
ctx_id += 1 | |
if self._append is not None: | |
ctx[str(ctx_id)] = self._append | |
def _analyze(self): | |
""" works out the updates to be performed """ | |
if self.value is None or self.value == self.previous: | |
pass | |
elif self._operation == "append": | |
self._append = self.value | |
elif self._operation == "prepend": | |
self._prepend = self.value | |
elif self.previous is None: | |
self._assignments = self.value | |
elif len(self.value) < len(self.previous): | |
# if elements have been removed, | |
# rewrite the whole list | |
self._assignments = self.value | |
elif len(self.previous) == 0: | |
# if we're updating from an empty | |
# list, do a complete insert | |
self._assignments = self.value | |
else: | |
# the max start idx we want to compare | |
search_space = len(self.value) - max(0, len(self.previous) - 1) | |
# the size of the sub lists we want to look at | |
search_size = len(self.previous) | |
for i in range(search_space): | |
# slice boundary | |
j = i + search_size | |
sub = self.value[i:j] | |
idx_cmp = lambda idx: self.previous[idx] == sub[idx] | |
if idx_cmp(0) and idx_cmp(-1) and self.previous == sub: | |
self._prepend = self.value[:i] or None | |
self._append = self.value[j:] or None | |
break | |
# if both append and prepend are still None after looking | |
# at both lists, an insert statement will be created | |
if self._prepend is self._append is None: | |
self._assignments = self.value | |
self._analyzed = True | |
class MapUpdateClause(ContainerUpdateClause): | |
""" updates a map collection """ | |
col_type = columns.Map | |
_updates = None | |
_removals = None | |
def _analyze(self): | |
if self._operation == "update": | |
self._updates = self.value.keys() | |
elif self._operation == "remove": | |
self._removals = {v for v in self.value.keys()} | |
else: | |
if self.previous is None: | |
self._updates = sorted([k for k, v in self.value.items()]) | |
else: | |
self._updates = sorted([k for k, v in self.value.items() if v != self.previous.get(k)]) or None | |
self._analyzed = True | |
def get_context_size(self): | |
if self.is_assignment: | |
return 1 | |
return int((len(self._updates or []) * 2) + int(bool(self._removals))) | |
def update_context(self, ctx): | |
ctx_id = self.context_id | |
if self.is_assignment: | |
ctx[str(ctx_id)] = {} | |
elif self._removals is not None: | |
ctx[str(ctx_id)] = self._removals | |
else: | |
for key in self._updates or []: | |
val = self.value.get(key) | |
ctx[str(ctx_id)] = key | |
ctx[str(ctx_id + 1)] = val | |
ctx_id += 2 | |
def is_assignment(self): | |
if not self._analyzed: | |
self._analyze() | |
return self.previous is None and not self._updates and not self._removals | |
def __unicode__(self): | |
qs = [] | |
ctx_id = self.context_id | |
if self.is_assignment: | |
qs += ['"{0}" = %({1})s'.format(self.field, ctx_id)] | |
elif self._removals is not None: | |
qs += ['"{0}" = "{0}" - %({1})s'.format(self.field, ctx_id)] | |
ctx_id += 1 | |
else: | |
for _ in self._updates or []: | |
qs += ['"{0}"[%({1})s] = %({2})s'.format(self.field, ctx_id, ctx_id + 1)] | |
ctx_id += 2 | |
return ', '.join(qs) | |
class CounterUpdateClause(AssignmentClause): | |
col_type = columns.Counter | |
def __init__(self, field, value, previous=None): | |
super(CounterUpdateClause, self).__init__(field, value) | |
self.previous = previous or 0 | |
def get_context_size(self): | |
return 1 | |
def update_context(self, ctx): | |
ctx[str(self.context_id)] = abs(self.value - self.previous) | |
def __unicode__(self): | |
delta = self.value - self.previous | |
sign = '-' if delta < 0 else '+' | |
return '"{0}" = "{0}" {1} %({2})s'.format(self.field, sign, self.context_id) | |
class BaseDeleteClause(BaseClause): | |
pass | |
class FieldDeleteClause(BaseDeleteClause): | |
""" deletes a field from a row """ | |
def __init__(self, field): | |
super(FieldDeleteClause, self).__init__(field, None) | |
def __unicode__(self): | |
return '"{0}"'.format(self.field) | |
def update_context(self, ctx): | |
pass | |
def get_context_size(self): | |
return 0 | |
class MapDeleteClause(BaseDeleteClause): | |
""" removes keys from a map """ | |
def __init__(self, field, value, previous=None): | |
super(MapDeleteClause, self).__init__(field, value) | |
self.value = self.value or {} | |
self.previous = previous or {} | |
self._analyzed = False | |
self._removals = None | |
def _analyze(self): | |
self._removals = sorted([k for k in self.previous if k not in self.value]) | |
self._analyzed = True | |
def update_context(self, ctx): | |
if not self._analyzed: | |
self._analyze() | |
for idx, key in enumerate(self._removals): | |
ctx[str(self.context_id + idx)] = key | |
def get_context_size(self): | |
if not self._analyzed: | |
self._analyze() | |
return len(self._removals) | |
def __unicode__(self): | |
if not self._analyzed: | |
self._analyze() | |
return ', '.join(['"{0}"[%({1})s]'.format(self.field, self.context_id + i) for i in range(len(self._removals))]) | |
class BaseCQLStatement(UnicodeMixin): | |
""" The base cql statement class """ | |
def __init__(self, table, timestamp=None, where=None, fetch_size=None, conditionals=None): | |
super(BaseCQLStatement, self).__init__() | |
self.table = table | |
self.context_id = 0 | |
self.context_counter = self.context_id | |
self.timestamp = timestamp | |
self.fetch_size = fetch_size if fetch_size else FETCH_SIZE_UNSET | |
self.where_clauses = [] | |
for clause in where or []: | |
self._add_where_clause(clause) | |
self.conditionals = [] | |
for conditional in conditionals or []: | |
self.add_conditional_clause(conditional) | |
def _update_part_key_values(self, field_index_map, clauses, parts): | |
for clause in filter(lambda c: c.field in field_index_map, clauses): | |
parts[field_index_map[clause.field]] = clause.value | |
def partition_key_values(self, field_index_map): | |
parts = [None] * len(field_index_map) | |
self._update_part_key_values(field_index_map, (w for w in self.where_clauses if w.operator.__class__ == EqualsOperator), parts) | |
return parts | |
def add_where(self, column, operator, value, quote_field=True): | |
value = column.to_database(value) | |
clause = WhereClause(column.db_field_name, operator, value, quote_field) | |
self._add_where_clause(clause) | |
def _add_where_clause(self, clause): | |
clause.set_context_id(self.context_counter) | |
self.context_counter += clause.get_context_size() | |
self.where_clauses.append(clause) | |
def get_context(self): | |
""" | |
returns the context dict for this statement | |
:rtype: dict | |
""" | |
ctx = {} | |
for clause in self.where_clauses or []: | |
clause.update_context(ctx) | |
return ctx | |
def add_conditional_clause(self, clause): | |
""" | |
Adds a iff clause to this statement | |
:param clause: The clause that will be added to the iff statement | |
:type clause: ConditionalClause | |
""" | |
clause.set_context_id(self.context_counter) | |
self.context_counter += clause.get_context_size() | |
self.conditionals.append(clause) | |
def _get_conditionals(self): | |
return 'IF {0}'.format(' AND '.join([str(c) for c in self.conditionals])) | |
def get_context_size(self): | |
return len(self.get_context()) | |
def update_context_id(self, i): | |
self.context_id = i | |
self.context_counter = self.context_id | |
for clause in self.where_clauses: | |
clause.set_context_id(self.context_counter) | |
self.context_counter += clause.get_context_size() | |
def timestamp_normalized(self): | |
""" | |
we're expecting self.timestamp to be either a long, int, a datetime, or a timedelta | |
:return: | |
""" | |
if not self.timestamp: | |
return None | |
if isinstance(self.timestamp, int): | |
return self.timestamp | |
if isinstance(self.timestamp, timedelta): | |
tmp = datetime.now() + self.timestamp | |
else: | |
tmp = self.timestamp | |
return int(time.mktime(tmp.timetuple()) * 1e+6 + tmp.microsecond) | |
def __unicode__(self): | |
raise NotImplementedError | |
def __repr__(self): | |
return self.__unicode__() | |
def _where(self): | |
return 'WHERE {0}'.format(' AND '.join([str(c) for c in self.where_clauses])) | |
class SelectStatement(BaseCQLStatement): | |
""" a cql select statement """ | |
def __init__(self, | |
table, | |
fields=None, | |
count=False, | |
where=None, | |
order_by=None, | |
limit=None, | |
allow_filtering=False, | |
distinct_fields=None, | |
fetch_size=None): | |
""" | |
:param where | |
:type where list of cqlengine.statements.WhereClause | |
""" | |
super(SelectStatement, self).__init__( | |
table, | |
where=where, | |
fetch_size=fetch_size | |
) | |
self.fields = [fields] if isinstance(fields, str) else (fields or []) | |
self.distinct_fields = distinct_fields | |
self.count = count | |
self.order_by = [order_by] if isinstance(order_by, str) else order_by | |
self.limit = limit | |
self.allow_filtering = allow_filtering | |
def __unicode__(self): | |
qs = ['SELECT'] | |
if self.distinct_fields: | |
if self.count: | |
qs += ['DISTINCT COUNT({0})'.format(', '.join(['"{0}"'.format(f) for f in self.distinct_fields]))] | |
else: | |
qs += ['DISTINCT {0}'.format(', '.join(['"{0}"'.format(f) for f in self.distinct_fields]))] | |
elif self.count: | |
qs += ['COUNT(*)'] | |
else: | |
qs += [', '.join(['"{0}"'.format(f) for f in self.fields]) if self.fields else '*'] | |
qs += ['FROM', self.table] | |
if self.where_clauses: | |
qs += [self._where] | |
if self.order_by and not self.count: | |
qs += ['ORDER BY {0}'.format(', '.join(str(o) for o in self.order_by))] | |
if self.limit: | |
qs += ['LIMIT {0}'.format(self.limit)] | |
if self.allow_filtering: | |
qs += ['ALLOW FILTERING'] | |
return ' '.join(qs) | |
class AssignmentStatement(BaseCQLStatement): | |
""" value assignment statements """ | |
def __init__(self, | |
table, | |
assignments=None, | |
where=None, | |
ttl=None, | |
timestamp=None, | |
conditionals=None): | |
super(AssignmentStatement, self).__init__( | |
table, | |
where=where, | |
conditionals=conditionals | |
) | |
self.ttl = ttl | |
self.timestamp = timestamp | |
# add assignments | |
self.assignments = [] | |
for assignment in assignments or []: | |
self._add_assignment_clause(assignment) | |
def update_context_id(self, i): | |
super(AssignmentStatement, self).update_context_id(i) | |
for assignment in self.assignments: | |
assignment.set_context_id(self.context_counter) | |
self.context_counter += assignment.get_context_size() | |
def partition_key_values(self, field_index_map): | |
parts = super(AssignmentStatement, self).partition_key_values(field_index_map) | |
self._update_part_key_values(field_index_map, self.assignments, parts) | |
return parts | |
def add_assignment(self, column, value): | |
value = column.to_database(value) | |
clause = AssignmentClause(column.db_field_name, value) | |
self._add_assignment_clause(clause) | |
def _add_assignment_clause(self, clause): | |
clause.set_context_id(self.context_counter) | |
self.context_counter += clause.get_context_size() | |
self.assignments.append(clause) | |
def is_empty(self): | |
return len(self.assignments) == 0 | |
def get_context(self): | |
ctx = super(AssignmentStatement, self).get_context() | |
for clause in self.assignments: | |
clause.update_context(ctx) | |
return ctx | |
class InsertStatement(AssignmentStatement): | |
""" an cql insert statement """ | |
def __init__(self, | |
table, | |
assignments=None, | |
where=None, | |
ttl=None, | |
timestamp=None, | |
if_not_exists=False): | |
super(InsertStatement, self).__init__(table, | |
assignments=assignments, | |
where=where, | |
ttl=ttl, | |
timestamp=timestamp) | |
self.if_not_exists = if_not_exists | |
def __unicode__(self): | |
qs = ['INSERT INTO {0}'.format(self.table)] | |
# get column names and context placeholders | |
fields = [a.insert_tuple() for a in self.assignments] | |
columns, values = zip(*fields) | |
qs += ["({0})".format(', '.join(['"{0}"'.format(c) for c in columns]))] | |
qs += ['VALUES'] | |
qs += ["({0})".format(', '.join(['%({0})s'.format(v) for v in values]))] | |
if self.if_not_exists: | |
qs += ["IF NOT EXISTS"] | |
using_options = [] | |
if self.ttl: | |
using_options += ["TTL {}".format(self.ttl)] | |
if self.timestamp: | |
using_options += ["TIMESTAMP {}".format(self.timestamp_normalized)] | |
if using_options: | |
qs += ["USING {}".format(" AND ".join(using_options))] | |
return ' '.join(qs) | |
class UpdateStatement(AssignmentStatement): | |
""" an cql update select statement """ | |
def __init__(self, | |
table, | |
assignments=None, | |
where=None, | |
ttl=None, | |
timestamp=None, | |
conditionals=None, | |
if_exists=False): | |
super(UpdateStatement, self). __init__(table, | |
assignments=assignments, | |
where=where, | |
ttl=ttl, | |
timestamp=timestamp, | |
conditionals=conditionals) | |
self.if_exists = if_exists | |
def __unicode__(self): | |
qs = ['UPDATE', self.table] | |
using_options = [] | |
if self.ttl: | |
using_options += ["TTL {0}".format(self.ttl)] | |
if self.timestamp: | |
using_options += ["TIMESTAMP {0}".format(self.timestamp_normalized)] | |
if using_options: | |
qs += ["USING {0}".format(" AND ".join(using_options))] | |
qs += ['SET'] | |
qs += [', '.join([str(c) for c in self.assignments])] | |
if self.where_clauses: | |
qs += [self._where] | |
if len(self.conditionals) > 0: | |
qs += [self._get_conditionals()] | |
if self.if_exists: | |
qs += ["IF EXISTS"] | |
return ' '.join(qs) | |
def get_context(self): | |
ctx = super(UpdateStatement, self).get_context() | |
for clause in self.conditionals: | |
clause.update_context(ctx) | |
return ctx | |
def update_context_id(self, i): | |
super(UpdateStatement, self).update_context_id(i) | |
for conditional in self.conditionals: | |
conditional.set_context_id(self.context_counter) | |
self.context_counter += conditional.get_context_size() | |
def add_update(self, column, value, operation=None, previous=None): | |
value = column.to_database(value) | |
col_type = type(column) | |
container_update_type = ContainerUpdateClause.type_map.get(col_type) | |
if container_update_type: | |
previous = column.to_database(previous) | |
clause = container_update_type(column.db_field_name, value, operation, previous) | |
elif col_type == columns.Counter: | |
clause = CounterUpdateClause(column.db_field_name, value, previous) | |
else: | |
clause = AssignmentClause(column.db_field_name, value) | |
if clause.get_context_size(): # this is to exclude map removals from updates. Can go away if we drop support for C* < 1.2.4 and remove two-phase updates | |
self._add_assignment_clause(clause) | |
class DeleteStatement(BaseCQLStatement): | |
""" a cql delete statement """ | |
def __init__(self, table, fields=None, where=None, timestamp=None, conditionals=None, if_exists=False): | |
super(DeleteStatement, self).__init__( | |
table, | |
where=where, | |
timestamp=timestamp, | |
conditionals=conditionals | |
) | |
self.fields = [] | |
if isinstance(fields, str): | |
fields = [fields] | |
for field in fields or []: | |
self.add_field(field) | |
self.if_exists = if_exists | |
def update_context_id(self, i): | |
super(DeleteStatement, self).update_context_id(i) | |
for field in self.fields: | |
field.set_context_id(self.context_counter) | |
self.context_counter += field.get_context_size() | |
for t in self.conditionals: | |
t.set_context_id(self.context_counter) | |
self.context_counter += t.get_context_size() | |
def get_context(self): | |
ctx = super(DeleteStatement, self).get_context() | |
for field in self.fields: | |
field.update_context(ctx) | |
for clause in self.conditionals: | |
clause.update_context(ctx) | |
return ctx | |
def add_field(self, field): | |
if isinstance(field, str): | |
field = FieldDeleteClause(field) | |
if not isinstance(field, BaseClause): | |
raise StatementException("only instances of AssignmentClause can be added to statements") | |
field.set_context_id(self.context_counter) | |
self.context_counter += field.get_context_size() | |
self.fields.append(field) | |
def __unicode__(self): | |
qs = ['DELETE'] | |
if self.fields: | |
qs += [', '.join(['{0}'.format(f) for f in self.fields])] | |
qs += ['FROM', self.table] | |
delete_option = [] | |
if self.timestamp: | |
delete_option += ["TIMESTAMP {0}".format(self.timestamp_normalized)] | |
if delete_option: | |
qs += [" USING {0} ".format(" AND ".join(delete_option))] | |
if self.where_clauses: | |
qs += [self._where] | |
if self.conditionals: | |
qs += [self._get_conditionals()] | |
if self.if_exists: | |
qs += ["IF EXISTS"] | |
return ' '.join(qs) | |