利用TensorFlow.js,D3.js 和 Web 的力量使训练模型的过程可视化

描述

在这篇文章中,我们将利用 TensorFlow.js,D3.js 和 Web 的力量使训练模型的过程可视化,以预测棒球数据中的坏球(蓝色区域)和好球(橙色区域)。 随着我们的进展,我们将模型在整个训练过程中理解的打击区域可视化。您可以通过访问此 Observable 笔记本在浏览器中运行此模型。

注:Observable 链接

https://beta.observablehq.com/@nkreeger/visualizing-ml-training-using-tensorflow-js-and-baseball-d

如果你不熟悉棒球的击球区,这里有一篇详细的文章。

上面的 GIF 可视化神经网络学习调用坏球(蓝色区域)和好球(橙色区域)在每个训练步骤之后,热图会根据模型的预测进行更新

使用 Observable 直接在浏览器中运行此模型。

注:文章链接

https://beta.observablehq.com/@nkreeger/visualizing-ml-training-using-tensorflow-js-and-baseball-d

体育运动中的高级指标

当今的职业体育环境中充斥着大量的数据。这些数据被团队,业余爱好者和粉丝应用于各种用例中。感谢像 TensorFlow 这样的框架 - 这些数据集已准备好应用于机器学习。

美国职业棒球大联盟先进媒体(MLBAM)的 PITCHf/x

美国职业棒球大联盟先进媒体(MLBAM)发布了一个可供公众研究的大型数据集。该数据集包含有关过去几年在美国职业棒球大联盟比赛中投掷的投球的传感器信息。 利用这个数据集,我们已编写了一个包含 5,000 个样本的训练集(2,500 个坏球和 2,500 个好球)。

以下是训练数据中前几个字段的示例:

注:示例链接

https://gist.github.com/nkreeger/01b5386b522b0cd1f22bc864320f3084#file-baseball-training-data-sample-csv

神经网络

以下是针对打击区域绘制的训练数据的样子。蓝点标记为坏球,橙点标记为好球(此为大联盟裁判员称谓):

利用 TensorFlow.js 构建模型

TensorFlow.js 将机器学习引入 JavaScript 和 Web。 我们将利用这个很棒的框架来构建一个深度神经网络模型。这个模型将能够按大联盟裁判的精准度来称呼好球和坏球。

输入 Input

该模型在 PITCHf / x 的以下字段中进行了训练:

协调球越过本垒的位置('px'和'pz')。

击球手站在垒的哪一侧。

击球区(击球手的躯干)的高度,以英尺为单位。

击球区底部的高度(击球手的膝盖)以英尺为单位。

裁判所称的投球(好球或坏球)的实际标签。

结构 Architecture

该模型将通过使用 TensorFlow.js 图层 API 定义。Layers API 基于 Keras,对以前使用过该框架的人来说应该很熟悉:

1    const model = tf.sequential(); 

2

3    // Two fully connected layers with dropout between each:    

4    model.add(tf.layers.dense({units: 24, activation: 'relu', inputShape: [5]}));    

5    model.add(tf.layers.dropout({rate: 0.01}));    

6    model.add(tf.layers.dense({units: 16, activation: 'relu'}));    

7    model.add(tf.layers.dropout({rate: 0.01}));    

8

9    // Only two classes: "strike" and "ball":    

10    model.add(tf.layers.dense({units: 2, activation: 'softmax'}));    

11

12    model.compile({    

13        optimizer: tf.train.adam(0.01),    

14        loss: 'categoricalCrossentropy',    

15        metrics: ['accuracy']    

16    });    

加载和准备数据

精选的训练集可通过 GitHub gist 获得。需要下载此数据集才能开始将 CSV 数据转换为 TensorFlow.js 用于训练的格式。

注:GitHub gist 链接

https://gist.github.com/nkreeger/43edc6e6daecc2cb02a2dd3293a08f29

1    const data = [];    

2    csvData.forEach((values) => {    

3        // 'logit' data uses the 5 fields:    

4        const x = [];    

5        x.push(parseFloat(values.px));    

6        x.push(parseFloat(values.pz));    

7        x.push(parseFloat(values.sz_top));    

8        x.push(parseFloat(values.sz_bot));    

9        x.push(parseFloat(values.left_handed_batter));    

10        // The label is simply 'is strike' or 'is ball':    

11        const y = parseInt(values.is_strike, 10);    

12        data.push({x: x, y: y});    

13    });    

