简介

前面涉及到的都是二分类的逻辑回归问题,只要输出一个概率就可以对数据集进行二分类,但在实际生活中,对图片进行分类等

鸢尾花(iris)分类

详情

根据鸢尾花花萼和花瓣的长度和宽度对其进行分类。鸢尾属约有 300 个品种,但我们的程序将仅对下列三个品种进行分类:

  • 山鸢尾
  • 维吉尼亚鸢尾
  • 变色鸢尾

维基百科

操作步骤

1、加载IRIS数据集(训练集和验证集)

  1. /**
  2. * @license
  3. * Copyright 2018 Google LLC. All Rights Reserved.
  4. * Licensed under the Apache License, Version 2.0 (the "License");
  5. * you may not use this file except in compliance with the License.
  6. * You may obtain a copy of the License at
  7. *
  8. * http://www.apache.org/licenses/LICENSE-2.0
  9. *
  10. * Unless required by applicable law or agreed to in writing, software
  11. * distributed under the License is distributed on an "AS IS" BASIS,
  12. * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. * See the License for the specific language governing permissions and
  14. * limitations under the License.
  15. * =============================================================================
  16. */
  17. import * as tf from '@tensorflow/tfjs';
  18. export const IRIS_CLASSES =
  19. ['山鸢尾', '变色鸢尾', '维吉尼亚鸢尾'];
  20. export const IRIS_NUM_CLASSES = IRIS_CLASSES.length;
  21. // Iris flowers data. Source:
  22. // https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data
  23. const IRIS_DATA = [
  24. [5.1, 3.5, 1.4, 0.2, 0], [4.9, 3.0, 1.4, 0.2, 0], [4.7, 3.2, 1.3, 0.2, 0],
  25. [4.6, 3.1, 1.5, 0.2, 0], [5.0, 3.6, 1.4, 0.2, 0], [5.4, 3.9, 1.7, 0.4, 0],
  26. [4.6, 3.4, 1.4, 0.3, 0], [5.0, 3.4, 1.5, 0.2, 0], [4.4, 2.9, 1.4, 0.2, 0],
  27. [4.9, 3.1, 1.5, 0.1, 0], [5.4, 3.7, 1.5, 0.2, 0], [4.8, 3.4, 1.6, 0.2, 0],
  28. [4.8, 3.0, 1.4, 0.1, 0], [4.3, 3.0, 1.1, 0.1, 0], [5.8, 4.0, 1.2, 0.2, 0],
  29. [5.7, 4.4, 1.5, 0.4, 0], [5.4, 3.9, 1.3, 0.4, 0], [5.1, 3.5, 1.4, 0.3, 0],
  30. [5.7, 3.8, 1.7, 0.3, 0], [5.1, 3.8, 1.5, 0.3, 0], [5.4, 3.4, 1.7, 0.2, 0],
  31. [5.1, 3.7, 1.5, 0.4, 0], [4.6, 3.6, 1.0, 0.2, 0], [5.1, 3.3, 1.7, 0.5, 0],
  32. [4.8, 3.4, 1.9, 0.2, 0], [5.0, 3.0, 1.6, 0.2, 0], [5.0, 3.4, 1.6, 0.4, 0],
  33. [5.2, 3.5, 1.5, 0.2, 0], [5.2, 3.4, 1.4, 0.2, 0], [4.7, 3.2, 1.6, 0.2, 0],
  34. [4.8, 3.1, 1.6, 0.2, 0], [5.4, 3.4, 1.5, 0.4, 0], [5.2, 4.1, 1.5, 0.1, 0],
  35. [5.5, 4.2, 1.4, 0.2, 0], [4.9, 3.1, 1.5, 0.1, 0], [5.0, 3.2, 1.2, 0.2, 0],
  36. [5.5, 3.5, 1.3, 0.2, 0], [4.9, 3.1, 1.5, 0.1, 0], [4.4, 3.0, 1.3, 0.2, 0],
  37. [5.1, 3.4, 1.5, 0.2, 0], [5.0, 3.5, 1.3, 0.3, 0], [4.5, 2.3, 1.3, 0.3, 0],
  38. [4.4, 3.2, 1.3, 0.2, 0], [5.0, 3.5, 1.6, 0.6, 0], [5.1, 3.8, 1.9, 0.4, 0],
  39. [4.8, 3.0, 1.4, 0.3, 0], [5.1, 3.8, 1.6, 0.2, 0], [4.6, 3.2, 1.4, 0.2, 0],
  40. [5.3, 3.7, 1.5, 0.2, 0], [5.0, 3.3, 1.4, 0.2, 0], [7.0, 3.2, 4.7, 1.4, 1],
  41. [6.4, 3.2, 4.5, 1.5, 1], [6.9, 3.1, 4.9, 1.5, 1], [5.5, 2.3, 4.0, 1.3, 1],
  42. [6.5, 2.8, 4.6, 1.5, 1], [5.7, 2.8, 4.5, 1.3, 1], [6.3, 3.3, 4.7, 1.6, 1],
  43. [4.9, 2.4, 3.3, 1.0, 1], [6.6, 2.9, 4.6, 1.3, 1], [5.2, 2.7, 3.9, 1.4, 1],
  44. [5.0, 2.0, 3.5, 1.0, 1], [5.9, 3.0, 4.2, 1.5, 1], [6.0, 2.2, 4.0, 1.0, 1],
  45. [6.1, 2.9, 4.7, 1.4, 1], [5.6, 2.9, 3.6, 1.3, 1], [6.7, 3.1, 4.4, 1.4, 1],
  46. [5.6, 3.0, 4.5, 1.5, 1], [5.8, 2.7, 4.1, 1.0, 1], [6.2, 2.2, 4.5, 1.5, 1],
  47. [5.6, 2.5, 3.9, 1.1, 1], [5.9, 3.2, 4.8, 1.8, 1], [6.1, 2.8, 4.0, 1.3, 1],
  48. [6.3, 2.5, 4.9, 1.5, 1], [6.1, 2.8, 4.7, 1.2, 1], [6.4, 2.9, 4.3, 1.3, 1],
  49. [6.6, 3.0, 4.4, 1.4, 1], [6.8, 2.8, 4.8, 1.4, 1], [6.7, 3.0, 5.0, 1.7, 1],
  50. [6.0, 2.9, 4.5, 1.5, 1], [5.7, 2.6, 3.5, 1.0, 1], [5.5, 2.4, 3.8, 1.1, 1],
  51. [5.5, 2.4, 3.7, 1.0, 1], [5.8, 2.7, 3.9, 1.2, 1], [6.0, 2.7, 5.1, 1.6, 1],
  52. [5.4, 3.0, 4.5, 1.5, 1], [6.0, 3.4, 4.5, 1.6, 1], [6.7, 3.1, 4.7, 1.5, 1],
  53. [6.3, 2.3, 4.4, 1.3, 1], [5.6, 3.0, 4.1, 1.3, 1], [5.5, 2.5, 4.0, 1.3, 1],
  54. [5.5, 2.6, 4.4, 1.2, 1], [6.1, 3.0, 4.6, 1.4, 1], [5.8, 2.6, 4.0, 1.2, 1],
  55. [5.0, 2.3, 3.3, 1.0, 1], [5.6, 2.7, 4.2, 1.3, 1], [5.7, 3.0, 4.2, 1.2, 1],
  56. [5.7, 2.9, 4.2, 1.3, 1], [6.2, 2.9, 4.3, 1.3, 1], [5.1, 2.5, 3.0, 1.1, 1],
  57. [5.7, 2.8, 4.1, 1.3, 1], [6.3, 3.3, 6.0, 2.5, 2], [5.8, 2.7, 5.1, 1.9, 2],
  58. [7.1, 3.0, 5.9, 2.1, 2], [6.3, 2.9, 5.6, 1.8, 2], [6.5, 3.0, 5.8, 2.2, 2],
  59. [7.6, 3.0, 6.6, 2.1, 2], [4.9, 2.5, 4.5, 1.7, 2], [7.3, 2.9, 6.3, 1.8, 2],
  60. [6.7, 2.5, 5.8, 1.8, 2], [7.2, 3.6, 6.1, 2.5, 2], [6.5, 3.2, 5.1, 2.0, 2],
  61. [6.4, 2.7, 5.3, 1.9, 2], [6.8, 3.0, 5.5, 2.1, 2], [5.7, 2.5, 5.0, 2.0, 2],
  62. [5.8, 2.8, 5.1, 2.4, 2], [6.4, 3.2, 5.3, 2.3, 2], [6.5, 3.0, 5.5, 1.8, 2],
  63. [7.7, 3.8, 6.7, 2.2, 2], [7.7, 2.6, 6.9, 2.3, 2], [6.0, 2.2, 5.0, 1.5, 2],
  64. [6.9, 3.2, 5.7, 2.3, 2], [5.6, 2.8, 4.9, 2.0, 2], [7.7, 2.8, 6.7, 2.0, 2],
  65. [6.3, 2.7, 4.9, 1.8, 2], [6.7, 3.3, 5.7, 2.1, 2], [7.2, 3.2, 6.0, 1.8, 2],
  66. [6.2, 2.8, 4.8, 1.8, 2], [6.1, 3.0, 4.9, 1.8, 2], [6.4, 2.8, 5.6, 2.1, 2],
  67. [7.2, 3.0, 5.8, 1.6, 2], [7.4, 2.8, 6.1, 1.9, 2], [7.9, 3.8, 6.4, 2.0, 2],
  68. [6.4, 2.8, 5.6, 2.2, 2], [6.3, 2.8, 5.1, 1.5, 2], [6.1, 2.6, 5.6, 1.4, 2],
  69. [7.7, 3.0, 6.1, 2.3, 2], [6.3, 3.4, 5.6, 2.4, 2], [6.4, 3.1, 5.5, 1.8, 2],
  70. [6.0, 3.0, 4.8, 1.8, 2], [6.9, 3.1, 5.4, 2.1, 2], [6.7, 3.1, 5.6, 2.4, 2],
  71. [6.9, 3.1, 5.1, 2.3, 2], [5.8, 2.7, 5.1, 1.9, 2], [6.8, 3.2, 5.9, 2.3, 2],
  72. [6.7, 3.3, 5.7, 2.5, 2], [6.7, 3.0, 5.2, 2.3, 2], [6.3, 2.5, 5.0, 1.9, 2],
  73. [6.5, 3.0, 5.2, 2.0, 2], [6.2, 3.4, 5.4, 2.3, 2], [5.9, 3.0, 5.1, 1.8, 2],
  74. ];
  75. /**
  76. * Convert Iris data arrays to `tf.Tensor`s.
  77. *
  78. * @param data The Iris input feature data, an `Array` of `Array`s, each element
  79. * of which is assumed to be a length-4 `Array` (for petal length, petal
  80. * width, sepal length, sepal width).
  81. * @param targets An `Array` of numbers, with values from the set {0, 1, 2}:
  82. * representing the true category of the Iris flower. Assumed to have the same
  83. * array length as `data`.
  84. * @param testSplit Fraction of the data at the end to split as test data: a
  85. * number between 0 and 1.
  86. * @return A length-4 `Array`, with
  87. * - training data as `tf.Tensor` of shape [numTrainExapmles, 4].
  88. * - training one-hot labels as a `tf.Tensor` of shape [numTrainExamples, 3]
  89. * - test data as `tf.Tensor` of shape [numTestExamples, 4].
  90. * - test one-hot labels as a `tf.Tensor` of shape [numTestExamples, 3]
  91. */
  92. function convertToTensors(data, targets, testSplit) {
  93. const numExamples = data.length;
  94. if (numExamples !== targets.length) {
  95. throw new Error('data and split have different numbers of examples');
  96. }
  97. // Randomly shuffle `data` and `targets`.
  98. const indices = [];
  99. for (let i = 0; i < numExamples; ++i) {
  100. indices.push(i);
  101. }
  102. tf.util.shuffle(indices);
  103. const shuffledData = [];
  104. const shuffledTargets = [];
  105. for (let i = 0; i < numExamples; ++i) {
  106. shuffledData.push(data[indices[i]]);
  107. shuffledTargets.push(targets[indices[i]]);
  108. }
  109. // Split the data into a training set and a tet set, based on `testSplit`.
  110. const numTestExamples = Math.round(numExamples * testSplit);
  111. const numTrainExamples = numExamples - numTestExamples;
  112. const xDims = shuffledData[0].length;
  113. // Create a 2D `tf.Tensor` to hold the feature data.
  114. const xs = tf.tensor2d(shuffledData, [numExamples, xDims]);
  115. // Create a 1D `tf.Tensor` to hold the labels, and convert the number label
  116. // from the set {0, 1, 2} into one-hot encoding (.e.g., 0 --> [1, 0, 0]).
  117. const ys = tf.oneHot(tf.tensor1d(shuffledTargets).toInt(), IRIS_NUM_CLASSES);
  118. // Split the data into training and test sets, using `slice`.
  119. const xTrain = xs.slice([0, 0], [numTrainExamples, xDims]);
  120. const xTest = xs.slice([numTrainExamples, 0], [numTestExamples, xDims]);
  121. const yTrain = ys.slice([0, 0], [numTrainExamples, IRIS_NUM_CLASSES]);
  122. const yTest = ys.slice([0, 0], [numTestExamples, IRIS_NUM_CLASSES]);
  123. return [xTrain, yTrain, xTest, yTest];
  124. }
  125. /**
  126. * Obtains Iris data, split into training and test sets.
  127. *
  128. * @param testSplit Fraction of the data at the end to split as test data: a
  129. * number between 0 and 1.
  130. *
  131. * @param return A length-4 `Array`, with
  132. * - training data as an `Array` of length-4 `Array` of numbers.
  133. * - training labels as an `Array` of numbers, with the same length as the
  134. * return training data above. Each element of the `Array` is from the set
  135. * {0, 1, 2}.
  136. * - test data as an `Array` of length-4 `Array` of numbers.
  137. * - test labels as an `Array` of numbers, with the same length as the
  138. * return test data above. Each element of the `Array` is from the set
  139. * {0, 1, 2}.
  140. */
  141. export function getIrisData(testSplit) {
  142. return tf.tidy(() => {
  143. const dataByClass = [];
  144. const targetsByClass = [];
  145. for (let i = 0; i < IRIS_CLASSES.length; ++i) {
  146. dataByClass.push([]);
  147. targetsByClass.push([]);
  148. }
  149. for (const example of IRIS_DATA) {
  150. const target = example[example.length - 1];
  151. const data = example.slice(0, example.length - 1);
  152. dataByClass[target].push(data);
  153. targetsByClass[target].push(target);
  154. }
  155. const xTrains = [];
  156. const yTrains = [];
  157. const xTests = [];
  158. const yTests = [];
  159. for (let i = 0; i < IRIS_CLASSES.length; ++i) {
  160. const [xTrain, yTrain, xTest, yTest] =
  161. convertToTensors(dataByClass[i], targetsByClass[i], testSplit);
  162. xTrains.push(xTrain);
  163. yTrains.push(yTrain);
  164. xTests.push(xTest);
  165. yTests.push(yTest);
  166. }
  167. const concatAxis = 0;
  168. return [
  169. tf.concat(xTrains, concatAxis), tf.concat(yTrains, concatAxis),
  170. tf.concat(xTests, concatAxis), tf.concat(yTests, concatAxis)
  171. ];
  172. });
  173. }

