"""
All utility methods, classes and functions needed for testing applications
"""
from builtins import str
from builtins import object
from itertools import chain
from webtest import TestRequest
from webob.request import environ_from_url
from pyramid.request import apply_request_extensions
from pyramid.threadlocal import manager
from assembl.lib.sqla import (
get_session_maker, get_metadata, mark_changed)
from assembl.lib import logging
log = logging.getLogger('pytest.assembl')
[docs]class PyramidWebTestRequest(TestRequest):
"""
A mock Pyramid web request this pushes itself onto the threadlocal stack
that also contains the user_id according to authentication model.
This is very useful because throughout the model logic, a request is often
required to determine the current_user, but outside of a Pyramid view. The
way a request is injected is via the current_thread from threadlocal.
"""
def __init__(self, *args, **kwargs):
super(PyramidWebTestRequest, self).__init__(*args, **kwargs)
manager.push({'request': self, 'registry': self.registry})
self._base_pyramid_request = self._pyramid_app.request_factory(
self.environ)
self._base_pyramid_request.registry = self.registry
apply_request_extensions(self)
def populate(self):
# This happens if the request is used through the app
# but sometimes we need to simulate that
routes_mapper = self._pyramid_app.routes_mapper
info = routes_mapper(self)
match, route = info['match'], info['route']
if route:
self.matchdict = match
self.matched_route = route
traverser = self._traverser()
self.__dict__.update(traverser(self))
@property
def session(self):
return self._base_pyramid_request.session
def _traverser(self):
from pyramid.traversal import ResourceTreeTraverser
ctx_root = self._pyramid_app.root_factory(self)
return ResourceTreeTraverser(ctx_root)
[docs] def get_response(self, app, catch_exc_info=True):
try:
super(PyramidWebTestRequest, app).get_response(
catch_exc_info=catch_exc_info)
finally:
manager.pop()
def route_path(self, name, *args, **kwargs):
return self._base_pyramid_request.route_path(
name, *args, **kwargs)
def route_url(self, name, *args, **kwargs):
return self._base_pyramid_request.route_url(
name, *args, **kwargs)
# TODO: Find a way to change user here
authenticated_userid = None
unauthenticated_userid = None
# How come this is missing in TestRequest?
# TODO: Use the negotiator
locale_name = 'en'
def committing_session_tween_factory(handler, registry):
# This ensures that the app has the latest state
def committing_session_tween(request):
get_session_maker().commit()
# Discussion may have been reified too early on request
for item in ('discussion', 'discussion_id'):
if request.__dict__.get(item, item) is None:
del request.__dict__[item]
resp = handler(request)
get_session_maker().flush()
return resp
return committing_session_tween
def as_boolean(s):
if isinstance(s, bool):
return s
return str(s).lower() in ['true', '1', 'on', 'yes']
def get_all_tables(app_settings, session, reversed=True):
schema = app_settings.get('db_schema', 'assembl_test')
# TODO: Quote schema name!
res = session.execute(
"SELECT table_name FROM "
"information_schema.tables WHERE table_schema = "
"'%s' ORDER BY table_name" % (schema,)).fetchall()
res = {row[0] for row in res}
# get the ordered version to minimize cascade.
# cascade does not exist on virtuoso.
import assembl.models
ordered = [t.name for t in get_metadata().sorted_tables
if t.name in res]
ordered.extend([t for t in res if t not in ordered])
if reversed:
ordered.reverse()
log.debug('Current tables: %s' % str(ordered))
return ordered
def self_referential_columns(table):
return [fk.parent for fk in chain(*[
c.foreign_keys for c in table.columns])
if fk.column.table == table]
def clear_rows(app_settings, session):
log.info('Clearing database rows.')
tables_by_name = {
t.name: t for t in get_metadata().sorted_tables}
for table_name in get_all_tables(app_settings, session):
log.debug("Clearing table: %s" % table_name)
table = tables_by_name.get(table_name, None)
if table is not None:
cols = self_referential_columns(table)
if len(cols):
for col in cols:
session.execute("UPDATE %s SET %s=NULL" % (table_name, col.key))
session.flush()
session.execute("DELETE FROM \"%s\"" % table_name)
session.commit()
session.transaction.close()
def drop_tables(app_settings, session):
log.info('Dropping all tables.')
# postgres. Thank you to
# http://stackoverflow.com/questions/5408156/how-to-drop-a-postgresql-database-if-there-are-active-connections-to-it
session.close()
# session.execute(
# """SELECT pg_terminate_backend(pg_stat_activity.pid)
# FROM pg_stat_activity
# WHERE pg_stat_activity.datname = '%s'
# AND pid <> pg_backend_pid()""" % (
# app_settings.get("db_database")))
try:
for row in get_all_tables(app_settings, session):
log.debug("Dropping table: %s" % row)
session.execute("drop table \"%s\"" % row)
session.commit()
mark_changed()
except Exception as e:
raise Exception('Error dropping tables: %s' % e)
def base_fixture_dirname():
from os.path import dirname
return dirname(dirname(dirname(dirname(__file__)))) +\
"/assembl/static/js/app/tests/fixtures/"
[docs]def api_call_to_fname(api_call, method="GET", **args):
"""Translate an API call to a filename containing most of the call information
Used in :js:func:`ajaxMock`"""
import os
import os.path
base_fixture_dir = base_fixture_dirname()
api_dir, fname = api_call.rsplit("/", 1)
api_dir = base_fixture_dir + api_dir
if not os.path.isdir(api_dir):
os.makedirs(api_dir)
args = list(args.items())
args.sort()
args = "_".join(["%s_%s" % x for x in args])
if args:
fname += "_" + args
if method != "GET":
fname = method + "_" + fname
fname += ".json"
return os.path.join(api_dir, fname)
[docs]class RecordingApp(object):
"Decorator for the test_app"
def __init__(self, test_app):
self.app = test_app
def __getattribute__(self, name):
if name not in {
"get", "post", "post_json", "put", "put_json",
"delete", "patch", "patch_json"}:
return super(RecordingApp, self).__getattribute__(name)
def appmethod(url, params=None, headers=None):
r = getattr(self.app, name)(url, params, headers)
assert 200 <= r.status_code < 300
params = params or {}
methodname = name.split("_")[0].upper()
with open(api_call_to_fname(url, methodname, **params), "wb") as f:
f.write(r.body)
return r
return appmethod