Tensorflow conditionals and while loops
Recntly I found myself needing to implement more advanced control flow in some models I have been hacking on in my free time. In past I never really needed any graph conditionals or loops or any combinations thereof, so I had to dive into documentation and read up on them.
This blog post covers tf.cond
and tf.while_loop
control flow operations and was written to document and share my experience learning about them. Both of the operations seem intuitive on the first look, but I got bitten by them so I wanted to document their usage on practical examples so I have something as a reference I can return to in the future should I need to.
tf.cond: simple lambdas
This is basically a slightly modified code from the official documentation. Here we have two constant tensors t1
and t2
and we execute either f1()
or f2()
based on the result of tf.less()
operation:
t1 = tf.constant(1)
t2 = tf.constant(2)
def f1(): return t1+t2
def f2(): return t1-t2
res = tf.cond(tf.less(t1, t2), f1, f2)
with tf.Session() as sess:
print(sess.run(res))
As expected the printed number is 3
since 1 < 2
and thus the f1()
gets executed. It’s worth noting that both lambdas here are single line functions, neither of which accepts any parameters.
tf.while_loop: basics
Let’s start again with a simple, slightly modified example of tf.while_loop
usage borrowed from the official documentation. Once again, we will have two constant tensors t1
and t2
and we will run a loop that will be incrementing t1
while it’s less than t2
:
def cond(t1, t2):
return tf.less(t1, t2)
def body(t1, t2):
return [tf.add(t1, 1), t2]
t1 = tf.constant(1)
t2 = tf.constant(5)
res = tf.while_loop(cond, body, [t1, t2])
with tf.Session() as sess:
print(sess.run(res))
Note that both cond
and body
function must accept as many arguments as there are loop_vars
; in this case our loop_vars
are the constant tensors t1
and t2
. tf.while_loop
then returns the result as a tensor of the same shape os loop_var
(let’s forget about shape_invariants
for now):
[5, 5]
This result makes perfect sense: we keep incrementing the original value (1
) until it’s less than 5
. Once it reaches 5
the tf.while_loop
stops and the last value returned by body
is returned as a result.
tf.while_loop: fixed number of iterations
Now if we wanted a fixed number of iterations we would modify the code such as follows:
def cond(t1, t2, i, iters):
return tf.less(i, iters)
def body(t1, t2, i, iters):
return [tf.add(t1, 1), t2, tf.add(i, 1), iters]
t1 = tf.constant(1)
t2 = tf.constant(5)
iters = tf.constant(3)
res = tf.while_loop(cond, body, [t1, t2, 0, iters])
with tf.Session() as sess:
print(sess.run(res))
There is a couple of things to notice in this code:
- third item in
loop_vars
is0
; this is the value we will be incrementing - loop incrementation happens in
body
function:tf.add(i, 1)
- once again the returned value (in this particular example) has as many elements as there are in
loop_vars
This code prints the following result:
[4, 5, 3, 3]
First we have the t1
value incremented iters
times (3
); we don’t modify t2
in this code; the third parameter is the final increment increment value we started with 0
and finished once iters
(3
) of iterations was reached (this is controlled by cond
function).
tf.while_loop: conditional break
With all the knowledge of tf.cond
and tf.while_loop
we are now well equipped to do conditional loops i.e. writing loops whose behaviour changes based on some condition, sort of like break
clauses in imperative programming. For brevity we will stick the code into a dedicated function called cond_loop
which will return a tensor operation we will run in the session.
This is what our code is going to do:
- we will be looping fixed number of loops set by
iters
- in each loop we will increment our familiar constant tensors
t1
andt2
- in the final loop, we will swap the tensors instead of incrementing them
You can see this in the code below:
def cond_loop(t1, t2, iters):
def cond(t1, t2, i):
return tf.less(i, iters)
def body(t1, t2, i):
def increment(t1, t2):
def f1(): return tf.add(t1, 1), tf.add(t2, 1)
return f1
def swap(t1, t2):
def f2(): return t2, t1
return f2
t1, t2 = tf.cond(tf.less(i+1, iters),
increment(t1, t2),
swap(t1, t2))
return [t1, t2, tf.add(i, 1)]
return tf.while_loop(cond, body, [t1, t2, 0])
t1 = tf.constant(1)
t2 = tf.constant(5)
iters = tf.constant(3)
with tf.Session() as sess:
loop = cond_loop(t1, t2, iters)
print(sess.run(loop))
The main difference between this code and the code we discussed previously is the presence of tf.cond
in the body
function that gets executed in each tf.while_loop
“iteration” (the double quotes here are deliberate: while_loop
calls cond
and body
exactly once). This tf.cond
causes body
to execute swap
lambda instead of increment
at the end of the loop. The return result confirms this:
[7, 3, 3]
We set out to run 3
iterations which will increment the constants 1
and 5
in each iteration except for the last one when we swap both values and return them along with the iteration counter. (1+1+1
, 5+1+1
<-> 7
, 3
).
Summary
TensorFlow
provides powerful data flow control structures. We have hardly scratched on the surface here. You should definitely check out the official documentation and read about things like shape_invariants
, parallel_iterations
or swap_memory
memory parameters that can speed up your code and save you some headaches. Happy hacking!