Turn a tf.data.Dataset to a jax.numpy iterator - python

I am interested about training a neural network using JAX. I had a look on tf.data.Dataset, but it provides exclusively tf tensors. I looked for a way to change the dataset into JAX numpy array and I found a lot of implementations that use Dataset.as_numpy_generator() to turn the tf tensors to numpy arrays. However I wonder if it is a good practice, as numpy arrays are stored in CPU memory and it is not what I want for my training (I use the GPU). So the last idea I found is to manually recast the arrays by calling jnp.array but it is not really elegant (I am afraid about the copy in GPU memory). Does anyone have a better idea for that?
Quick code to illustrate:
import os
import jax.numpy as jnp
import tensorflow as tf
def generator():
for _ in range(2):
yield tf.random.uniform((1, ))
ds = tf.data.Dataset.from_generator(generator, output_types=tf.float32,
output_shapes=tf.TensorShape([1]))
ds1 = ds.take(1).as_numpy_iterator()
ds2 = ds.skip(1)
for i, batch in enumerate(ds1):
print(type(batch))
for i, batch in enumerate(ds2):
print(type(jnp.array(batch)))
# returns:
<class 'numpy.ndarray'> # not good
<class 'jaxlib.xla_extension.DeviceArray'> # good but not elegant

Both tensorflow and JAX have the ability to convert arrays to dlpack tensors without copying memory, so one way you can create a JAX array from a tensorflow array without copying the underlying data buffer is to do it via dlpack:
import numpy as np
import tensorflow as tf
import jax.dlpack
tf_arr = tf.random.uniform((10,))
dl_arr = tf.experimental.dlpack.to_dlpack(tf_arr)
jax_arr = jax.dlpack.from_dlpack(dl_arr)
np.testing.assert_array_equal(tf_arr, jax_arr)
By doing the round-trip to JAX, you can compare unsafe_buffer_pointer() to ensure that the arrays point at the same buffer, rather than copying the buffer along the way:
def tf_to_jax(arr):
return jax.dlpack.from_dlpack(tf.experimental.dlpack.to_dlpack(tf_arr))
def jax_to_tf(arr):
return tf.experimental.dlpack.from_dlpack(jax.dlpack.to_dlpack(arr))
jax_arr = jnp.arange(20.)
tf_arr = jax_to_tf(jax_arr)
jax_arr2 = tf_to_jax(tf_arr)
print(jnp.all(jax_arr == jax_arr2))
# True
print(jax_arr.unsafe_buffer_pointer() == jax_arr2.unsafe_buffer_pointer())
# True

From Flax example:
https://github.com/google/flax/blob/6ae22681ef6f6c004140c3759e7175533bda55bd/examples/imagenet/train.py#L183
def prepare_tf_data(xs):
local_device_count = jax.local_device_count()
def _prepare(x):
x = x._numpy()
return x.reshape((local_device_count, -1) + x.shape[1:])
return jax.tree_util.tree_map(_prepare, xs)
it = map(prepare_tf_data, ds)
it = jax_utils.prefetch_to_device(it, 2)

Related

Benefit of storing state as a list/integer in tensorflow agents

