Clone Kubernetes objects programmatically using the Python API - python

The Python API is available to read objects from a cluster. By cloning we can say:
Get a copy of an existing Kubernetes object using kubectl get
Change the properties of the object
Apply the new object
Until recently, the option to --export api was deprecated in 1.14. How can we use the Python Kubernetes API to do the steps from 1-3 described above?
There are multiple questions about how to extract the code from Python API to YAML, but it's unclear how to transform the Kubernetes API object.

Just use to_dict() which is now offered by Kubernetes Client objects. Note that it creates a partly deep copy. So to be safe:
copied_obj = copy.deepcopy(obj.to_dict())
Dicts can be passed to create* and patch* methods.
For convenience, you can also wrap the dict in Prodict.
copied_obj = Prodict.from_dict(copy.deepcopy(obj.to_dict()))
The final issue is getting rid of superfluous fields. (Unfortunately, Kubernetes sprinkles them throughout the object.) I use kopf's internal facility for getting the "essence" of an object. (It takes care of the deep copy.)
copied_obj = kopf.AnnotationsDiffBaseStorage().build(body=kopf.Body(obj.to_dict()))
copied_obj = Prodic.from_dict(copied_obj)

After looking at the requirement, I spent a couple of hours researching the Kubernetes Python API. Issue 340 and others ask about how to transform the Kubernetes API object into a dict, but the only workaround I found was to retrieve the raw data and then convert to JSON.
The following code uses the Kubernetes API to get a deployment and its related hpa from the namespaced objects, but retrieving their raw values as JSON.
Then, after transforming the data into a dict, you can alternatively clean up the data by removing null references.
Once you are done, you can transform the dict as YAML payload to then save the YAML to the file system
Finally, you can apply either using kubectl or the Kubernetes Python API.
Note:
Make sure to set KUBECONFIG=config so that you can point to a cluster
Make sure to adjust the values of origin_obj_name = "istio-ingressgateway" and origin_obj_namespace = "istio-system" with the name of the corresponding objects to be cloned in the given namespace.
import os
import logging
import yaml
import json
logging.basicConfig(level = logging.INFO)
import crayons
from kubernetes import client, config
from kubernetes.client.rest import ApiException
LOGGER = logging.getLogger(" IngressGatewayCreator ")
class IngressGatewayCreator:
#staticmethod
def clone_default_ingress(clone_context):
# Clone the deployment
IngressGatewayCreator.clone_deployment_object(clone_context)
# Clone the deployment's HPA
IngressGatewayCreator.clone_hpa_object(clone_context)
#staticmethod
def clone_deployment_object(clone_context):
kubeconfig = os.getenv('KUBECONFIG')
config.load_kube_config(kubeconfig)
v1apps = client.AppsV1beta1Api()
deployment_name = clone_context.origin_obj_name
namespace = clone_context.origin_obj_namespace
try:
# gets an instance of the api without deserialization to model
# https://github.com/kubernetes-client/python/issues/574#issuecomment-405400414
deployment = v1apps.read_namespaced_deployment(deployment_name, namespace, _preload_content=False)
except ApiException as error:
if error.status == 404:
LOGGER.info("Deployment %s not found in namespace %s", deployment_name, namespace)
return
raise
# Clone the object deployment as a dic
cloned_dict = IngressGatewayCreator.clone_k8s_object(deployment, clone_context)
# Change additional objects
cloned_dict["spec"]["selector"]["matchLabels"]["istio"] = clone_context.name
cloned_dict["spec"]["template"]["metadata"]["labels"]["istio"] = clone_context.name
# Save the deployment template in the output dir
context.save_clone_as_yaml(cloned_dict, "deployment")
#staticmethod
def clone_hpa_object(clone_context):
kubeconfig = os.getenv('KUBECONFIG')
config.load_kube_config(kubeconfig)
hpas = client.AutoscalingV1Api()
hpa_name = clone_context.origin_obj_name
namespace = clone_context.origin_obj_namespace
try:
# gets an instance of the api without deserialization to model
# https://github.com/kubernetes-client/python/issues/574#issuecomment-405400414
hpa = hpas.read_namespaced_horizontal_pod_autoscaler(hpa_name, namespace, _preload_content=False)
except ApiException as error:
if error.status == 404:
LOGGER.info("HPA %s not found in namespace %s", hpa_name, namespace)
return
raise
# Clone the object deployment as a dic
cloned_dict = IngressGatewayCreator.clone_k8s_object(hpa, clone_context)
# Change additional objects
cloned_dict["spec"]["scaleTargetRef"]["name"] = clone_context.name
# Save the deployment template in the output dir
context.save_clone_as_yaml(cloned_dict, "hpa")
#staticmethod
def clone_k8s_object(k8s_object, clone_context):
# Manipilate in the dict level, not k8s api, but from the fetched raw object
# https://github.com/kubernetes-client/python/issues/574#issuecomment-405400414
cloned_obj = json.loads(k8s_object.data)
labels = cloned_obj['metadata']['labels']
labels['istio'] = clone_context.name
cloned_obj['status'] = None
# Scrub by removing the "null" and "None" values
cloned_obj = IngressGatewayCreator.scrub_dict(cloned_obj)
# Patch the metadata with the name and labels adjusted
cloned_obj['metadata'] = {
"name": clone_context.name,
"namespace": clone_context.origin_obj_namespace,
"labels": labels
}
return cloned_obj
# https://stackoverflow.com/questions/12118695/efficient-way-to-remove-keys-with-empty-strings-from-a-dict/59959570#59959570
#staticmethod
def scrub_dict(d):
new_dict = {}
for k, v in d.items():
if isinstance(v, dict):
v = IngressGatewayCreator.scrub_dict(v)
if isinstance(v, list):
v = IngressGatewayCreator.scrub_list(v)
if not v in (u'', None, {}):
new_dict[k] = v
return new_dict
# https://stackoverflow.com/questions/12118695/efficient-way-to-remove-keys-with-empty-strings-from-a-dict/59959570#59959570
#staticmethod
def scrub_list(d):
scrubbed_list = []
for i in d:
if isinstance(i, dict):
i = IngressGatewayCreator.scrub_dict(i)
scrubbed_list.append(i)
return scrubbed_list
class IngressGatewayContext:
def __init__(self, manifest_dir, name, hostname, nats, type):
self.manifest_dir = manifest_dir
self.name = name
self.hostname = hostname
self.nats = nats
self.ingress_type = type
self.origin_obj_name = "istio-ingressgateway"
self.origin_obj_namespace = "istio-system"
def save_clone_as_yaml(self, k8s_object, kind):
try:
# Just try to create if it doesn't exist
os.makedirs(self.manifest_dir)
except FileExistsError:
LOGGER.debug("Dir already exists %s", self.manifest_dir)
full_file_path = os.path.join(self.manifest_dir, self.name + '-' + kind + '.yaml')
# Store in the file-system with the name provided
# https://stackoverflow.com/questions/12470665/how-can-i-write-data-in-yaml-format-in-a-file/18210750#18210750
with open(full_file_path, 'w') as yaml_file:
yaml.dump(k8s_object, yaml_file, default_flow_style=False)
LOGGER.info(crayons.yellow("Saved %s '%s' at %s: \n%s"), kind, self.name, full_file_path, k8s_object)
try:
k8s_clone_name = "http2-ingressgateway"
hostname = "my-nlb-awesome.a.company.com"
nats = ["123.345.678.11", "333.444.222.111", "33.221.444.23"]
manifest_dir = "out/clones"
context = IngressGatewayContext(manifest_dir, k8s_clone_name, hostname, nats, "nlb")
IngressGatewayCreator.clone_default_ingress(context)
except Exception as err:
print("ERROR: {}".format(err))

