用 Java 训练深度学习模型,原来能这么简单( 三 )

  • 更新权重:我们会根据选择的优化器(Optimizer)更新每一个在 Block 上参数的值 。
  • DJL 利用了 Trainer 结构体精简了整个过程 。开发者只需要创建 Trainer 并指定对应的 Initializer、Loss 和 Optimizer 即可 。这些参数都是由 TrainingConfig 设定的 。下面我们来看一下具体的参数设置:
    • TrainingListener:这个是对训练过程设定的监听器 。它可以实时反馈每个阶段的训练结果 。这些结果可以用于记录训练过程或者帮助 debug 神经网络训练过程中的问题 。用户也可以定制自己的 TrainingListener 来对训练过程进行监听 。
    DefaultTrainingConfig config = new DefaultTrainingConfig(Loss.softmaxCrossEntropyLoss())    .addEvaluator(new Accuracy())    .addTrainingListeners(TrainingListener.Defaults.logging());try (Trainer trainer = model.newTrainer(config)){    // 训练代码}当训练器产生后,我们可以定义输入的 Shape 。之后就可以调用 fit 函数来进行训练 。fit 函数会对输入数据,训练多个 epoch 是并最终将结果存储在本地目录下 。
    /* * MNIST 包含 28x28 灰度图片并导入成 28 * 28 NDArray 。 * 第一个维度是批大小, 在这里我们设置批大小为 1 用于初始化 。 */Shape inputShape = new Shape(1, Mnist.IMAGE_HEIGHT * Mnist.IMAGE_WIDTH);int numEpoch = 5;String outputDir = "/build/model";// 用输入初始化 trainertrainer.initialize(inputShape);TrainingUtils.fit(trainer, numEpoch, trainingSet, validateSet, outputDir, "mlp");这就是训练过程的全部流程了!用 DJL 训练是不是还是很轻松的?之后看一下输出每一步的训练结果 。如果你用了我们默认的监听器,那么输出是类似于下图:
    [INFO ] - Downloading libmxnet.dylib ...[INFO ] - Training on: cpu().[INFO ] - Load MXNet Engine Version 1.7.0 in 0.131 ms.Training:    100% |████████████████████████████████████████| Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.24, speed: 1235.20 items/secValidating:  100% |████████████████████████████████████████|[INFO ] - Epoch 1 finished.[INFO ] - Train: Accuracy: 0.93, SoftmaxCrossEntropyLoss: 0.24[INFO ] - Validate: Accuracy: 0.95, SoftmaxCrossEntropyLoss: 0.14Training:    100% |████████████████████████████████████████| Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.10, speed: 2851.06 items/secValidating:  100% |████████████████████████████████████████|[INFO ] - Epoch 2 finished.NG [1m 41s][INFO ] - Train: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.10[INFO ] - Validate: Accuracy: 0.97, SoftmaxCrossEntropyLoss: 0.09[INFO ] - train P50: 12.756 ms, P90: 21.044 ms[INFO ] - forward P50: 0.375 ms, P90: 0.607 ms[INFO ] - training-metrics P50: 0.021 ms, P90: 0.034 ms[INFO ] - backward P50: 0.608 ms, P90: 0.973 ms[INFO ] - step P50: 0.543 ms, P90: 0.869 ms[INFO ] - epoch P50: 35.989 s, P90: 35.989 s当训练结果完成后,我们可以用刚才的模型进行推理来识别手写数字 。
    四、最后在这个文章中,我们介绍了深度学习的基本概念,同时还有如何优雅的利用 DJL 构建深度学习模型并进行训练 。




    推荐阅读