JavaはKeras、Tensorflowモデルを呼び出します



Java Calls Keras Tensorflow Model



Pythonオフライントレーニングモデル、Javaオンライン予測デプロイメントを実装します。 オリジナルを見る

現在、ディープラーニングの主流はPythonを使用して独自のモデルをトレーニングしています。ニューラルネットワークをすばやく構築する機能を提供するフレームワークはたくさんあります。 Kerasは高レベルの構文を提供し、基礎となるものはtensorflowまたはtheanoを使用できます。



ただし、Javaで開発された多くの企業のバックエンドアプリケーションがあります。 Pythonを使用してHTTPインターフェースを提供し、ビジネスレイテンシーの要件が高い場合でも、多少の遅延が発生します。では、Java呼び出しモデルを使用でき、Pythonはモデルをオフラインでトレーニングできますか? (tensorflowは成熟したデプロイメントソリューションも提供します TensorFlowサービング )。

手持ちのKerasで訓練されたモデルがあります。インターネット上のJava呼び出しKerasモデルに関する情報は多くなく、それらのほとんどは反復的であり、あまり詳細ではありません。大まかに2つのオプションがあります。1つはKerasモデル実装をインポートするためのJavaベースの深層学習ライブラリであり、もう1つはtensorflowによって提供されるJavaインターフェイス呼び出しです。



Deeplearning4J

Eclipse Deeplearning4j は、JavaおよびScala用に作成された最初の商用グレードのオープンソースの分散型ディープラーニングライブラリです。 DL4JはHadoopおよびSparkと統合されており、 AI 分散GPUおよびCPUで使用するためのAIからビジネス環境へ。

Deeplearning4jは現在、Kerasトレーニング用のモデルのインポートをサポートしており、構造化データをより簡単に処理するためのPythonのnumpyなどの機能を提供しています。残念ながら、Deeplearning4jは現在Kerasのほとんどのレイヤーしかカバーしていません<2.0. If you are using Keras 2.0 or higher, you may get an error when importing the model.

もっと理解する:
Kerasモデルのインポート:サポートされている機能
KerasからDeeplearning4jへのモデルのインポート



Tensorflow

資料 Javaにはドキュメントがほとんどありませんが、モデルを呼び出すプロセスも非常に簡単です。この方法でモデルを呼び出すには、最初にKerasから派生したモデルをtensorflowのprotobufプロトコルのモデルに変換する必要があります。

1.Kerasのh5モデルがpbモデルに変換されます

Kerasで使用model.save(model.h5)現在のモデルをHDF5形式のファイルとして保存します。
Kerasのバックエンドフレームワークはテンソルフローを使用するため、最初にモデルをpbモデルとしてエクスポートします。 Javaでは、予測のためにモデルを呼び出すだけでよいので、現在のグラフの変数を定数に変更し、トレーニングされた重みを使用します。フリーズグラフのコードは次のとおりです。

def freeze_session(session, keep_var_names=None, output_names=None, clear_devices=True): ''' :param session: session of tensorflow that needs to be converted :param keep_var_names: variables that need to be retained, default all convert constant :param output_names: the name of the output :param clear_devices: Whether to remove device directives for better portability :return: ''' from tensorflow.python.framework.graph_util import convert_variables_to_constants graph = session.graph with graph.as_default(): freeze_var_names = list(set(v.op.name for v in tf.global_variables()).difference(keep_var_names or [])) output_names = output_names or [] # If the output name is specified, copy a new Tensor and name it with the specified name if len(output_names) > 0: for i in range(output_names): # Copy a new Tensor in the current graph, specify the name tf.identity(model.model.outputs[i], name=output_names[i]) output_names += [v.op.name for v in tf.global_variables()] input_graph_def = graph.as_graph_def() if clear_devices: for node in input_graph_def.node: node.device = '' frozen_graph = convert_variables_to_constants(session, input_graph_def, output_names, freeze_var_names) return frozen_graph

この方法では、テンソルを変数グラフにすべて定数に変換し、トレーニングされた重みを使用できます。 output_nameはより重要であり、後でJavaがモデルを呼び出すときに使用されることに注意してください。

Kerasでは、モデルは次のように定義されています。

def create_model(self): input_tensor = Input(shape=(self.maxlen,), name='input') x = Embedding(len(self.text2id) + 1, 200)(input_tensor) x = Bidirectional(LSTM(128))(x) x = Dense(256, activation='relu')(x) x = Dropout(self.dropout)(x) x = Dense(len(self.id2class), activation='softmax', name='output_softmax')(x) model = Model(inputs=input_tensor, outputs=x) model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

次のコードは、定義されたKerasモデルの入力名と出力名を確認できます。これは、後続のJava呼び出しに役立ちます。

print(model.input.op.name) print(model.output.op.name)

Kerasモデルをトレーニングした後、pbモデルに変換します。

from keras import backend as K import tensorflow as tf model.load_model('model.h5') print(model.input.op.name) print(model.output.op.name) # output_names frozen_graph = freeze_session(K.get_session(), output_names=['output']) tf.train.write_graph(frozen_graph, './', 'model.pb', as_text=False) ### Output: # input # output_softmax/Softmax # If you do not customize output_name, the generated output_name of the pb model is output_softmax/Softmax. If you customize it, the custom name is output_name.

実行後、model.pbのモデルが生成されます。これは、後で呼び出されるモデルになります。

2、Java呼び出し

新しいMavenプロジェクトを作成し、tensorflowパッケージをpomにインポートします。

org.tensorflow tensorflow 1.6.0

コアコード:

public void predict() throws Exception { try (Graph graph = new Graph()) { graph.importGraphDef(Files.readAllBytes(Paths.get( 'path/to/model.pb' ))) try (Session sess = new Session(graph)) { // Construct an input yourself float[][] input = {{56, 632, 675, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}} try (Tensor x = Tensor.create(input) // input is the name of the input, output is the name of the output Tensor y = sess.runner().feed('input', x).fetch('output').run().get(0)) { float[][] result = new float[1][y.shape[1]] y.copyTo(result) System.out.println(Arrays.toString(y.shape())) System.out.println(Arrays.toString(result[0])) } } } }

GraphオブジェクトとTensorオブジェクトの両方を渡す必要がありますclose()メソッドは、コードで使用されている占有リソースを明示的に解放します。try-with-resourcesメソッドが実装されます。

この時点で、Kerasオフライントレーニング、Javaオンライン予測機能を実現できます。