from enum import Enum as pyEnum
from collections import OrderedDict
from inspect import isclass
from sqlalchemy import Column, Integer, ForeignKey, Enum as sqlEnum
from sqlalchemy.dialects.postgresql import ENUM
from sqlalchemy.orm import attributes
from sqlalchemy_utils.generic import (
TypeMapper, GenericRelationshipProperty as GenericRelationshipProperty_)
from future.utils import string_types
from .decl_enums import UpdatablePgEnum
class BaseTableEnum(object):
# pretends to be a pyEnum by implementing __members__
__members__ = OrderedDict()
@classmethod
def init_members(cls, Base):
"""Must be called after object initialization"""
base_cls_by_table = {
c.__tablename__: c for c in Base._decl_class_registry.values()
if getattr(c, '__dict__', {}).get('__tablename__', None) and
c.__dict__['__tablename__'] == c.base_tablename()}
tables = list(base_cls_by_table.keys())
# make table order predictable
tables.sort()
for table in tables:
cls.__members__[table] = base_cls_by_table[table]
[docs]class UniversalTableRefColType(UpdatablePgEnum):
def __init__(self, *args, **kwargs):
kwargs['name'] = 'base_tables_enum'
super(UniversalTableRefColType, self).__init__(
BaseTableEnum, *args, ordered=False, **kwargs)
def reset_enum(self):
kw = {}
values, objects = self._parse_into_values([self.enum_class], kw)
self._setup_for_values(values, objects, kw)
def init_datatype(base_class):
from ..models.import_records import ImportRecord
BaseTableEnum.init_members(base_class)
type = ImportRecord.__mapper__.columns['target_table'].type
type.reset_enum()
def init_dbtype(session):
from ..models.import_records import ImportRecord
bind = session.bind
type = ImportRecord.__mapper__.columns['target_table'].type
type.update_type(bind)
[docs]class MulticlassTableRefColType(ENUM):
def __init__(self, target_classes, *args, **kwargs):
class_enum = pyEnum(kwargs['name'], {
c.base_tablename(): c.base_concrete_class()
for c in target_classes})
super(MulticlassTableRefColType, self).__init__(
class_enum, *args, **kwargs)
[docs] def create(self, bind=None, checkfirst=True):
# Alembic uses checkfirst=False, just override
super(UniversalTableRefColType, self).create(bind, True)
class MyTypeMapper(TypeMapper):
def class_to_value(self, cls):
return cls.base_tablename()
def column_is_type(self, column, other_type):
return column == other_type.base_concrete_class()
def value_to_class(self, value, base_class):
if isinstance(value, string_types):
return BaseTableEnum.__members__.get(value, None)
elif isclass(value):
return value
else:
raise RuntimeError("Wrong value")
[docs]class GenericRelationshipProperty(GenericRelationshipProperty_):
_generic_pointers = []
def __init__(self, *args, **kwargs):
kwargs['type_mapper'] = MyTypeMapper()
super(GenericRelationshipProperty, self).__init__(*args, **kwargs)
[docs] def instrument_class(self, mapper):
super(GenericRelationshipProperty, self).instrument_class(mapper)
self._generic_pointers.append((self, mapper))
@classmethod
def declare_universal_delete_cascades(cls, db):
global _universal_pointer_classes
for target_cls in BaseTableEnum.__members__.values():
target_table = target_cls.__tablename__
for (pointer, mapper) in cls._generic_pointers:
source_cls = mapper.class_
source_table = source_cls.base_tablename()
target_type = pointer._discriminator_col.type
type_enum = target_type.enum_class
for key in type_enum.keys():
fname = "on_delete_%s_universal_cascade_%s_%s" % (target_table, source_table, key)
text = """
DROP TABLE IF EXISTS %(fname)s;
CREATE FUNCTION %(fname)s() RETURNS trigger AS $%(fname)s$
BEGIN
DELETE FROM public.%(source_table)s
WHERE %(key)s_table = base_tables_enum.%(target_table)s
AND %(key)s_id = OLD.id
END;
$%(fname)s$ LANGUAGE plpgsql;
DROP TRIGGER IF EXISTS %(fname)s ON %(target_table)s;
CREATE TRIGGER %(fname)s AFTER DELETE ON %(target_table)s
DEFERRABLE FOR EACH ROW
EXECUTE PROCEDURE %(fname)s
""" % {'key': key, 'source_table': source_table,
'target_table': target_table, 'fname': fname}
db.execute(text)
def generic_relationship(*args, **kwargs):
return GenericRelationshipProperty(*args, **kwargs)
# class GenericPointerMixin(object):
# @classmethod
# def list_keys(cls):
# for col in cls.__mapper__._column_to_property:
# if issubclass(col.type, GenericPointerTable):
# name = col.name
# assert name.endswith('_table')
# name = name[:-6]
# assert name + "_id" in cls.__mapper__._props
# yield name
# @classmethod
# def guess_key(cls):
# return next(cls.list_keys())
# def get_instance(self, key=None):
# key = key or self.guess_key()
# table_enum = getattr(self, key + '_table', None)
# if not table_enum:
# return
# assert issubclass(table_enum, object)
# table_id = getattr(self, key + '_id', None)
# return table_enum.get()
# @classmethod
# def references_to_instance_query(cls, instance, key=None):
# key = key or self.guess_key()
# filter_args = {
# key + "_id": instance.id,
# key + "_table": BaseTableEnum.__members__[instance.base_tablename()]
# }
# return instance.db.query(cls).filter_by(**filter_args)