Disclaimer: Yes I am well aware this is a mad attempt.
Use case:
I am reading from a config file to run a test collection where each such collection comprises of set of test cases with corresponding results and a fixed setup.
Flow (for each test case):
Setup: wipe and setup database with specific test case dataset (glorified SQL file)
load expected test case results from csv
execute collections query/report
compare results.
Sounds good, except the people writing the test cases are more from a tech admin perspective, so the goal is to enable this without writing any python code.
code
Assume these functions exist.
# test_queries.py
def gather_collections(): (collection, query, config)
def gather_cases(collection): (test_case)
def load_collection_stubs(collection): None
def load_case_dataset(test_case): None
def read_case_result_csv(test_case): [csv_result]
def execute(query): [query_result]
class TestQueries(unittest.TestCase):
def setup_method(self, method):
collection = self._item.name.replace('test_', '')
load_collection_stubs(collection)
# conftest.py
import pytest
#pytest.hookimpl(hookwrapper=True)
def pytest_runtest_protocol(item, nextitem):
item.cls._item = item
yield
Example Data
Collection stubs / data (setting up of environment)
-- stubs/test_setup_log.sql
DROP DATABASE IF EXISTS `test`;
CREATE DATABASE `test`;
USE test;
CREATE TABLE log (`id` int(9) NOT NULL AUTO_INCREMENT, `timestamp` datetime NOT NULL DEFAULT NOW(), `username` varchar(100) NOT NULL, `message` varchar(500));
Query to test
-- queries/count.sql
SELECT count(*) as `log_count` from test.log where username = 'unicorn';
Test case 1 input data
-- test_case_1.sql
INSERT INTO log (`id`, `timestamp`, `username`, `message`)
VALUES
(1,'2020-12-18T11:23.01Z', 'unicorn', 'user logged in'),
(2,'2020-12-18T11:23.02Z', 'halsey', 'user logged off'),
(3,'2020-12-18T11:23.04Z', 'unicorn', 'user navigated to home')
Test case 1 expected result
test_case_1.csv
log_count
2
Attempt 1
for collection, query, config in gather_collections():
test_method_name = 'test_{}'.format(collection)
LOGGER.debug("collections.{}.test - {}".format(collection, config))
cases = gather_cases(collection)
LOGGER.debug("collections.{}.cases - {}".format(collection, cases))
setattr(
TestQueries,
test_method_name,
pytest.mark.parametrize(
'case_name',
cases,
ids=cases
)(
lambda self, case_name: (
load_case_dataset(case_name),
self.assertEqual(execute(query, case_name), read_case_result_csv( case_name))
)
)
)
Attempt 2
for collection, query, config in gather_collections():
test_method_name = 'test_{}'.format(collection)
LOGGER.debug("collections.{}.test - {}".format(collection, config))
setattr(
TestQueries,
test_method_name,
lambda self, case_name: (
load_case_dataset(case_name),
self.assertEqual(execute(query, case_name), read_case_result_csv(case_name))
)
)
def pytest_generate_tests(metafunc):
collection = metafunc.function.__name__.replace('test_', '')
# FIXME logs and id setting not working
cases = gather_cases(collection)
LOGGER.info("collections.{}.pytest.cases - {}".format(collection, cases))
metafunc.parametrize(
'case_name',
cases,
ids=cases
)
So I figured it out, but it's not the most elegant solution.
Essentially you use one function and then use some of pytests hooks to change the function names for reporting.
There are numerous issues, e.g. if you don't use pytest.param to pass the parameters to parametrize then you do not have the required information available.
Also the method passed to setup_method is not aware of the actual iteration being run when its called, so I had to hack that in with the iter counter.
# test_queries.py
def gather_tests():
global TESTS
for test_collection_name in TESTS.keys():
LOGGER.debug("collections.{}.gather - {}".format(test_collection_name, TESTS[test_collection_name]))
query = path.join(SRC_DIR, TESTS[test_collection_name]['query'])
cases_dir = TESTS[test_collection_name]['cases']
result_sets = path.join(TEST_DIR, cases_dir, '*.csv')
for case_result_csv in glob.glob(result_sets):
test_case_name = path.splitext(path.basename(case_result_csv))[0]
yield test_case_name, query, test_collection_name, TESTS[test_collection_name]
class TestQueries():
iter = 0
def setup_method(self, method):
method_name = method.__name__ # or self._item.originalname
global TESTS
if method_name == 'test_scripts_reports':
_mark = next((m for m in method.pytestmark if m.name == 'parametrize' and 'collection_name' in m.args[0]), None)
if not _mark:
raise Exception('test {} missing collection_name parametrization'.format(method_name)) # nothing to do here
_args = _mark.args[0]
_params = _mark.args[1]
LOGGER.debug('setup_method: _params - {}'.format(_params))
if not _params:
raise Exception('test {} missing pytest.params'.format(method_name)) # nothing to do here
_currparams =_params[self.iter]
self.iter += 1
_argpos = [arg.strip() for arg in _args.split(',')].index('collection_name')
collection = _currparams.values[_argpos]
LOGGER.debug('collections.{}.setup_method[{}] - {}'.format(collection, self.iter, _currparams))
load_collection_stubs(collection)
#pytest.mark.parametrize(
'case_name, collection_query, collection_name, collection_config',
[pytest.param(*c, id='{}:{}'.format(c[2], c[0])) for c in gather_tests()]
)
def test_scripts_reports(self, case_name, collection_query, collection_name, collection_config):
if not path.isfile(collection_query):
pytest.skip("report query does not exist: {}".format(collection_query))
LOGGER.debug("test_scripts_reports.{}.{} - ".format(collection_name, case_name))
load_case_dataset( case_name)
assert execute(collection_query, case_name) == read_case_result_csv(case_name)
Then to make the test ids more human you can do this
# conftest.py
def pytest_collection_modifyitems(items):
# https://stackoverflow.com/questions/61317809/pytest-dynamically-generating-test-name-during-runtime
for item in items:
if item.originalname == 'test_scripts_reports':
item._nodeid = re.sub(r'::\w+::\w+\[', '[', item.nodeid)
the result with the following files:
stubs/
00-wipe-db.sql
setup-db.sql
queries/
report1.sql
collection/
report1/
case1.sql
case1.csv
case2.sql
case2.csv
# results (with setup_method firing before each test and loading the appropriate stubs as per configuration)
FAILED test_queries.py[report1:case1]
FAILED test_queries.py[report1:case2]
Related
So, I am trying to write an airflow Dag to 1) Read a few different CSVs from my local desk, 2) Create different PostgresQL tables, 3) Load the files into their respective tables. When I am running the DAG, the second step seems to fail.
Below are the DAG logic operators code:
AIRFLOW_HOME = os.getenv('AIRFLOW_HOME')
def get_listings_data ():
listings = pd.read_csv(AIRFLOW_HOME + '/dags/data/listings.csv')
return listings
def get_g01_data ():
demographics= pd.read_csv(AIRFLOW_HOME + '/dags/data/demographics.csv')
return demographics
def insert_listing_data_func(**kwargs):
ps_pg_hook = PostgresHook(postgres_conn_id="postgres")
conn_ps = ps_pg_hook.get_conn()
ti = kwargs['ti']
insert_df = pd.DataFrame.listings
if len(insert_df) > 0:
col_names = ['host_id', 'host_name', 'host_neighbourhood', 'host_total_listings_count', 'neighbourhood_cleansed', 'property_type', 'price', 'has_availability', 'availability_30']
values = insert_df[col_names].to_dict('split')
values = values['data']
logging.info(values)
insert_sql = """
INSERT INTO assignment_2.listings (host_name, host_neighbourhood, host_total_listings_count, neighbourhood_cleansed, property_type, price, has_availability, availability_30)
VALUES %s
"""
result = execute_values(conn_ps.cursor(), insert_sql, values, page_size=len(insert_df))
conn_ps.commit()
else:
None
return None
def insert_demographics_data_func(**kwargs):
ps_pg_hook = PostgresHook(postgres_conn_id="postgres")
conn_ps = ps_pg_hook.get_conn()
ti = kwargs['ti']
insert_df = pd.DataFrame.demographics
if len(insert_df) > 0:
col_names = ['LGA', 'Median_age_persons', 'Median_mortgage_repay_monthly', 'Median_tot_prsnl_inc_weekly', 'Median_rent_weekly', 'Median_tot_fam_inc_weekly', 'Average_num_psns_per_bedroom', 'Median_tot_hhd_inc_weekly', 'Average_household_size']
values = insert_df[col_names].to_dict('split')
values = values['data']
logging.info(values)
insert_sql = """
INSERT INTO assignment_2.demographics (LGA,Median_age_persons,Median_mortgage_repay_monthly,Median_tot_prsnl_inc_weekly,Median_rent_weekly,Median_tot_fam_inc_weekly,Average_num_psns_per_bedroom,Median_tot_hhd_inc_weekly,Average_household_size)
VALUES %s
"""
result = execute_values(conn_ps.cursor(), insert_sql, values, page_size=len(insert_df))
conn_ps.commit()
else:
None
return None
And my postgresQL hook for the demographics table (just an example) is below:
create_psql_table_demographics= PostgresOperator(
task_id="create_psql_table_demographics",
postgres_conn_id="postgres",
sql="""
CREATE TABLE IF NOT EXISTS postgres.demographics (
LGA VARCHAR,
Median_age_persons INT,
Median_mortgage_repay_monthly INT,
Median_tot_prsnl_inc_weekly INT,
Median_rent_weekly INT,
Median_tot_fam_inc_weekly INT,
Average_num_psns_per_bedroom DECIMAL(10,1),
Median_tot_hhd_inc_weekly INT,
Average_household_size DECIMAL(10,2)
);
""",
dag=dag)
Am I missing something in my code that stops the completion of that create_psql_table_demographics from running successfully on Airflow?
If your Postgresql database has access to the CSV files,
you may simply use the copy_expert method of the PostgresHook class (cf documentation).
Postgresql is pretty efficient in loading flat files: you'll save a lot of cpu cycles by not involving python (and Pandas!), not to mention the potential encoding issues that you would have to address.
I have created an ObjectType using Python Graphene however in the query object to return the data, I don't know what the return should be from the resolver.
My code is below:
class RunLog(ObjectType):
status = String()
result = String()
log = List(String)
def resolve_status(self, resolve, run_id):
return r.hget("run:%i" % run_id, "status").decode('utf-8')
def resolve_result(self, resolve, run_id):
return r.hget("run:%i" % run_id, "result").decode('utf-8')
def resolve_log(self, resolve, run_id):
log_data = r.lrange("run:%i:log" % run_id, 0, -1)
log_data = [entry.decode('utf-8') for entry in log_data]
return log_data
class Query(ObjectType):
log_by_run_id = Field(RunLog, run_id=Int(required=True))
def resolve_log_by_run_id(root, info, run_id):
return ???
The RunLog object should read from a redis database and return the data at the relevant run_id.
I want to be able to execute the following query to get the data associated with that run:
{
logByRunId(runId: 1) {
status
result
log
}
}
What should the return be from 'resolve_log_by_run_id'? The Graphene documentation is not helpful.
Try returning a RunLog object, i.e.
def resolve_log_by_run_id(root, info, run_id):
# fetch values from Redis here using run_id (status, result, log)
run_log = RunLog(status=status, result=result, log=log)
return run_log
Also, since the arguments passed to your method are named, you should consider renaming your method something less restrictive like resolve_log or resolve_run_log. If you need to add another filter to your resolver, you won't need to add another method.
I am writing unit test case for a function which has multiple sql queries in it.I am using psycopg2 module and trying to mock the cursor.
app.py
import psycopg2
def my_function():
# all connection related code goes here ...
query = "SELECT name,phone FROM customer WHERE name='shanky'"
cursor.execute(query)
columns = [i[0] for i in cursor.description]
customer_response = []
for row in cursor.fetchall():
customer_response.append(dict(zip(columns, row)))
query = "SELECT name,id FROM product WHERE name='soap'"
cursor.execute(query)
columns = [i[0] for i in cursor.description]
product_response = []
for row in cursor.fetchall():
product_response.append(dict(zip(columns, row)))
return product_response
test.py
from pytest_mock import mocker
import psycopg2
def test_my_function(mocker):
from my_module import app
mocker.patch('psycopg2.connect')
#first query
mocked_cursor_one = psycopg2.connect.return_value.cursor.return_value
mocked_cursor_one.description = [['name'],['phone']]
mocked_cursor_one.fetchall.return_value = [('shanky', '347539593')]
mocked_cursor_one.execute.call_args == "SELECT name,phone FROM customer WHERE name='shanky'"
#second query
mocked_cursor_two = psycopg2.connect.return_value.cursor.return_value
mocked_cursor_two.description = [['name'],['id']]
mocked_cursor_two.fetchall.return_value = [('nirma', 12313)]
mocked_cursor_two.execute.call_args == "SELECT name,id FROM product WHERE name='soap'"
ret = app.my_function()
assert ret == {'name' : 'nirma', 'id' : 12313}
But the mocker always takes the last mock object (the second query).I have already tried multiple hacks, but that didn't work out. How can i mock multiple queries in one function and successfully pass the unit test case? Is it possible to write a unit test case in this fashion or do i need to split the queries in different functions?
After drilling a lot through the documentation, I was able to achieve this with the help of unittest mock decorator and side_effect which was suggested by #Pavel Vergeev.I was able to write a unit test case that is good enough to test the functionality.
from unittest import mock
from my_module import app
#mock.patch('psycopg2.connect')
def test_my_function(mocked_db):
mocked_cursor = mocked_db.return_value.cursor.return_value
description_mock = mock.PropertyMock()
type(mocked_cursor).description = description_mock
fetchall_return_one = [('shanky', '347539593')]
fetchall_return_two = [('nirma', 12313)]
descriptions = [
[['name'],['phone']],
[['name'],['id']]
]
mocked_cursor.fetchall.side_effect = [fetchall_return_one, fetchall_return_two]
description_mock.side_effect = descriptions
ret = app.my_function()
# assert whether called with mocked side effect objects
mocked_db.assert_has_calls(mocked_cursor.fetchall.side_effect)
# assert db query count is 2
assert mocked_db.return_value.cursor.return_value.execute.call_count == 2
# first query
query1 = """
SELECT name,phone FROM customer WHERE name='shanky'
"""
assert mocked_db.return_value.cursor.return_value.execute.call_args_list[0][0][0] == query1
# second query
query2 = """
SELECT name,id FROM product WHERE name='soap'
"""
assert mocked_db.return_value.cursor.return_value.execute.call_args_list[1][0][0] == query2
# assert the data of response
assert ret == {'name' : 'nirma', 'id' : 12313}
In addition to this if there are dynamic parameters in the query, that can be asserted too by the following method.
assert mocked_db.return_value.cursor.return_value.execute.call_args_list[0][0][1] = (parameter_name,)
so when the first query is executed, cursor.execute(query,(parameter_name,)) at call_args_list[0][0][0] the query can be obtained and asserted, at call_args_list[0][0][1] the first parameter parameter_name can be obtained. similarly incrementing the index, all the other params and different queries can be obtained and asserted.
Try side_effect argument of mocker.patch:
from unittest.mock import MagicMock
from pytest_mock import mocker
import psycopg2
def test_my_function(mocker):
from my_module import app
mocker.patch('psycopg2.connect', side_effect=[MagicMock(), MagicMock()])
#first query
mocked_cursor_one = psycopg2.connect().cursor.return_value # note that we actually call psyocpg2.connect -- it's important
mocked_cursor_one.description = [['name'],['phone']]
mocked_cursor_one.fetchall.return_value = [('shanky', '347539593')]
mocked_cursor_one.execute.call_args == "SELECT name,phone FROM customer WHERE name='shanky'"
#second query
mocked_cursor_two = psycopg2.connect().cursor.return_value
mocked_cursor_two.description = [['name'],['id']]
mocked_cursor_two.fetchall.return_value = [('nirma', 12313)]
mocked_cursor_two.execute.call_args == "SELECT name,id FROM product WHERE name='soap'"
assert mocked_cursor_one is not mocked_cursor_two # show that they are different
ret = app.my_function()
assert ret == {'name' : 'nirma', 'id' : 12313}
As per the docs, side_effect allows you to change returned value each time the patched object is called:
If you pass in an iterable, it is used to retrieve an iterator which must yield a value on every call. This value can either be an exception instance to be raised, or a value to be returned from the call to the mock
As I have mentioned in an earlier comment, the best way to make unit testing portable is to develop a complete Mock of your database's behavior.
I've done it for MySQL but it's pretty much the same for all databases.
First of all, I like using wrapper classes over the packages I'm using, it helps quickly change the database at one place instead of changing it everywhere in the code.
Here's a samople of what I use as a wrapper:
Now, you would need to Mock this MySQL class:
# _database.py
# -----------------------------------------------------------------------------
# Database Metaclass
# -----------------------------------------------------------------------------
"""Metaclass for Database implementation.
"""
# -----------------------------------------------------------------------------
import logging
logger = logging.getLogger(__name__)
class Database:
"""Database Metaclass"""
def __init__(self, connect_func, **kwargs):
self.connection = connect_func(**kwargs)
def execute(self, statement, fetchall=True):
"""Execute a statement.
Execute the statement passed as arugment.
Args:
statement (str): SQL Query or Command to execute.
Returns:
set: List of returned objects by the cursor.
"""
cursor = self.connection.cursor()
logger.debug(f"Executing: {statement}")
cursor.execute(statement)
if fetchall:
return cursor.fetchall()
else:
return cursor.fetchone()
def __del__(self):
"""Close connection on object deletion."""
self.connection.close()
And the mysql module:
# mysql.py
# -*- coding: utf-8 -*-
# -----------------------------------------------------------------------------
# MySQL Database Class
# -----------------------------------------------------------------------------
"""Class for MySQL Database connection."""
# -----------------------------------------------------------------------------
import logging
import mysql.connector
from . import _database
logger = logging.getLogger(__name__)
class MySQL(_database.Database):
"""Snowflake Database Class Wrapper.
Attributes:
connection (obj): Object returned from mysql.connector.connect
"""
def __init__(self, autocommit=True, **kwargs):
super().__init__(connect_func=mysql.connector.connect, **kwargs)
self.connection.autocommit = autocommit
Instantiate like: db = MySQL(user='...', password='...', ...)
Here's the data file:
# database_mock_data.json
{
"customer": {
"name": [
"shanky",
"nirma"
],
"phone": [
123123123,
232342342
]
},
"product": {
"name": [
"shanky",
"nirma"
],
"id": [
1,
2
]
}
}
The mocks.py
# mocks.py
import json
import re
from . import mysql
_MOCK_DATA_PATH = 'database_mock_data.json'
class MockDatabase(MySQL):
"""
"""
def __init__(self, **kwargs):
self.connection = MockConnection()
class MockConnection:
"""
Mock the connection object by returning a mock cursor.
"""
#staticmethod
def cursor():
return MockCursor()
class MockCursor:
"""
The Mocked Cursor
A call to execute() will initiate the read on the json data file and will set
the description object (containing the column names usually).
You could implement an update function like `_json_sql_update()`
"""
def __init__(self):
self.description = []
self.__result = None
def execute(self, statement):
data = _read_json_file(_MOCK_DATA_PATH)
if statement.upper().startswith('SELECT'):
self.__result, self.description = _json_sql_select(data, statement)
def fetchall(self):
return self.__result
def fetchone(self):
return self.__result[0]
def _json_sql_select(data, query):
"""
Takes a dictionary and returns the values from a sql query.
NOTE: It does not work with other where clauses than '='.
Also, note that a where statement is expected.
:param (dict) data: Dictionary with the following structure:
{
'tablename': {
'column_name_1': ['value1', 'value2],
'column_name_2': ['value1', 'value2],
...
},
...
}
:param (str) query: An update sql query as:
`update TABLENAME set column_name_1='value'
where column_name_2='value1'`
:return: List of list of values and header description
"""
try:
match = (re.search("select(.*)from(.*)where(.*)[;]?", query,
re.IGNORECASE | re.DOTALL).groups())
except AttributeError:
print("Select Query pattern mismatch... {}".format(query))
raise
# Parse values from the select query
tablename = match[1].strip().upper()
columns = [col.strip().upper() for col in match[0].split(",")]
if columns == ['*']:
columns = data[tablename].keys()
where = [cmd.upper().strip().replace(' ', '')
for cmd in match[2].split('and')]
# Select values
selected_values = []
nb_lines = len(list(data[tablename].values())[0])
for i in range(nb_lines):
is_match = True
for condition in where:
key_condition, value_condition = (_clean_string(condition)
.split('='))
if data[tablename][key_condition][i].upper() != value_condition:
# Set flag to yes
is_match = False
if is_match:
sub_list = []
for column in columns:
sub_list.append(data[tablename][column][i])
selected_values.append(sub_list)
# Usual descriptor has nested list
description = zip(columns, ['...'] * len(columns))
return selected_values, description
def _read_json_file(file_path):
with open(file_path, 'r') as f_in:
data = json.load(f_in)
return data
And then you have your test in a test_module_yourfunction.py
import pytest
def my_function(db, query):
# Code goes here
#pytest.fixture
def db_connection():
return MockDatabase()
#pytest.mark.parametrize(
("query", "expected"),
[
("SELECT name,phone FROM customer WHERE name='shanky'", {'name' : 'nirma', 'id' : 12313}),
("<second query goes here>", "<second result goes here>")
]
)
def test_my_function(db_connection, query, expected):
assert my_function(db_connection, query) == expected
Now I'm sorry if you can't copy/paste this code and make it work, but you get the feeling :) just trying to help
Brand new to this library
Here is the call stack of my mocked object
[call(),
call('test'),
call().instance('test'),
call().instance().database('test'),
call().instance().database().snapshot(),
call().instance().database().snapshot().__enter__(),
call().instance().database().snapshot().__enter__().execute_sql('SELECT * FROM users'),
call().instance().database().snapshot().__exit__(None, None, None),
call().instance().database().snapshot().__enter__().execute_sql().__iter__()]
Here is the code I have used
#mock.patch('testmodule.Client')
def test_read_with_query(self, mock_client):
mock = mock_client()
pipeline = TestPipeline()
records = pipeline | ReadFromSpanner(TEST_PROJECT_ID, TEST_INSTANCE_ID, self.database_id).with_query('SELECT * FROM users')
pipeline.run()
print mock_client.mock_calls
exit()
I want to mock this whole stack that eventually it gives me some fake data which I will provide as a return value.
The code being tested is
spanner_client = Client(self.project_id)
instance = spanner_client.instance(self.instance_id)
database = instance.database(self.database_id)
with database.snapshot() as snapshot:
results = snapshot.execute_sql(self.query)
So my requirements is that the results variable should contain the data I will provide.
How can I provide a return value to such a nested calls
Thanks
Create separate MagicMock instances for the instance, database and snapshot objects in the code under test. Use return_value to configure the return values of each method. Here is an example. I simplified the method under test to just be a free standing function called mut.
# test_module.py : the module under test
class Client:
pass
def mut(project_id, instance_id, database_id, query):
spanner_client = Client(project_id)
instance = spanner_client.instance(instance_id)
database = instance.database(database_id)
with database.snapshot() as snapshot:
results = snapshot.execute_sql(query)
return results
# test code (pytest)
from unittest.mock import MagicMock
from unittest import mock
from test_module import mut
#mock.patch('test_module.Client')
def test_read_with_query(mock_client_class):
mock_client = MagicMock()
mock_instance = MagicMock()
mock_database = MagicMock()
mock_snapshot = MagicMock()
expected = 'fake query results'
mock_client_class.return_value = mock_client
mock_client.instance.return_value = mock_instance
mock_instance.database.return_value = mock_database
mock_database.snapshot.return_value = mock_snapshot
mock_snapshot.execute_sql.return_value = expected
mock_snapshot.__enter__.return_value = mock_snapshot
observed = mut(29, 42, 77, 'select *')
mock_client_class.assert_called_once_with(29)
mock_client.instance.assert_called_once_with(42)
mock_instance.database.assert_called_once_with(77)
mock_database.snapshot.assert_called_once_with()
mock_snapshot.__enter__.assert_called_once_with()
mock_snapshot.execute_sql.assert_called_once_with('select *')
assert observed == expected
This test is kind of portly. Consider breaking it apart by using a fixture and a before function that sets up the mocks.
Either set the value directly to your Mock instance (those enters and exit should have not seen) with:
mock.return_value.instance.return_value.database.return_value.snapshot.return_value.execute_sql.return_value = MY_MOCKED_DATA
or patch and set return_value to target method, something like:
#mock.patch('database_engine.execute_sql', return_value=MY_MOCKED_DATA)
I am using the Products.FacultyStaffDirectory product to manage contacts in my website.
One of the content types, i.e. FacultyStaffDirectory, has the description field in the 'categorization' tab when you are in edit mode. Now I need to remove the description field and put it in the default tab.
To do this, I'm using archetypes.schemaextender product to edit the field. In particular, I am using ISchemaModifier.
I have implemented the code but the field is still being shown in the categorization tab. May be something I missed. Below is the code:
This is the class that contains the description field that I want to modify:
# -*- coding: utf-8 -*-
__author__ = """WebLion <support#weblion.psu.edu>"""
__docformat__ = 'plaintext'
from AccessControl import ClassSecurityInfo
from Products.Archetypes.atapi import *
from zope.interface import implements
from Products.FacultyStaffDirectory.interfaces.facultystaffdirectory import IFacultyStaffDirectory
from Products.FacultyStaffDirectory.config import *
from Products.CMFCore.permissions import View, ManageUsers
from Products.CMFCore.utils import getToolByName
from Products.ATContentTypes.content.base import ATCTContent
from Products.ATContentTypes.content.schemata import ATContentTypeSchema, finalizeATCTSchema
from Products.membrane.at.interfaces import IPropertiesProvider
from Products.membrane.utils import getFilteredValidRolesForPortal
from Acquisition import aq_inner, aq_parent
from Products.FacultyStaffDirectory import FSDMessageFactory as _
schema = ATContentTypeSchema.copy() + Schema((
LinesField('roles_',
accessor='getRoles',
mutator='setRoles',
edit_accessor='getRawRoles',
vocabulary='getRoleSet',
default = ['Member'],
multiValued=1,
write_permission=ManageUsers,
widget=MultiSelectionWidget(
label=_(u"FacultyStaffDirectory_label_FacultyStaffDirectoryRoles", default=u"Roles"),
description=_(u"FacultyStaffDirectory_description_FacultyStaffDirectoryRoles", default=u"The roles all people in this directory will be granted site-wide"),
i18n_domain="FacultyStaffDirectory",
),
),
IntegerField('personClassificationViewThumbnailWidth',
accessor='getClassificationViewThumbnailWidth',
mutator='setClassificationViewThumbnailWidth',
schemata='Display',
default=100,
write_permission=ManageUsers,
widget=IntegerWidget(
label=_(u"FacultyStaffDirectory_label_personClassificationViewThumbnailWidth", default=u"Width for thumbnails in classification view"),
description=_(u"FacultyStaffDirectory_description_personClassificationViewThumbnailWidth", default=u"Show all person thumbnails with a fixed width (in pixels) within the classification view"),
i18n_domain="FacultyStaffDirectory",
),
),
))
FacultyStaffDirectory_schema = OrderedBaseFolderSchema.copy() + schema.copy() # + on Schemas does only a shallow copy
finalizeATCTSchema(FacultyStaffDirectory_schema, folderish=True)
class FacultyStaffDirectory(OrderedBaseFolder, ATCTContent):
"""
"""
security = ClassSecurityInfo()
implements(IFacultyStaffDirectory, IPropertiesProvider)
meta_type = portal_type = 'FSDFacultyStaffDirectory'
# Make this permission show up on every containery object in the Zope instance. This is a Good Thing, because it easy to factor up permissions. The Zope Developer's Guide says to put this here, not in the install procedure (http://www.zope.org/Documentation/Books/ZDG/current/Security.stx). This is because it isn't "sticky", in the sense of being persisted through the ZODB. Thus, it has to run every time Zope starts up. Thus, when you uninstall the product, the permission doesn't stop showing up, but when you actually remove it from the Products folder, it does.
security.setPermissionDefault('FacultyStaffDirectory: Add or Remove People', ['Manager', 'Owner'])
# moved schema setting after finalizeATCTSchema, so the order of the fieldsets
# is preserved. Also after updateActions is called since it seems to overwrite the schema changes.
# Move the description field, but not in Plone 2.5 since it's already in the metadata tab. Although,
# decription and relateditems are occasionally showing up in the "default" schemata. Move them
# to "metadata" just to be safe.
if 'categorization' in FacultyStaffDirectory_schema.getSchemataNames():
FacultyStaffDirectory_schema.changeSchemataForField('description', 'categorization')
else:
FacultyStaffDirectory_schema.changeSchemataForField('description', 'metadata')
FacultyStaffDirectory_schema.changeSchemataForField('relatedItems', 'metadata')
_at_rename_after_creation = True
schema = FacultyStaffDirectory_schema
# Methods
security.declarePrivate('at_post_create_script')
def at_post_create_script(self):
"""Actions to perform after a FacultyStaffDirectory is added to a Plone site"""
# Create some default contents
# Create some base classifications
self.invokeFactory('FSDClassification', id='faculty', title='Faculty')
self.invokeFactory('FSDClassification', id='staff', title='Staff')
self.invokeFactory('FSDClassification', id='grad-students', title='Graduate Students')
# Create a committees folder
self.invokeFactory('FSDCommitteesFolder', id='committees', title='Committees')
# Create a specialties folder
self.invokeFactory('FSDSpecialtiesFolder', id='specialties', title='Specialties')
security.declareProtected(View, 'getDirectoryRoot')
def getDirectoryRoot(self):
"""Return the current FSD object through acquisition."""
return self
security.declareProtected(View, 'getClassifications')
def getClassifications(self):
"""Return the classifications (in brains form) within this FacultyStaffDirectory."""
portal_catalog = getToolByName(self, 'portal_catalog')
return portal_catalog(path='/'.join(self.getPhysicalPath()), portal_type='FSDClassification', depth=1, sort_on='getObjPositionInParent')
security.declareProtected(View, 'getSpecialtiesFolder')
def getSpecialtiesFolder(self):
"""Return a random SpecialtiesFolder contained in this FacultyStaffDirectory.
If none exists, return None."""
specialtiesFolders = self.getFolderContents({'portal_type': 'FSDSpecialtiesFolder'})
if specialtiesFolders:
return specialtiesFolders[0].getObject()
else:
return None
security.declareProtected(View, 'getPeople')
def getPeople(self):
"""Return a list of people contained within this FacultyStaffDirectory."""
portal_catalog = getToolByName(self, 'portal_catalog')
results = portal_catalog(path='/'.join(self.getPhysicalPath()), portal_type='FSDPerson', depth=1)
return [brain.getObject() for brain in results]
security.declareProtected(View, 'getSortedPeople')
def getSortedPeople(self):
""" Return a list of people, sorted by SortableName
"""
people = self.getPeople()
return sorted(people, cmp=lambda x,y: cmp(x.getSortableName(), y.getSortableName()))
security.declareProtected(View, 'getDepartments')
def getDepartments(self):
"""Return a list of FSDDepartments contained within this site."""
portal_catalog = getToolByName(self, 'portal_catalog')
results = portal_catalog(portal_type='FSDDepartment')
return [brain.getObject() for brain in results]
security.declareProtected(View, 'getAddableInterfaceSubscribers')
def getAddableInterfaceSubscribers():
"""Return a list of (names of) content types marked as addable using the
IFacultyStaffDirectoryAddable interface."""
return [type['name'] for type in listTypes() if IFacultyStaffDirectoryAddable.implementedBy(type['klass'])]
security.declarePrivate('getRoleSet')
def getRoleSet(self):
"""Get the roles vocabulary to use."""
portal_roles = getFilteredValidRolesForPortal(self)
allowed_roles = [r for r in portal_roles if r not in INVALID_ROLES]
return allowed_roles
#
# Validators
#
security.declarePrivate('validate_id')
def validate_id(self, value):
"""Ensure the id is unique, also among groups globally."""
if value != self.getId():
parent = aq_parent(aq_inner(self))
if value in parent.objectIds():
return _(u"An object with id '%s' already exists in this folder") % value
groups = getToolByName(self, 'portal_groups')
if groups.getGroupById(value) is not None:
return _(u"A group with id '%s' already exists in the portal") % value
registerType(FacultyStaffDirectory, PROJECTNAME)
Below is the portion of code from the class that I have implemented to change the description field's schemata:
from Products.Archetypes.public import ImageField, ImageWidget, StringField, StringWidget, SelectionWidget, TextField, RichWidget
from Products.FacultyStaffDirectory.interfaces.facultystaffdirectory import IFacultyStaffDirectory
from archetypes.schemaextender.field import ExtensionField
from archetypes.schemaextender.interfaces import ISchemaModifier, ISchemaExtender, IBrowserLayerAwareExtender
from apkn.templates.interfaces import ITemplatesLayer
from zope.component import adapts
from zope.interface import implements
class _ExtensionImageField(ExtensionField, ImageField): pass
class _ExtensionStringField(ExtensionField, StringField): pass
class _ExtensionTextField(ExtensionField, TextField): pass
class FacultyStaffDirectoryExtender(object):
"""
Adapter to add description field to a FacultyStaffDirectory.
"""
adapts(IFacultyStaffDirectory)
implements(ISchemaModifier, IBrowserLayerAwareExtender)
layer = ITemplatesLayer
fields = [
]
def __init__(self, context):
self.context = context
def getFields(self):
return self.fields
def fiddle(self, schema):
desc_field = schema['description'].copy()
desc_field.schemata = "default"
schema['description'] = desc_field
And here is the code from my configure.zcml:
<adapter
name="apkn-FacultyStaffDirectoryExtender"
factory="apkn.templates.extender.FacultyStaffDirectoryExtender"
provides="archetypes.schemaextender.interfaces.ISchemaModifier"
/>
Is there something that I'm missing in this?
def fiddle(self, schema):
schema['description'].schemata = 'default'
should be sufficient. The copy() operation does not make any sense here.
In order to check if the fiddle() method is actually used: use pdb or add debug print statements.
It turns out that there was nothing wrong with the code.
I got it to work by doing the following:
Restarted the plone website
Re-installed the product
Clear and rebuilt my catalog
This got it working somehow.