Not python, but I've used jq in the past to quickly clone something with the small customisations required for each use case (usually cloning secrets into a new namespace).
kc get pod whatever-85pmk -o json \
| jq 'del(.status, .metadata ) | .metadata.name="newname"' \
| kc apply -f - -o yaml --dry-run

This is really easy to do with Hikaru.
Here is an example from my own open source project:
def duplicate_without_fields(obj: HikaruBase, omitted_fields: List[str]):
"""
Duplicate a hikaru object, omitting the specified fields
This is useful when you want to compare two versions of an object and first "cleanup" fields that shouldn't be
compared.
:param HikaruBase obj: A kubernetes object
:param List[str] omitted_fields: List of fields to be omitted. Field name format should be '.' separated
For example: ["status", "metadata.generation"]
"""
if obj is None:
return None
duplication = obj.dup()
for field_name in omitted_fields:
field_parts = field_name.split(".")
try:
if len(field_parts) > 1:
parent_obj = duplication.object_at_path(field_parts[:-1])
else:
parent_obj = duplication
setattr(parent_obj, field_parts[-1], None)
except Exception:
pass # in case the field doesn't exist on this object
return duplication
Dumping the object to yaml afterwards or re-applying it to the cluster is trivial with Hikaru
We're using this to clean up objects so that can show users a github-style diff when objects change, without spammy fields that change often like generation

