TensorflowでのTf.cast()の使用法



Tf Cast Usage Tensorflow



tf.cast(x, dtype, name=None)

xのデータ形式をdtypeデータ型に変換します。たとえば、元のxデータ形式はboolです。
パラメータ
x:入力
dtype:変換ターゲットタイプ
名前:名前
戻り値:Tensor

import tensorflow as tf a = tf.Variable([1.0, 1.3, 2.1, 3.41, 4.51]) b = tf.cast(a, dtype=tf.int8) c = tf.cast(a > 3, dtype=tf.bool) d = tf.cast(a > 3, dtype=tf.int8) e = tf.cast(a <2, dtype=tf.float32) sess = tf.Session() sess.run(tf.initialize_all_variables()) print(sess.run(b)) print(sess.run(c)) print(sess.run(d)) print(sess.run(e))

出力:



[1 1 2 3 4] [False False False True True] [0 0 0 1 1] [1. 1. 0. 0. 0.] Process finished with exit code 0