Using if condition inside the TensorFlow graph with tf.cond

Other topics

Remarks:

  • pred cannot be just True or False, it needs to be a Tensor
  • The function fn1 and fn2 should return the same number of outputs, with the same types.

Basic example

x = tf.constant(1.)
bool = tf.constant(True)

res = tf.cond(bool, lambda: tf.add(x, 1.), lambda: tf.add(x, 10.))
# sess.run(res) will give you 2.

When f1 and f2 return multiple tensors

The two functions fn1 and fn2 can return multiple tensors, but they have to return the exact same number and types of outputs.

x = tf.constant(1.)
bool = tf.constant(True)

def fn1():
    return tf.add(x, 1.), x

def fn2():
    return tf.add(x, 10.), x

res1, res2 = tf.cond(bool, fn1, fn2)
# tf.cond returns a list of two tensors
# sess.run([res1, res2]) will return [2., 1.]

define and use functions f1 and f2 with parameters

You can pass parameters to the functions in tf.cond() using lambda and the code is as bellow.

x = tf.placeholder(tf.float32)
y = tf.placeholder(tf.float32)
z = tf.placeholder(tf.float32)

def fn1(a, b):
  return tf.mul(a, b)

def fn2(a, b):
  return tf.add(a, b)

pred = tf.placeholder(tf.bool)
result = tf.cond(pred, lambda: fn1(x, y), lambda: fn2(y, z))

Then you can call it as bellowing:

with tf.Session() as sess:
  print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: True})
  # The result is 2.0
  print sess.run(result, feed_dict={x: 1, y: 2, z: 3, pred: False})
  # The result is 5.0

Parameters:

ParameterDetails
preda TensorFlow tensor of type bool
fn1a callable function, with no argument
fn2a callable function, with no argument
name(optional) name for the operation

Contributors

Topic Id: 2628

Example Ids: 8715,8716,26884

This site is not affiliated with any of the contributors.