Related

Flask storing large dataframes across multiple requests

I have a Flask web app that uses a large DataFrame ( hundreds of Megs). The DataFrame is used in the app for several different machine learning models. I want to create the DataFrame only once in the application and use it across multiple requests so that the user may build different models based on the same data. The Flask session is not built for large data, so that is not an option. I do not want to go back and recreate the DataFrame in case the source of the data is a csv file(yuck). be
I have a solution that works, but I cannot find any discussion of this solution in stack overflow. That makes me suspicious that my solution may not be a good design idea. I have always used the assumption that a well beaten path in software development is a path well chosen.
My solution is simply to create a data holder class with one class variable:
class DataHolder:
dataFrameHolder = None
Now the dataFrameHolder is known across all class instances (like a static variable in Java) since it is stored in memory on the server.
I can now create the DataFrame once, put it into the DataHolder class:
import pandas as pd
from dataholder import DataHolder
result_set = pd.read_sql_query(some_SQL, connection)
df = pd.DataFrame(result_set, columns=['col1', 'col2',....]
DataHolder.dataFrameHolder = df
Then access that DataFrame from any code that imports the DataHolder class. I can then use the stored DataFrame anywhere in the application, including across different requests:
.
.
modelDataFrame = DataHolder.dataFrameHolder
do_some_model(modelDataFrame)
.
.
Is this a bad idea, a good idea, or is there something else that I am not aware of that already solves the problem?
Redis can be used. My use case is smaller data frames so have not tested with larger data frames. This allows me to provide 3 second ticking data to multiple browser clients. pyarrow serialisation / deserialisation is performing well. Works locally and across AWS/GCloud and Azure
GET route
#app.route('/cacheget/<path:key>', methods=['GET'])
def cacheget(key):
c = mycache()
data = c.redis().get(key)
resp = Response(BytesIO(data), mimetype="application/octet-stream", direct_passthrough=True)
resp.headers["key"] = key
resp.headers["type"] = c.redis().get(f"{key}.type")
resp.headers["size"] = sys.getsizeof(data)
resp.headers["redissize"] = sys.getsizeof(c.redis().get(key))
return resp
sample route to put dataframe into cache
#app.route('/sensor_data', methods=['POST'])
def sensor_data() -> str:
c = mycache()
dfsensor = c.get("dfsensor")
newsensor = json_normalize(request.get_json())
newsensor[["x","y"]] = newsensor[["epoch", "value"]]
newsensor["xy"] = newsensor[['x', 'y']].agg(pd.Series.to_dict, axis=1)
newsensor["amin"] = newsensor["value"]
newsensor["amax"] = newsensor["value"]
newsensor = newsensor.drop(columns=["x","y"])
# add new data from serial interface to start of list (append old data to new data).
# default time as now to new data
dfsensor = newsensor.append(dfsensor, sort=False)
# keep size down - only last 500 observations
c.set("dfsensor", dfsensor[:500])
del dfsensor
return jsonify(result={"status":"ok"})
utility class
import pandas as pd
import pyarrow as pa, os
import redis,json, os, pickle
import ebutils
from logenv import logenv
from pandas.core.frame import DataFrame
from redis.client import Redis
from typing import (Union, Optional)
class mycache():
__redisClient:Redis
CONFIGKEY = "cacheconfig"
def __init__(self) -> None:
try:
ep = os.environ["REDIS_HOST"]
except KeyError:
if os.environ["HOST_ENV"] == "GCLOUD":
os.environ["REDIS_HOST"] = "redis://10.0.0.3"
elif os.environ["HOST_ENV"] == "EB":
os.environ["REDIS_HOST"] = "redis://" + ebutils.get_redis_endpoint()
elif os.environ["HOST_ENV"] == "AZURE":
#os.environ["REDIS_HOST"] = "redis://ignore:password#redis-sensorvenv.redis.cache.windows.net"
pass # should be set in azure env variable
elif os.environ["HOST_ENV"] == "LOCAL":
os.environ["REDIS_HOST"] = "redis://127.0.0.1"
else:
raise "could not initialise redis"
return # no known redis setup
#self.__redisClient = redis.Redis(host=os.environ["REDIS_HOST"])
self.__redisClient = redis.Redis.from_url(os.environ["REDIS_HOST"])
self.__redisClient.ping()
# get config as well...
self.config = self.get(self.CONFIGKEY)
if self.config is None:
self.config = {"pyarrow":True, "pickle":False}
self.set(self.CONFIGKEY, self.config)
self.alog = logenv.alog()
def redis(self) -> Redis:
return self.__redisClient
def exists(self, key:str) -> bool:
if self.__redisClient is None:
return False
return self.__redisClient.exists(key) == 1
def get(self, key:str) -> Union[DataFrame, str]:
keytype = "{k}.type".format(k=key)
valuetype = self.__redisClient.get(keytype)
if valuetype is None:
if (key.split(".")[-1] == "pickle"):
return pickle.loads(self.redis().get(key))
else:
ret = self.redis().get(key)
if ret is None:
return ret
else:
return ret.decode()
elif valuetype.decode() == str(pd.DataFrame):
# fallback to pickle serialized form if pyarrow fails
# https://issues.apache.org/jira/browse/ARROW-7961
try:
return pa.deserialize(self.__redisClient.get(key))
except pa.lib.ArrowIOError as err:
self.alog.warning("using pickle from cache %s - %s - %s", key, pa.__version__, str(err))
return pickle.loads(self.redis().get(f"{key}.pickle"))
except OSError as err:
if "Expected IPC" in str(err):
self.alog.warning("using pickle from cache %s - %s - %s", key, pa.__version__, str(err))
return pickle.loads(self.redis().get(f"{key}.pickle"))
else:
raise err
elif valuetype.decode() == str(type({})):
return json.loads(self.__redisClient.get(key).decode())
else:
return self.__redisClient.get(key).decode() # type: ignore
def set(self, key:str, value:Union[DataFrame, str]) -> None:
if self.__redisClient is None:
return
keytype = "{k}.type".format(k=key)
if str(type(value)) == str(pd.DataFrame):
self.__redisClient.set(key, pa.serialize(value).to_buffer().to_pybytes())
if self.config["pickle"]:
self.redis().set(f"{key}.pickle", pickle.dumps(value))
# issue should be transient through an upgrade....
# once switched off data can go away
self.redis().expire(f"{key}.pickle", 60*60*24)
elif str(type(value)) == str(type({})):
self.__redisClient.set(key, json.dumps(value))
else:
self.__redisClient.set(key, value)
self.__redisClient.set(keytype, str(type(value)))
if __name__ == '__main__':
os.environ["HOST_ENV"] = "LOCAL"
r = mycache()
rr = r.redis()
for k in rr.keys("cache*"):
print(k.decode(), rr.ttl(k))
print(rr.get(k.decode()))
I had a similar problem, as I was importing CSVs (100s of MBs) and creating DataFrames on the fly for each request, which as you said was yucky! I also tried the REDIS way, to cache it, and that improved performance for a while, until I realized that making changes to the underlying data meant updating the cache as well.
Then I found a world beyond CSV, and more performant file formats like Pickle, Feather, Parquet, and others. You can read more about them here. You can import/export CSVs all you want, but use intermediate formats for processing.
I did run into some issues though. I have read that Pickle has security issues, even though I still use it. Feather wouldn't let me write some object types in my data, it needed them categorized. Your mileage may vary, but if you have good clean data, use Feather.
And more recently I've found that I manage large data using Datatable instead of Pandas, and storing them in Jay for even better read/write performance.
This does however mean re-writing bits of code that use Pandas into DataTable, but I believe the APIs are very similar. I have not yet done it myself, because of very large codebase, but you can give it a try.

How to use documentation of azure DevOps python API, I am trying to get what members an object has when API call is made?

I'm using the python API for Azure DevOps, I am trying to get a list of repo and branches each repo has. I don't know what members an object has when the API call returns the object. I am not sure if the documentation is complete, How to know what members an object has?
like for an object of "repo", I guessed the "name" property
from azure.devops.connection import Connection
from msrest.authentication import BasicAuthentication
import pprint
# Fill in with your personal access token and org URL
personal_access_token = 'YOURPAT'
organization_url = 'https://dev.azure.com/YOURORG'
project_name = "XYZ"
# Create a connection to the org
credentials = BasicAuthentication('', personal_access_token)
connection = Connection(base_url=organization_url, creds=credentials)
# Get a client (the "core" client provides access to projects, teams, etc)
git_client = connection.clients.get_git_client
# Get the first page of projects
get_repos_response = git_client.get_repositories(project_name)
index = 0
for repo in get_repos_response:
pprint.pprint(str(index) + "." + repo.name)
index += 1
In the above code, I just guessed the name property of a repo. I want to know what branches each of these repos have.
TIA
Based on these description, you want to apply the python api to get Repos list, branch list of corresponding repos and the creator of branch. For >>when the branch was last used, not sure whether it means the latest update time of the branch.
If this, just refer to this source code sample.
Get repositories:
It can only be got based on the project name provided, so use _send to pass the parameters to client:
route_values = {}
if project is not None:
route_values['project'] = self._serialize.url('project', project, 'str')
query_parameters = {}
if include_links is not None:
query_parameters['includeLinks'] = self._serialize.query('include_links', include_links, 'bool')
if include_all_urls is not None:
query_parameters['includeAllUrls'] = self._serialize.query('include_all_urls', include_all_urls, 'bool')
if include_hidden is not None:
query_parameters['includeHidden'] = self._serialize.query('include_hidden', include_hidden, 'bool')
response = self._send(http_method='GET',
location_id='225f7195-f9c7-4d14-ab28-a83f7ff77e1f',
version='6.0-preview.1',
route_values=route_values,
query_parameters=query_parameters)
return self._deserialize('[GitRepository]', self._unwrap_collection(response))
Get branches:
Add repository_id and base_version_descriptor additionally:
if project is not None:
route_values['project'] = self._serialize.url('project', project, 'str')
query_parameters = {}
if include_links is not None:
query_parameters['includeLinks'] = self._serialize.query('include_links', include_links, 'bool')
if include_all_urls is not None:
query_parameters['includeAllUrls'] = self._serialize.query('include_all_urls', include_all_urls, 'bool')
if include_hidden is not None:
query_parameters['includeHidden'] = self._serialize.query('include_hidden', include_hidden, 'bool')
branch creator and the latest update time, these message are contained in one specified branch and should be get by getting branch. At this time, the name is necessary which represent one specified branch name:
route_values = {}
if project is not None:
route_values['project'] = self._serialize.url('project', project, 'str')
if repository_id is not None:
route_values['repositoryId'] = self._serialize.url('repository_id', repository_id, 'str')
query_parameters = {}
if name is not None:
query_parameters['name'] = self._serialize.query('name', name, 'str')
if base_version_descriptor is not None:
if base_version_descriptor.version_type is not None:
query_parameters['baseVersionDescriptor.versionType'] = base_version_descriptor.version_type
if base_version_descriptor.version is not None:
query_parameters['baseVersionDescriptor.version'] = base_version_descriptor.version
if base_version_descriptor.version_options is not None:
query_parameters['baseVersionDescriptor.versionOptions'] = base_version_descriptor.version_options

AWS KMS client not returning an alias name

I've been using the list_aliases() method of KMS client since a while now without any issues. But recently it has stopped listing one of the alias names I want to use.
import boto3
kms_client = boto3.client('kms')
# Getting all the aliases from my KMS
key_aliases = kms_client.list_aliases()
key_aliases = key_aliases['Aliases']
# DO SOMETHING...
The key_aliases list above contains all the keys except the one I want to use. However, I can see from the AWS KMS UI that the key is enabled. Not sure why the list_aliases() method is not returning it.
Has anyone faced this problem?
It looks like the response is truncated. The default number of aliases fetched by this API call is 50. You can increase the limit up to 100, which should solve your problem.
key_aliases = kms_client.list_aliases(Limit=100)
You should also check if the truncated field in the response is set to True. In that case, you can just make another API call to fetch the remaining results:
if key_aliases['Truncated'] is True:
key_aliases = kms_client.list_aliases(Marker=key_aliases['NextMarker'])
...
def get_keys_arn(kmsclient,key_name):
#Marker = 'string'
alias_list = kmsclient.list_aliases(Limit=999)
if alias_list['Truncated'] is True:
alias_list_trun = alias_list['Aliases']
for alias in alias_list_trun:
if alias["AliasName"] == "alias/" + key_name:
return alias["TargetKeyId"]
while alias_list['Truncated'] :
alias_list = kmsclient.list_aliases(Limit=999,Marker=alias_list['NextMarker'])
alias_list_trun = alias_list['Aliases']
for alias in alias_list_trun:
if alias["AliasName"] == "alias/" + key_name:
return alias["TargetKeyId"]
else:
alias_list= alias_list['Aliases']
for alias in alias_list:
if alias["AliasName"] == "alias/" + key_name:
return alias["TargetKeyId"]

How to Parse YAML Using PyYAML if there are '!' within the YAML

I have a YAML file that I'd like to parse the description variable only; however, I know that the exclamation points in my CloudFormation template (YAML file) are giving PyYAML trouble.
I am receiving the following error:
yaml.constructor.ConstructorError: could not determine a constructor for the tag '!Equals'
The file has many !Ref and !Equals. How can I ignore these constructors and get a specific variable I'm looking for -- in this case, the description variable.
If you have to deal with a YAML document with multiple different tags, and
are only interested in a subset of them, you should still
handle them all. If the elements you are intersted in are nested
within other tagged constructs you at least need to handle all of the "enclosing" tags
properly.
There is however no need to handle all of the tags individually, you
can write a constructor routine that can handle mappings, sequences
and scalars register that to PyYAML's SafeLoader using:
import yaml
inp = """\
MyEIP:
Type: !Join [ "::", [AWS, EC2, EIP] ]
Properties:
InstanceId: !Ref MyEC2Instance
"""
description = []
def any_constructor(loader, tag_suffix, node):
if isinstance(node, yaml.MappingNode):
return loader.construct_mapping(node)
if isinstance(node, yaml.SequenceNode):
return loader.construct_sequence(node)
return loader.construct_scalar(node)
yaml.add_multi_constructor('', any_constructor, Loader=yaml.SafeLoader)
data = yaml.safe_load(inp)
print(data)
which gives:
{'MyEIP': {'Type': ['::', ['AWS', 'EC2', 'EIP']], 'Properties': {'InstanceId': 'MyEC2Instance'}}}
(inp can also be a file opened for reading).
As you see above will also continue to work if an unexpected !Join tag shows up in your code,
as well as any other tag like !Equal. The tags are just dropped.
Since there are no variables in YAML, it is a bit of guesswork what
you mean by "like to parse the description variable only". If that has
an explicit tag (e.g. !Description), you can filter out the values by adding 2-3 lines
to the any_constructor, by matching the tag_suffix parameter.
if tag_suffix == u'!Description':
description.append(loader.construct_scalar(node))
It is however more likely that there is some key in a mapping that is a scalar description,
and that you are interested in the value associated with that key.
if isinstance(node, yaml.MappingNode):
d = loader.construct_mapping(node)
for k in d:
if k == 'description':
description.append(d[k])
return d
If you know the exact position in the data hierarchy, You can of
course also walk the data structure and extract anything you need
based on keys or list positions. Especially in that case you'd be better of
using my ruamel.yaml, was this can load tagged YAML in round-trip mode without
extra effort (assuming the above inp):
from ruamel.yaml import YAML
with YAML() as yaml:
data = yaml.load(inp)
You can define a custom constructors using a custom yaml.SafeLoader
import yaml
doc = '''
Conditions:
CreateNewSecurityGroup: !Equals [!Ref ExistingSecurityGroup, NONE]
'''
class Equals(object):
def __init__(self, data):
self.data = data
def __repr__(self):
return "Equals(%s)" % self.data
class Ref(object):
def __init__(self, data):
self.data = data
def __repr__(self):
return "Ref(%s)" % self.data
def create_equals(loader,node):
value = loader.construct_sequence(node)
return Equals(value)
def create_ref(loader,node):
value = loader.construct_scalar(node)
return Ref(value)
class Loader(yaml.SafeLoader):
pass
yaml.add_constructor(u'!Equals', create_equals, Loader)
yaml.add_constructor(u'!Ref', create_ref, Loader)
a = yaml.load(doc, Loader)
print(a)
Outputs:
{'Conditions': {'CreateNewSecurityGroup': Equals([Ref(ExistingSecurityGroup), 'NONE'])}}

ndb.Key filter for MapReduce input_reader

Playing with new Google App Engine MapReduce library filters for input_reader I would like to know how can I filter by ndb.Key.
I read this post and I've played with datetime, string, int, float, in filters tuples, but How I can filter by ndb.Key?
When I try to filter by a ndb.Key I get this error:
BadReaderParamsError: Expected Key, got u"Key('Clients', 406)"
Or this error:
TypeError: Key('Clients', 406) is not JSON serializable
I tried to pass a ndb.Key object and string representation of the ndb.Key.
Here are my two filters tuples:
Sample 1:
input_reader': {
'input_reader': 'mapreduce.input_readers.DatastoreInputReader',
'entity_kind': 'model.Sales',
'filters': [("client","=", ndb.Key('Clients', 406))]
}
Sample 2:
input_reader': {
'input_reader': 'mapreduce.input_readers.DatastoreInputReader',
'entity_kind': 'model.Sales',
'filters': [("client","=", "%s" % ndb.Key('Clients', 406))]
}
This is a bit tricky.
If you look at the code on Google Code you can see that mapreduce.model defines a JSON_DEFAULTS dict which determines the classes that get special-case handling in JSON serialization/deserialization: by default, just datetime. So, you can monkey-patch the ndb.Key class into there, and provide it with functions to do that serialization/deserialization - something like:
from mapreduce import model
def _JsonEncodeKey(o):
"""Json encode an ndb.Key object."""
return {'key_string': o.urlsafe()}
def _JsonDecodeKey(d):
"""Json decode a ndb.Key object."""
return ndb.Key(urlsafe=d['key_string'])
model.JSON_DEFAULTS[ndb.Key] = (_JsonEncodeKey, _JsonDecodeKey)
model._TYPE_IDS['Key'] = ndb.Key
You may also need to repeat those last two lines to patch mapreduce.lib.pipeline.util as well.
Also note if you do this, you'll need to ensure that this gets run on any instance that runs any part of a mapreduce: the easiest way to do this is to write a wrapper script that imports the above registration code, as well as mapreduce.main.APP, and override the mapreduce URL in your app.yaml to point to your wrapper.
Make your own input reader based on DatastoreInputReader, which knows how to decode key-based filters:
class DatastoreKeyInputReader(input_readers.DatastoreKeyInputReader):
"""Augment the base input reader to accommodate ReferenceProperty filters"""
def __init__(self, *args, **kwargs):
try:
filters = kwargs['filters']
decoded = []
for f in filters:
value = f[2]
if isinstance(value, list):
value = db.Key.from_path(*value)
decoded.append((f[0], f[1], value))
kwargs['filters'] = decoded
except KeyError:
pass
super(DatastoreKeyInputReader, self).__init__(*args, **kwargs)
Run this function on your filters before passing them in as options:
def encode_filters(filters):
if filters is not None:
encoded = []
for f in filters:
value = f[2]
if isinstance(value, db.Model):
value = value.key()
if isinstance(value, db.Key):
value = value.to_path()
entry = (f[0], f[1], value)
encoded.append(entry)
filters = encoded
return filters
Are you aware of the to_old_key() and from_old_key() methods?
I had the same problem and came up with a workaround with computed properties.
You can add to your Sales model a new ndb.ComputedProperty with the Key id. Ids are just strings, so you wont have any JSON problems.
client_id = ndb.ComputedProperty(lambda self: self.client.id())
And then add that condition to your mapreduce query filters
input_reader': {
'input_reader': 'mapreduce.input_readers.DatastoreInputReader',
'entity_kind': 'model.Sales',
'filters': [("client_id","=", '406']
}
The only drawback is that Computed properties are not indexed and stored until you call the put() parameter, so you will have to traverse all the Sales entities and save them:
for sale in Sales.query().fetch():
sale.put()

Categories