2、定义模型结构

带有 softmax(将输出压缩为和为1的概率)的神经网络

3、训练模型并预测

相比之前损失函数,会增加准确度、训练时损失、训练集和验证集准确度等度量单位
见图:训练集损失/测试集损失 and 训练集准确度/测试集准确度 基本拟合,说明这是有效的训练
WechatIMG447.jpeg

  1. import * as tf from '@tensorflow/tfjs'
  2. import * as tfvis from '@tensorflow/tfjs-vis'
  3. import { getIrisData, IRIS_CLASSES } from './data';
  4. console.log('IRIS_CLASSES', IRIS_CLASSES)
  5. window.onload = async () => {
  6. const [xTrain, yTrain, xTest, yTest] = getIrisData(0.15);
  7. console.log('ssss', getIrisData(0.15))
  8. // 初始化模型
  9. const model = tf.sequential()
  10. model.add(tf.layers.dense({
  11. units: 10,
  12. inputShape: xTrain.shape[1] , // 或 4
  13. activation: 'sigmoid'
  14. }))
  15. model.add(tf.layers.dense({
  16. units: 3, // 因为需要输出3个概率,所以这里是3
  17. activation: 'softmax'
  18. }))
  19. // 准确度度量
  20. model.compile({
  21. loss: 'categoricalCrossentropy', // 交叉熵损失函数(解决多分类问题)
  22. optimizer: tf.train.adam(0.1),
  23. metrics: ['accuracy'] // 准确度
  24. })
  25. await model.fit(xTrain, yTrain, {
  26. epochs: 100,
  27. validationData: [xTrain, yTrain], // 验证集
  28. callbacks: tfvis.show.fitCallbacks(
  29. {
  30. name: '训练效果'
  31. },
  32. ['loss', 'val_loss', 'acc', 'val_acc'],
  33. {
  34. callbacks: ['onEpochEnd']
  35. }
  36. )
  37. })
  38. window.predict = (form) => {
  39. const input = tf.tensor([[
  40. form.a.value * 1, // 花萼长度
  41. form.b.value * 1, // 花萼宽度
  42. form.c.value * 1, // 花瓣长度
  43. form.d.value * 1, // 花瓣宽度
  44. ]]);
  45. const pred = model.predict(input); // pred.argMax(数据维度) 输出概率中的最高值
  46. alert(`预测结果:${IRIS_CLASSES[pred.argMax(1).dataSync(0)]}`);
  47. };
  48. }