Creating a custom operation with tf.py_func (CPU only)

Other topics

Basic example

The tf.py_func(func, inp, Tout) operator creates a TensorFlow operation that calls a Python function, func on a list of tensors inp.

See the documentation for tf.py_func(func, inp, Tout).

Warning: The tf.py_func() operation will only run on CPU. If you are using distributed TensorFlow, the tf.py_func() operation must be placed on a CPU device in the same process as the client.

def func(x):
    return 2*x

x = tf.constant(1.)
res = tf.py_func(func, [x], [tf.float32])
# res is a list of length 1

Why to use tf.py_func

The tf.py_func() operator enables you to run arbitrary Python code in the middle of a TensorFlow graph. It is particularly convenient for wrapping custom NumPy operators for which no equivalent TensorFlow operator (yet) exists. Adding tf.py_func() is an alternative to using sess.run() calls inside the graph.

Another way of doing that is to cut the graph in two parts:

# Part 1 of the graph
inputs = ...  # in the TF graph

# Get the numpy array and apply func
val = sess.run(inputs)  # get the value of inputs
output_val = func(val)  # numpy array

# Part 2 of the graph
output = tf.placeholder(tf.float32, shape=...)
train_op = ...

# We feed the output_val to the tensor output
sess.run(train_op, feed_dict={output: output_val})

With tf.py_func this is much easier:

# Part 1 of the graph
inputs = ...

# call to tf.py_func
output = tf.py_func(func, [inputs], [tf.float32])[0]

# Part 2 of the graph
train_op = ...

# Only one call to sess.run, no need of a intermediate placeholder
sess.run(train_op)

Parameters:

ParameterDetails
funcpython function, which takes numpy arrays as its inputs and returns numpy arrays as its outputs
inplist of Tensors (inputs)
Toutlist of tensorflow data types for the outputs of func

Contributors

Topic Id: 3856

Example Ids: 13341,13342

This site is not affiliated with any of the contributors.