14    // Shuffle the contents to ensure the model does not always train on the same    

15    // sequence of pitch data:    

16    tf.util.shuffle(data);    

解析 CSV 数据后,需要将 JS 类型转换为 Tensor 批次进行培训和评估。有关此过程的详细信息,请参阅代码实验室。TensorFlow.js 团队正在开发一种新的 Data API,以便将来更容易获取。

注:代码实验室

https://beta.observablehq.com/@nkreeger/visualizing-ml-training-using-tensorflow-js-and-baseball-d#batches

训练模型

让我们把这一切都整合在一起吧。定义了模型,准备好了训练数据,现在我们已经准备好开始训练了。以下异步方法训练一批训练样本并更新热图:

1    // Trains and reports loss+accuracy for one batch of training data:

2    async function trainBatch(index) {

3        const history = await model.fit(batches[index].x, batches[index].y, {

4            epochs: 1,

5            shuffle: false,

6            validationData: [batches[index].x, batches[index].y],

7            batchSize: CONSTANTS.BATCH_SIZE

8        });

9

10        // Don't block the UI frame by using tf.nextFrame()

11        await tf.nextFrame();

12        updateHeatmap();

13        await tf.nextFrame();

14    }    

可视化模型的准确性

使用来自均匀放置在本垒板上方的 4 x 4 英尺栅格的预测矩阵来构建热图。在每个训练步骤之后将该矩阵传递到模型中以检查模型的准确度。使用 D3 库将该预测的结果呈现为热图。

构建预测矩阵

热图中使用的预测矩阵从本垒板的中间开始,向左和向右各延伸 2 英尺。它的范围也从本垒板的底部到 4 英尺高。击打区样本位于本垒板上方 1.5 至 3.5 英尺之间。下图有助于让这些 2d 窗格可视化:

该视觉显示了打击区域和预测矩阵与本垒板和游戏区域相关的位置

将预测矩阵与模型一起使用

每个批次在模型中训练之后,预测矩阵被传递到模型中用以请求矩阵中的好球或坏球预测:

1    function predictZone() {    

2        const predictions = model.predictOnBatch(predictionMatrix.data);    

3        const values = predictions.dataSync();    

4

5        // Sort each value so the higher prediction is the first element in the array:    

6        const results = [];    

7        let index = 0;    

8        for (let i = 0; i < values.length; i++) {    

9            let list = [];    

10            list.push({value: values[index++], strike: 0}); 

11            list.push({value: values[index++], strike: 1});    

12            list = list.sort((a, b) => b.value - a.value);    

13            results.push(list);    

14        }    

15        return results;    

16    }

热图与 D3

现在可以使用 D3 显示预测结果。 来自 50x50 网格中的每一个元素将在 SVG 中呈现为 10px x 10px 的矩形。每个矩形的颜色取决于预测结果(好球或者坏球)以及模型对该结果的确定程度(范围从 50%-100%)。 以下代码段显示了如何从 D3 svg  矩形分组更新数据:

1    function updateHeatmap() {

2        rects.data(generateHeatmapData());

3        rects    

4            .attr('x', (coord) => { return scaleX(coord.x) * CONSTANTS.HEATMAP_SIZE; })

5            .attr('y', (coord) => { return scaleY(coord.y) * CONSTANTS.HEATMAP_SIZE; })

6            .attr('width', CONSTANTS.HEATMAP_SIZE)

7            .attr('height', CONSTANTS.HEATMAP_SIZE)

8            .style('fill', (coord) => {

9                if (coord.strike) {    

10                    return strikeColorScale(coord.value);

11                } else {    

12                    return ballColorScale(coord.value);

13                }

14            });

15    }

有关使用 D3 绘制热图的完整详细信息,请参阅此部分。

注:此部分链接

https://beta.observablehq.com/@nkreeger/visualizing-ml-training-using-tensorflow-js-and-baseball-d#colorDomain

总结

网络上有许多令人惊叹的第三方库和工具,可用于创建视觉效果。将这些与机器学习的强大功能与 TensorFlow.js 相结合,开发人员能够创建一些非常新奇有趣的演示。

打开APP阅读更多精彩内容
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

全部0条评论

快来发表一下你的评论吧 !

×
20
完善资料,
赚取积分