This section provides an overview of what tensorflow is, and why a developer might want to use it.
It should also mention any large subjects within tensorflow, and link out to the related topics. Since the Documentation for tensorflow is new, you may need to create initial versions of those related topics.
pred
cannot be just True
or False
, it needs to be a Tensorfn1
and fn2
should return the same number of outputs, with the same types.In the restoring model section above if I understand correctly you build the model and then restore the variables. I believe rebuilding the model is not necessary so long as you add the relevant tensors/placeholders when saving using tf.add_to_collection()
. For example:
tf.add_to_collection('cost_op', cost_op)
Then later you can restore the saved graph and get access to cost_op
using
with tf.Session() as sess:
new_saver = tf.train.import_meta_graph('model.meta')`
new_saver.restore(sess, 'model')
cost_op = tf.get_collection('cost_op')[0]
Even if you don't run tf.add_to_collection()
, you can retrieve your tensors, but the process is a bit more cumbersome, and you may have to do some digging to find the right names for things. For example:
in a script that builds a tensorflow graph, we define some set of tensors lab_squeeze
:
...
with tf.variable_scope("inputs"):
y=tf.convert_to_tensor([[0,1],[1,0]])
split_labels=tf.split(1,0,x,name='lab_split')
split_labels=[tf.squeeze(i,name='lab_squeeze') for i in split_labels]
...
with tf.Session().as_default() as sess:
saver=tf.train.Saver(sess,split_labels)
saver.save("./checkpoint.chk")
we can recall them later on as follows:
with tf.Session() as sess:
g=tf.get_default_graph()
new_saver = tf.train.import_meta_graph('./checkpoint.chk.meta')`
new_saver.restore(sess, './checkpoint.chk')
split_labels=['inputs/lab_squeeze:0','inputs/lab_squeeze_1:0','inputs/lab_squeeze_2:0']
split_label_0=g.get_tensor_by_name('inputs/lab_squeeze:0')
split_label_1=g.get_tensor_by_name("inputs/lab_squeeze_1:0")
There are a number of ways to find the name of a tensor -- you can find it in your graph on tensor board, or you can search through for it with something like:
sess=tf.Session()
g=tf.get_default_graph()
...
x=g.get_collection_keys()
[i.name for j in x for i in g.get_collection(j)] # will list out most, if not all, tensors on the graph
When you have huge model, it is useful to form some groups of tensors in your computational graph, that are connected with each other. For example tf.GraphKeys class contains such standart collections as:
tf.GraphKeys.VARIABLES
tf.GraphKeys.TRAINABLE_VARIABLES
tf.GraphKeys.SUMMARIES
Here is a screen shot of the result of the working example above.
The code and a jupyter notebook version of this working example can be found at the author's repository
I used TensorBoard sintaxis to track the behavior of some parts of the model, cost, train and activation elements.
with tf.name_scope("") as scope:
Imports used:
import numpy as np
import tensorflow as tf
Type of application and language used:
I have used a traditional console implementation app type, developed in Python, to represent the example.
Version of TensorFlow used:
1.0.1
Conceptual academic example/reference extracted from here:
The model can accept any number of inputs, so change the NUM_PREDICTIONS if you want to run more predictions than one. Realize that the Java is using JNI to call into the C++ tensorflow model, so you will see some info messages coming from the model when you run this.