Tensorflow conditionals and while loops

Tensorflow conditionals and while loops

Recntly I found myself needing to implement more advanced control flow in some models I have been hacking on in my free time. In past I never really needed any graph conditionals or loops or any combinations thereof, so I had to dive into documentation and read up on them.

This blog post covers tf.cond and tf.while_loop control flow operations and was written to document and share my experience learning about them. Both of the operations seem intuitive on the first look, but I got bitten by them so I wanted to document their usage on practical examples so I have something as a reference I can return to in the future should I need to.

tf.cond: simple lambdas

This is basically a slightly modified code from the official documentation. Here we have two constant tensors t1 and t2 and we execute either f1() or f2() based on the result of tf.less() operation:

t1 = tf.constant(1)
t2 = tf.constant(2)

def f1(): return t1+t2
def f2(): return t1-t2

res = tf.cond(tf.less(t1, t2), f1, f2)

with tf.Session() as sess:

As expected the printed number is 3 since 1 < 2 and thus the f1() gets executed. It’s worth noting that both lambdas here are single line functions, neither of which accepts any parameters.

tf.while_loop: basics

Let’s start again with a simple, slightly modified example of tf.while_loop usage borrowed from the official documentation. Once again, we will have two constant tensors t1 and t2 and we will run a loop that will be incrementing t1 while it’s less than t2:

def cond(t1, t2):
    return tf.less(t1, t2)

def body(t1, t2):
    return [tf.add(t1, 1), t2]

t1 = tf.constant(1)
t2 = tf.constant(5)

res = tf.while_loop(cond, body, [t1, t2])

with tf.Session() as sess:

Note that both cond and body function must accept as many arguments as there are loop_vars; in this case our loop_vars are the constant tensors t1 and t2. tf.while_loop then returns the result as a tensor of the same shape os loop_var (let’s forget about shape_invariants for now):

[5, 5]

This result makes perfect sense: we keep incrementing the original value (1) until it’s less than 5. Once it reaches 5 the tf.while_loop stops and the last value returned by body is returned as a result.

tf.while_loop: fixed number of iterations

Now if we wanted a fixed number of iterations we would modify the code such as follows:

def cond(t1, t2, i, iters):
    return tf.less(i, iters)

def body(t1, t2, i, iters):
    return [tf.add(t1, 1), t2, tf.add(i, 1), iters]

t1 = tf.constant(1)
t2 = tf.constant(5)
iters = tf.constant(3)

res = tf.while_loop(cond, body, [t1, t2, 0, iters])

with tf.Session() as sess:

There is a couple of things to notice in this code: * third item in loop_vars is 0; this is the value we will be incrementing * loop incrementation happens in body function: tf.add(i, 1) * once again the returned value (in this particular example) has as many elements as there are in loop_vars

This code prints the following result:

[4, 5, 3, 3]

First we have the t1 value incremented iters times (3); we don’t modify t2 in this code; the third parameter is the final increment increment value we started with 0 and finished once iters (3) of iterations was reached (this is controlled by cond function).

tf.while_loop: conditional break

With all the knowledge of tf.cond and tf.while_loop we are now well equipped to do conditional loops i.e. writing loops whose behaviour changes based on some condition, sort of like break clauses in imperative programming. For brevity we will stick the code into a dedicated function called cond_loop which will return a tensor operation we will run in the session.

This is what our code is going to do: * we will be looping fixed number of loops set by iters * in each loop we will increment our familiar constant tensors t1 and t2 * in the final loop, we will swap the tensors instead of incrementing them

You can see this in the code below:

def cond_loop(t1, t2, iters):
    def cond(t1, t2, i):
        return tf.less(i, iters)

    def body(t1, t2, i):
        def increment(t1, t2):
            def f1(): return tf.add(t1, 1), tf.add(t2, 1)
            return f1

        def swap(t1, t2):
            def f2(): return t2, t1
            return f2

        t1, t2 = tf.cond(tf.less(i+1, iters),
                         increment(t1, t2),
                         swap(t1, t2))

        return [t1, t2, tf.add(i, 1)]

    return tf.while_loop(cond, body, [t1, t2, 0])

t1 = tf.constant(1)
t2 = tf.constant(5)
iters = tf.constant(3)

with tf.Session() as sess:
    loop = cond_loop(t1, t2, iters)

The main difference between this code and the code we discussed previously is the presence of tf.cond in the body function that gets executed in each tf.while_loop “iteration” (the double quotes here are deliberate: while_loop calls cond and body exactly once). This tf.cond causes body to execute swap lambda instead of increment at the end of the loop. The return result confirms this:

[7, 3, 3]

We set out to run 3 iterations which will increment the constants 1 and 5 in each iteration except for the last one when we swap both values and return them along with the iteration counter. (1+1+1, 5+1+1 <-> 7, 3).


TensorFlow provides powerful data flow control structures. We have hardly scratched on the surface here. You should definitely check out the official documentation and read about things like shape_invariants, parallel_iterations or swap_memory memory parameters that can speed up your code and save you some headaches. Happy hacking!

See also