In the environment tutorial of tensorflow agents (https://www.tensorflow.org/agents/tutorials/2_environments_tutorial), the state is stored as an integer. When the state is required, it is converted to a numpy array:
from tf_agents.environments import py_environment
import numpy as np
class CardGameEnv(py_environment.PyEnvironment):
def __init__(self):
self._state = 0
def _step(self,action):
state_array = np.array([self._state], dtype=np.int32)
return np.transition(state_array, reward=1.0, discount=0.9)
Is there any reason why they do this, instead of just storing the state directly as a numpy array? So like this:
from tf_agents.environments import py_environment
import numpy as np
class CardGameEnv(py_environment.PyEnvironment):
def __init__(self):
self._state = np.array([0], dtype=np.int32)
def _step(self,action):
return np.transition(self._state, reward=1.0, discount=0.9)
Is there any downside to using the second method? Or is this equally valid?
I often do not store data as numpy array for convenience. I sometimes use pandas dataframes, sometimes lists, it depends on how you update your current state.
Nevertheless, storing the state as numpy array is always more efficient, since you do not need to convert the state to numpy array when returning an observation within a transition.

How to accelerate numpy array masking?

I am profiling performance of a piece of Python code, using a line profiler.
In the code, I have a numpy array tt of shape (106906,) and dtype=int64. With the help of the profiler, I find that the the second line below mask[tt]=True is quite slow. Is there anyway to accelerate it? I am on Python 3 if that matters.
mask = np.zeros(100000, dtype='bool')
mask[tt] = True
You can use Numba as #orlevii has suggested:
from numba import njit
#njit
def f(mask,tt):
mask[tt] = True
#Test:
mask = np.zeros(1000000, dtype='bool')
tt = np.random.randint(0,1000000,106906)
f(mask,tt)
A simple %%timeit check suggests that you should expect roughly 3 times faster execution.
Further speed-up can be achieved by utilizing the GPU. An example of how to do it with PyTorch:
import torch
mask = torch.zeros(1000000).type(torch.cuda.FloatTensor)
tt = torch.randint(0,1000000,torch.Size([106906])).type(torch.cuda.LongTensor)
mask[tt] = True
Note that here we use a torch.Tensor object which is the equivalent of numpy.ndarray in PyTorch. Code will run only if you have a GPU (of NVIDIA) with CUDA. Expect x30 speed-up w.r.t your original code on Tesla V100-SXM2.

fastest way to load images in python for processing

I want to load more than 10000 images in my 8gb ram in the form of numpy arrays.So far I have tried cv2.imread,keras.preprocessing.image.load_image,pil,imageio,scipy.I want to do it the fastest way possible but I can't figure out which on is it.
One of the fastest ways is to get your multiprocessors to do your job in Parallel. It brings multiple processors to work on your tasks at the same time when concurrent running isn't an issue. Now the example below is just a simple sketch of how it might look, you can practice with small functions and then integrate them with your own code :
from multiprocessing import Process
#this is the function to be parallelized
def image_load_here(image_path):
pass
if __name__ == '__main__':
#Start the multiprocesses and provide your dataset.
p = Process(target=image_load_here,['img1', 'img2', 'img3', 'img4'])
p.start()
p.join()
Feel free to write, ill try to help.
If you're using keras library in order to create a deep learning model, I suggest you to use image class from keras.preprocessing package.
image class provides a method img_to_array which returns already a numpy array.
Also, it uses NumPy - Numpy internally for all its array manipulations/computations.
train_image = image.load_img(path, target_size = (height, width))
train_image = image.img_to_array(train_image)
import numpy as np
import os
from keras.preprocessing import image
def batch_data_generator(data, indexes):
#indexes is a sub array of index from the data
X = np.zeros((len(indexes), config.IMG_INPUT_SHAPE[0], config.IMG_INPUT_SHAPE[1], config.IMG_INPUT_SHAPE[2]))
Y = np.zeros((len(indexes), len(label_mapping)))
i = 0
for idx in indexes:
image_id = data['X'][idx]
filename = os.path.join('images', str(image_id) + '.jpg')
img = image.load_img(filename, target_size=(300, 300))
X[i] = np.array(img, dtype='float32')
label_id = label_mapping[data['Y'][idx]]
Y[i][label_id] = 1
i += 1
# subtract mean and normalize
for depth in range(3):
X[:, :, :, depth] = (X[:, :, :, depth] - np.mean(X[:, :, :, depth])) / 255
return X, Y

For loops with Dask arrays and/or h5py

I have a time series with over a hundred million rows of data. I am trying to reshape it to include a time window. My sample data is of shape (79499, 9) and I am trying to reshape it to (79979, 10, 9). The following for loop works fine in numpy.
def munge(data, backprop_window):
result = []
for index in range(len(data) - backprop_window):
result.append(data[index: index + backprop_window])
return np.array(result)
X_train = munge(X_train, backprop_window)
I have tried a few variations with dask, but all of them seem to hang without giving any error messages, including this one:
import h5py
import dask.array as da
f1 = h5py.File("data.hdf5")
X_train = f1.create_dataset('X_train',data = X_train, dtype='float32')
x = da.from_array(X_train, chunks=(10000, d.shape[1]))
result = x.compute(munge(x, backprop_window))
Any wise thoughts appreciated.
This doesn't necessarily solve your dask issue, but as a much faster alternative to munge, you could instead use numpy's stride_tricks to create a rolling view into your data (based on example here).
def munge_strides(data, backprop_window):
""" take a rolling view into array by manipulating strides """
from numpy.lib.stride_tricks import as_strided
new_shape = (data.shape[0] - backprop_window,
backprop_window,
data.shape[1])
new_strides = (data.strides[0], data.strides[0], data.strides[1])
return as_strided(data, shape=new_shape, strides=new_strides)
X_train = np.arange(100).reshape(20, 5)
np.array_equal(munge(X_train, backprop_window=3),
munge_strides(X_train, backprop_window=3))
Out[112]: True
as_strided needs to be used very carefully - it is an 'advanced' feature and incorrect parameters can easily lead you into segfaults - see docstring

Joblib parallel write to "shared" numpy sparse matrix

Im trying to compute number of shared neighbors for each node of a very big graph (~1m nodes). Using Joblib Im trying to run it in parallel. But Im worrying about parallel writes to sparse matrix, which supposed to keep all data. Will this piece of code produce consistent results?
vNum = 1259084
NN_Matrix = csc_matrix((vNum, vNum), dtype=np.int8)
def nn_calc_parallel(node_id = None):
i, j = np.unravel_index(node_id, (1259084, 1259084))
NN_Matrix[i, j] = len(np.intersect1d(nx.neighbors(G, i), nx.neighbors(G,j)))
num_cores = multiprocessing.cpu_count()
result = Parallel(n_jobs=num_cores)(delayed(nn_calc_parallel)(i) for i in xrange(vNum**2))
If not, can you help me to solve this?
I needed to do the same work, in my case was just ok to merge the matrixes together into one matrix which you can do this way:
from scipy.sparse import vstack
matrixes = Parallel(n_jobs=-3)(delayed(nn_calc_parallel)(x) for x in documents)
matrix = vstack(matrixes)
Njob-3 means all CPUS except 2, otherwise it might throw some memory errors.

Categories