tf.get_collection()



Tf Get_collection



この関数は、keyとscopeの2つの引数を取ります。

Args:



  • 1.key:コレクションのキー。たとえば、GraphKeysクラスには、コレクションの多くの標準名が含まれています。
  • 2.scope :(オプション)指定されている場合、結果のリストは、re.matchを使用してname属性が一致するアイテムのみを含むようにフィルタリングされます。 name属性のないアイテムは、スコープが指定されている場合は返されません。choiceまたはre.matchは、特別なトークンのないスコープがプレフィックスでフィルタリングされることを意味します。

例えば:

# 'My-TensorFlow-tutorials-master/02 CIFAR10/cifar10.py' code variables = tf.get_collection(tf.GraphKeys.VARIABLES) for i in variables: print(i) >>>

Tf.get_collectionは、キーのすべての値を一覧表示します。




さらに:

tf.GraphKeysのポイントの後には、多くのクラスが続く可能性があります。
たとえば、VARIABLESクラス(すべての変数を含む)、
REGULARIZATION_LOSSESのように。

特定のtf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)の使用:

def easier_network(x, reg): ''' A network based on tf.contrib.learn, with input `x`. ''' with tf.variable_scope('EasyNet'): out = layers.flatten(x) out = layers.fully_connected(out, num_outputs=200, weights_initializer = layers.xavier_initializer(uniform=True), weights_regularizer = layers.l2_regularizer(scale=reg), activation_fn = tf.nn.tanh) out = layers.fully_connected(out, num_outputs=200, weights_initializer = layers.xavier_initializer(uniform=True), weights_regularizer = layers.l2_regularizer(scale=reg), activation_fn = tf.nn.tanh) out = layers.fully_connected(out, num_outputs=10, # Because there are ten digits! weights_initializer = layers.xavier_initializer(uniform=True), weights_regularizer = layers.l2_regularizer(scale=reg), activation_fn = None) return out def main(_): mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) x = tf.placeholder(tf.float32, [None, 784]) y_ = tf.placeholder(tf.float32, [None, 10]) # Make a network with regularization y_conv = easier_network(x, FLAGS.regu) weights = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'EasyNet') print('') for w in weights: shp = w.get_shape().as_list() print('- {} shape:{} size:{}'.format(w.name, shp, np.prod(shp))) print('') reg_ws = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES, 'EasyNet') for w in reg_ws: shp = w.get_shape().as_list() print('- {} shape:{} size:{}'.format(w.name, shp, np.prod(shp))) print('') # Make the loss function `loss_fn` with regularization. cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y_conv)) loss_fn = cross_entropy + tf.reduce_sum(reg_ws) train_step = tf.train.AdamOptimizer(1e-4).minimize(loss_fn) main() >>> - EasyNet/fully_connected/weights:0 shape:[784, 200] size:156800 - EasyNet/fully_connected/biases:0 shape:[200] size:200 - EasyNet/fully_connected_1/weights:0 shape:[200, 200] size:40000 - EasyNet/fully_connected_1/biases:0 shape:[200] size:200 - EasyNet/fully_connected_2/weights:0 shape:[200, 10] size:2000 - EasyNet/fully_connected_2/biases:0 shape:[10] size:10 - EasyNet/fully_connected/kernel/Regularizer/l2_regularizer:0 shape:[] size:1.0 - EasyNet/fully_connected_1/kernel/Regularizer/l2_regularizer:0 shape:[] size:1.0 - EasyNet/fully_connected_2/kernel/Regularizer/l2_regularizer:0 shape:[] size:1.0

による:



for w in reg_ws: shp = ....

このコードの出力は次のことを示しています
グラフ上のすべての正則化は、tf.GraphKeys.REGULARIZATION_LOSSESに一元的に保存されます。

コレクションの詳細については、以下を参照してください。
http://blog.csdn.net/shenxiaolu1984/article/details/52815641

tf.GraphKeys.REGULARIZATION_LOSSESの詳細については、以下を参照してください。
https://gxnotes.com/article/178205.html