I'm using the nornir (3.3.0) automation framework with Python 3.8. I'd like to mock the SSH access to the devices in order to do testing without having some real or virtual network equipment online. How would I use patch or Mock/MagicMock from unittest.mock to mock netmiko_send_command (ssh interaction with device)?
I have the following nornir task function:
# dbb_automation/tasks.py
from nornir.core.task import Task, Result
from nornir_netmiko.tasks import netmiko_send_command
def get_interfaces_with_ip(task: Task):
log.debug(f"{task.name}: Getting result on host {task.host}")
result: MultiResult = task.run(name="show ip int br | e unass", task=netmiko_send_command,
command_string="show ip int br | e unass")
content_str = result[0].result
task.run(
task=write_file,
filename=f"outputs/{task.host}-{purpose}.{ending}",
content=content_str
)
return Result(
host=task.host,
result=f"{task.host.name} got ip result"
)
and the following test case (work in progress):
# tests/test_tasks.py
from dbb_automation.tasks import get_interfaces_with_ip
from nornir import InitNornir
from nornir.core.filter import F
from tests.settings import *
def test_get_interfaces_with_ip():
# [x] init nornir with fake host
# [ ] patch/mock netmiko_send_command
# [ ] check file contents with patched return string of netmiko_send_command
nr = InitNornir(
core={
"raise_on_error": True
},
runner={
"plugin": "threaded",
"options": {
"num_workers": 1,
}
},
inventory={
"plugin": "SimpleInventory",
"options": {
"host_file": DNAC_HOSTS_YAML,
"group_file": DNAC_GROUPS_YAML,
"defaults_file": DNAC_DEFAULT_YAML
}
},
logging={
"log_file": "logs/nornir.log"
}
)
result = nr.filter(F(has_parent_group="Borders")).run(name="get_interfaces_with_ip", task=get_interfaces_with_ip)
# todo: test code
assert False
Regards,
GĂ©rard
I think I found the solution. Key was to patch to where the imported function is used not to where it is defined and to set the return value on the mock object.
#patch("dbb_automation.tasks.netmiko_send_command")
def test_get_interfaces_with_ip(mock_netmiko_send_command, nr):
...
mock_netmiko_send_command.return_value = """Interface IP-Address OK? Method Status Protocol
GigabitEthernet22 10.1.54.146 YES TFTP up up
Loopback0 10.150.32.2 YES other up up
Port-channel1.2 10.150.33.65 YES manual up up
...
"""
import pytest
import os
import shutil
from unittest.mock import patch
from dbb_automation.tasks import get_interfaces_with_ip
from nornir import InitNornir
from nornir.core.filter import F
from tests.settings import *
#pytest.fixture()
def nr():
nr = InitNornir(
core={
"raise_on_error": True
},
runner={
"plugin": "threaded",
"options": {
"num_workers": 1,
}
},
inventory={
"plugin": "SimpleInventory",
"options": {
"host_file": DNAC_HOSTS_YAML,
"group_file": DNAC_GROUPS_YAML,
"defaults_file": DNAC_DEFAULT_YAML
}
},
logging={
"log_file": "logs/nornir.log"
}
)
return nr
#patch("dbb_automation.tasks.netmiko_send_command")
def test_get_interfaces_with_ip(mock_netmiko_send_command, nr):
output_folder_name = "outputs"
shutil.rmtree(output_folder_name)
os.mkdir(output_folder_name)
mock_netmiko_send_command.return_value = """Interface IP-Address OK? Method Status Protocol
GigabitEthernet22 10.1.54.146 YES TFTP up up
Loopback0 10.150.32.2 YES other up up
Port-channel1.2 10.150.33.65 YES manual up up
"""
nr.filter(F(has_parent_group="Borders")).run(name="get_interfaces_with_ip", task=get_interfaces_with_ip)
# test code
count = 0
files_found = None
for root_dir, cur_dir, files in os.walk(output_folder_name):
count += len(files)
assert files_found is None # make sure there are no subdirectories
files_found = files
assert count == 4 # we expect a file for each host
for file_name in files_found:
with open(f"{output_folder_name}/{file_name}") as f:
assert f.read() == mock_netmiko_send_command.return_value
Related
I have a problem. I want to make a prediction with TFServing but unfourtnaly as soon as I call the API of TFServing the docker container crashes with the following error:
2022-10-05 08:22:19.091237: I tensorflow_serving/model_servers/server.cc:442] Exporting HTTP/REST API at:localhost:8601 ...
terminate called after throwing an instance of 'std::bad_alloc'
what(): std::bad_alloc
I am using TFServing inside a docker container an the call comes from a flask server. What is the problem for that? I have for the VM 16GB RAM.
server.py
from flask import current_app, flash, jsonify, make_response, redirect, request, url_for
from keras_preprocessing.sequence import pad_sequences
from keras.preprocessing.text import Tokenizer
from dotenv import load_dotenv
from loguru import logger
from pathlib import Path
from flask import Flask
import tensorflow as tf
import numpy as np
import requests
import string
import pickle5 as pickle
import nltk
import re
import os
app = Flask(__name__)
load_dotenv()
#app.route("/test")
def index():
txt = "This is a text"
output = get_prediction_probability(txt)
return output
def text_wragling(text):
x = text.lower()
x = remove_URL(x)
x = remove_punct(x)
x = remove_stopwords(x)
with open('tokenizer.pickle', 'rb') as handle:
tokenizer = pickle.load(handle)
x = tokenizer.texts_to_sequences([x])
# pad
x = pad_sequences(x, maxlen=int(os.getenv('NLP__MAXLEN')))
return x
def remove_URL(text):
url = re.compile(r"https?://\S+|www.\.S+")
return url.sub(r"",text)
def remove_punct(text):
translator = str.maketrans("", "", string.punctuation)
return text.translate(translator)
def remove_stopwords(text):
# nltk.download()
nltk.download('stopwords')
from nltk.corpus import stopwords
stop = set(stopwords.words("english"))
filtered_words = [word.lower() for word in text.split() if word.lower() not in stop]
return " ".join(filtered_words)
def get_prediction_probability(txt):
x = text_wragling(txt)
logger.info("Txt wragling")
data = {
"instances": [
x.tolist()
]
}
#logger.info(data)
logger.info("Get prediction from model")
response = requests.post("http://localhost:8601/v1/models/nlp_model/labels/production:predict", json=data)
probability = (np.asarray(response.json()['predictions']).max(axis=1))
pred = np.asarray(response.json()['predictions']).argmax(axis=1)
with open('labelenconder.pickle', 'rb') as handle:
le = pickle.load(handle)
pred = le.classes_[pred]
prediction = pred[0]
return {
"prediction": prediction,
"probability": probability[0]
}
if __name__ == '__main__':
#test()
app.run(host='0.0.0.0')
Dockerfile
FROM tensorflow/serving
EXPOSE 8601
docker-compose.yml
version: '3'
services:
tfserving:
container_name: tfserving
build: ..
ports:
- "8601:8601"
volumes:
- ./model.config:/models/model.config
- ../model:/models/model
environment:
- TENSORFLOW_SERVING_REST_API_PORT=8061
- TENSORFLOW_SERVING_MODEL_NAME=model
- TENSORFLOW_MODEL_BASE_PATH=/models/model/
entrypoint: [ "bash", "-c", "tensorflow_model_server --rest_api_port=8601 --allow_version_labels_for_unavailable_models --model_config_file=/models/model.config"]
model.config
model_config_list {
config {
name: 'nlp_model'
base_path: '/models/model/'
model_platform: 'tensorflow'
model_version_policy {
specific {
versions: 1
versions: 2
}
}
version_labels {
key: 'production'
value: 1
}
version_labels {
key: 'beta'
value: 2
}
}
}
This an error and has already been reported and been fixed on TensorFlow Serving 2.11 (not yet released).
You can use nightly release from docker-hub.
You can find this issue here #2048.
I'm trying to test this function "start_dojot_messenger", which has some objects and methods in it, and I need to test and verify that they were called with those specific parameters.
For example we have messenger.create_channel("dojot.notifications", "r"), and I need to test if it is starting with these parameters.
follow image of the method
def start_dojot_messenger(config, persister, dojot_persist_notifications_only):
messenger = Messenger("Persister", config)
messenger.init()
# Persister Only Notification
messenger.create_channel("dojot.notifications", "r")
messenger.on(config.dojot['subjects']['tenancy'],
"message", persister.handle_new_tenant)
messenger.on("dojot.notifications", "message",
persister.handle_notification)
LOGGER.info('Listen to notification events')
if str2_bool(dojot_persist_notifications_only) != True:
LOGGER.info("Listen to devices events")
# TODO: add notifications to config on dojot-module-python
messenger.create_channel(config.dojot['subjects']['devices'], "r")
messenger.create_channel(config.dojot['subjects']['device_data'], "r")
messenger.on(config.dojot['subjects']['devices'],
"message", persister.handle_event_devices)
messenger.on(config.dojot['subjects']['device_data'],
"message", persister.handle_event_data)
in my tests I'm doing it that way here.
#patch.object(Persister, 'create_indexes')
#patch.object(Config, 'load_defaults')
#patch('history.subscriber.persister.Messenger')
#patch.object(Messenger, 'create_channel', return_value=None)
#patch.object(Messenger, 'on', return_value=None)
def test_persist_only_notifications(mock_create_channel, mock_on, mock_messenger, mock_config, create_indexes):
from history.subscriber.persister import start_dojot_messenger
from history import conf
p = Persister()
p.create_indexes_for_notifications('admin')
mock_config.dojot = {
"management": {
"user": "dojot-management",
"tenant": "dojot-management"
},
"subjects": {
"tenancy": "dojot.tenancy",
"devices": "dojot.device-manager.device",
"device_data": "device-data"
}
}
# test persist only boolean valued notifications
start_dojot_messenger(mock_config, p, True)
mock_messenger.assert_called()
assert mock_messenger.call_count == 1
the first test to check if the Messenger class has been initialized I can do it using the mock: mock_messenger.assert_called()
but the others I can't access.
I figured it could be something like this
mock_messenger.mock_create_channel.assert_called
I have a gRPC server implemented in python and I am calling an RPC from NodeJS but it gives an error "Method not found". When I call using the python client, the request is successful.
stream_csv.proto
syntax = "proto3";
package csv;
service Stream {
rpc csvToObject(CSVDataRequest) returns (stream CSVDataResponse) {};
rpc sayHello(HelloRequest) returns (HelloReply);
}
message CSVDataRequest{
string url = 1;
enum Protocol {
HTTP = 0;
HTTPS = 1;
FTP = 2;
SFTP = 3;
}
Protocol protocol = 2;
}
message CSVDataResponse{
repeated string row = 1;
}
message HelloRequest {
string name = 1;
}
// The response message containing the greetings
message HelloReply {
string message = 1;
}
client.js
var PROTO_PATH = '../stream_csv.proto';
var grpc = require('grpc');
var protoLoader = require('#grpc/proto-loader');
var packageDefinition = protoLoader.loadSync(
PROTO_PATH,
{keepCase: true,
longs: String,
enums: String,
defaults: true,
oneofs: true
});
var proto = grpc.loadPackageDefinition(packageDefinition).csv;
function main() {
var client = new proto.Stream('localhost:5000',
grpc.credentials.createInsecure());
var user;
if (process.argv.length >= 3) {
user = process.argv[2];
} else {
user = 'world';
}
console.log(user);
client.sayHello({name: user}, function(err, response) {
console.log(err);
});
}
main();
server.py
import grpc
import stream_csv_pb2
import urllib.request
from urllib.error import HTTPError, URLError
from concurrent import futures
class DataService:
def csv_to_object(self, request, context):
url = request.url
protocol = stream_csv_pb2.CSVDataRequest.Protocol.Name(
request.protocol)
fetch_url = protocol.lower() + "://"+url
try:
with urllib.request.urlopen(fetch_url) as data:
for line in data:
decoded_line = line.decode()
val = decoded_line.split(',')
print(val)
print("Data send")
yield stream_csv_pb2.CSVDataResponse(row=val)
print("Sending finished!")
except URLError as e:
context.abort(grpc.StatusCode.UNKNOWN,
'Randomly injected failure.')
# return stream_csv_pb2.CSVDataResponse(row=[], error=e.reason)
def SayHello(self, request, context):
name = request.name
print(name)
return stream_csv_pb2.HelloReply(message='Hello %s' % (name))
def add_DataServicer_to_server(servicer, server):
rpc_method_handlers = {
'CSVToObject': grpc.unary_stream_rpc_method_handler(
servicer.csv_to_object,
request_deserializer=stream_csv_pb2.CSVDataRequest.FromString,
response_serializer=stream_csv_pb2.CSVDataResponse.SerializeToString,
),
'SayHello': grpc.unary_unary_rpc_method_handler(
servicer.SayHello,
request_deserializer=stream_csv_pb2.HelloRequest.FromString,
response_serializer=stream_csv_pb2.HelloReply.SerializeToString,
)
}
generic_handler = grpc.method_handlers_generic_handler(
'stream_csv.Stream', rpc_method_handlers)
server.add_generic_rpc_handlers((generic_handler,))
def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
add_DataServicer_to_server(DataService(), server)
server.add_insecure_port('[::]:5000')
server.start()
server.wait_for_termination()
if __name__ == '__main__':
serve()
Error
Error: 12 UNIMPLEMENTED: Method not found!
at Object.exports.createStatusError (/home/shantam/Documents/grepsr/client/node_modules/grpc/src/common.js:91:15)
at Object.onReceiveStatus (/home/shantam/Documents/grepsr/client/node_modules/grpc/src/client_interceptors.js:1209:28)
at InterceptingListener._callNext (/home/shantam/Documents/grepsr/client/node_modules/grpc/src/client_interceptors.js:568:42)
at InterceptingListener.onReceiveStatus (/home/shantam/Documents/grepsr/client/node_modules/grpc/src/client_interceptors.js:618:8)
at callback (/home/shantam/Documents/grepsr/client/node_modules/grpc/src/client_interceptors.js:847:24) {
code: 12,
metadata: Metadata { _internal_repr: {}, flags: 0 },
details: 'Method not found!'
}
I wrote a simpler variant of your Python (and Golang) server implementations.
Both work with your Node.JS client as-is.
I think perhaps your issue is as simple as needing to name the rpc sayHello (rather than SayHello).
from concurrent import futures
import logging
import grpc
import stream_csv_pb2
import stream_csv_pb2_grpc
class Stream(stream_csv_pb2_grpc.StreamServicer):
def sayHello(self, request, context):
logging.info("[sayHello]")
return stream_csv_pb2.HelloReply(message='Hello, %s!' % request.name)
def serve():
logging.info("[serve]")
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
stream_csv_pb2_grpc.add_StreamServicer_to_server(Stream(), server)
server.add_insecure_port('[::]:50051')
server.start()
server.wait_for_termination()
if __name__ == '__main__':
logging.basicConfig()
serve()
And:
node client.js
Greeting: Hello, Freddie!
null
Conventionally (!) gRPC follows Protobufs style guide of CamelCasing service, rpc and message names and using underscores elsewhere see Style Guide. When protoc compiles, the results don't always match perfectly (e.g. Python's CamelCased functions rather than lowercased with underscores).
Conventionally, your proto would be:
service Stream {
rpc CsvToObject(CSVDataRequest) returns (stream CSVDataResponse) {};
rpc SayHello(HelloRequest) returns (HelloReply);
}
...and then the Python generated function would be (!) SayHello too.
I am trying to mock the result of API call made to compute engine to list VMs. But unfortunately couldn't mock an exact function.
I've tried using PATCH and MOCK methods to mock specific calls made, still unsuccessful
code.py file looks likes this
import googleapiclient.discovery
import logging
class Service:
def __init__(self, project, event):
self.project_id = project
self.compute = googleapiclient.discovery.build('compute', 'v1',
cache_discovery=False)
self.event = event
self.zones = self._validate_event()
def _validate_event(self):
if "jsonPayload" not in self.event:
zones = self.compute.zones().list(
project=self.project_id).execute()['items']
else:
zones = self.compute.zones().get(project=self.project_id,
zone=self.event["jsonPayload"]
["resource"]["zone"]).execute()
logging.debug(f"Identified Zones are {zones}")
return [zone["name"] for zone in zones]
My test file looks like this
# in-built
from unittest import TestCase
from unittest.mock import patch
# custom
import code
class TestServiceModule(TestCase):
def setUp(self):
self.project_id = "sample-project-id"
#patch('code.googleapiclient.discovery')
def test__validate_event_with_empty_inputs(self, mock_discovery):
mock_discovery.build.zones.list.execute.return_value = {"items": [
{
"name": "eu-west-1"
}
]}
obj = code.Service(event={}, project=self.project_id)
print(obj.zones)
In the above test case, I Expected to see "eu-west-1" as the value when I print obj.zones
You didn't mock the googleapiclient.discovery.build method correctly. Here is the unit test solution:
E.g.
code.py:
import googleapiclient.discovery
import logging
class Service:
def __init__(self, project, event):
self.project_id = project
self.compute = googleapiclient.discovery.build('compute', 'v1', cache_discovery=False)
self.event = event
self.zones = self._validate_event()
def _validate_event(self):
if "jsonPayload" not in self.event:
zones = self.compute.zones().list(project=self.project_id).execute()['items']
else:
zones = self.compute.zones().get(project=self.project_id,
zone=self.event["jsonPayload"]["resource"]["zone"]).execute()
logging.debug(f"Identified Zones are {zones}")
return [zone["name"] for zone in zones]
test_code.py:
from unittest import TestCase, main
from unittest.mock import patch
import code
class TestService(TestCase):
def setUp(self):
self.project_id = "sample-project-id"
#patch('code.googleapiclient.discovery')
def test__validate_event_with_empty_inputs(self, mock_discovery):
# Arrange
mock_discovery.build.return_value.zones.return_value.list.return_value.execute.return_value = {
"items": [{"name": "eu-west-1"}]}
# Act
obj = code.Service(event={}, project=self.project_id)
# Assert
mock_discovery.build.assert_called_once_with('compute', 'v1', cache_discovery=False)
mock_discovery.build.return_value.zones.assert_called_once()
mock_discovery.build.return_value.zones.return_value.list.assert_called_once_with(project='sample-project-id')
mock_discovery.build.return_value.zones.return_value.list.return_value.execute.assert_called_once()
self.assertEqual(obj.zones, ["eu-west-1"])
if __name__ == '__main__':
main()
unit test result with coverage report:
.
----------------------------------------------------------------------
Ran 1 test in 0.002s
OK
Name Stmts Miss Cover Missing
-----------------------------------------------------------------------
src/stackoverflow/56794377/code.py 14 1 93% 16
src/stackoverflow/56794377/test_code.py 16 0 100%
-----------------------------------------------------------------------
TOTAL 30 1 97%
Versions:
google-api-python-client==1.12.3
Python 3.7.5
I am trying to create my own ansible module (which will update cmdb) and i am looking how to use ansible_facts in module code ?
example of my module script is :
#!/usr/bin/python
from ansible.module_utils.basic import *
import json, ast
from servicenow import ServiceNow
from servicenow import Connection
def __get_server_info(table,server_name="", sys_id=""):
if sys_id == "":
return table.fetch_one({'name': server_name})
if server_name == "":
return table.fetch_one({'sys_id': sys_id})
def __update_cmdb_hwinfo(table, sys_id, server_name=""):
return table.update({'sys_id': sys_id,{'hw_ram': 'Here for example i want to put ansible_facts about server ram size'})
def main():
fields = {
"snow_instance": {"required": True, "type": "str"},
"snow_username": {"required": True, "type": "str"},
"snow_password": {"required": True, "type": "str"},
"server_name": {"required": True, "type": "str" },
"api_type": {"default": "JSONv2", "type": "str"},
}
module = AnsibleModule(argument_spec=fields)
snow_connection = Connection.Auth(username=module.params['snow_username'], password=module.params['snow_password'], instance=module.params['snow_instance'], api=module.params['api_typ
e'])
server = ServiceNow.Base(snow_connection)
server.__table__ = 'cmdb_ci_server_list.do'
machine = __get_server_info(server, )
## Define connection object to ServiceNow instance
module.exit_json(changed=False, meta=module.params, msg=machine)
if __name__ == '__main__':
main()
What variable i should use to call ansible_facts in module script? (And is it even possible? ).
I doubt this is possible from inside module itself, because they are executed in the context of remote machine with predefined parameters.
But you can wrap your module with action plugin (that is executed in local context), collect required data from available variables and pass them as parameters to your module.
Like this (./action_plugins/a_test.py):
from ansible.plugins.action import ActionBase
class ActionModule(ActionBase):
def run(self, tmp=None, task_vars=None):
result = super(ActionModule, self).run(tmp, task_vars)
module_args = self._task.args.copy()
module_args['mem_size'] = self._templar._available_variables.get('ansible_memtotal_mb')
return self._execute_module(module_args=module_args, task_vars=task_vars, tmp=tmp)
In this case if your module expect mem_size parameter it will be set to ansible_memtotal_mb's value with action plugin.
Module example (./library/a_test.py):
#!/usr/bin/python
def main():
module = AnsibleModule(
argument_spec = dict(
mem_size=dict(required=False, default=None),
),
supports_check_mode = False
)
module.exit_json(changed=False, mem_size=module.params['mem_size'])
from ansible.module_utils.basic import *
from ansible.module_utils.urls import *
main()
Test playbook:
---
- hosts: all
tasks:
- a_test: