Skip to content

Latest commit

 

History

History
 
 

training

Folders and files

NameName
Last commit message
Last commit date

parent directory

..
 
 
 
 
 
 
 
 
 
 

Training models in Java

Example of training a model (and saving and restoring checkpoints) using the TensorFlow Java API.

Quickstart

  1. Train for a few steps:

    mvn -q compile exec:java -Dexec.args="model/graph.pb checkpoint"
    
  2. Resume training from previous checkpoint and train some more:

    mvn -q exec:java -Dexec.args="model/graph.pb checkpoint"
    
  3. Delete checkpoint:

    rm -rf checkpoint
    

Details

The model in model/graph.pb represents a very simple linear model:

y = x * W + b

The graph.pb file is generated by executing create_graph.py in Python.

The training is orchestrated by src/main/java/Train.java, which generates training data of the form y = 3.0 * x + 2.0 and over time, using gradient descent, the model should "learn" and the value of W should converge to 3.0, and b to 2.0.