Skip to content Skip to sidebar Skip to footer

How To Restore Pretrained Checkpoint For Current Model In Tensorflow?

I have a pretrained checkpoint. And now I'm trying to restore this pretrained model to the current network. However, variable names are different. Tensorflow document says that usi

Solution 1:

You can use tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)to get a list of all variable names in current graph. You also can specify scope.

tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES, scope='a')

You can use tf.train.list_variables(ckpt_file) to get a list of all variables in checkpoint.

Suppose you have variable b in your checkpoint, and you want to load inside tf.variable_scope('a') under name a/b. To do that you just define it

with tf.variable_scope('a'):
    b=tf.get_variable(......)

And load

saver = tf.train.Saver({'v2': b})

with tf.Session() as sess:
    saver.restore(sess, ckpt_file))
    print(b)

This will output

<tf.Variable 'a/b:0' shape dtype>

Edit: As mentioned earlier you can get variable names with

vars_dict = {}
for var_current in tf.global_variables():
    print(var_current)
    print(var_current.op.name) # this gets only name

for var_ckpt in tf.train.list_variables(ckpt):
    print(var_ckpt[0]) this gets only name

When you know exact names of all variables you can assign whatever value you need, provided variables have same shape and dtype So to get a dict

vars_dict[var_ckpt[0]) = tf.get_variable(var_current.op.name, shape) # remember to specify shape, you can always get it from var_current 

You can construct this dictionary either explicitly or in any kind of loop you'll see fit. And then you pass it to saver

saver = tf.train.Saver(vars_dict)

Post a Comment for "How To Restore Pretrained Checkpoint For Current Model In Tensorflow?"