含义

和线性回归不同,逻辑回归解决的是分类问题,会输出一个介于 0 到 1 之间(不包括 0 和 1)的概率值(非线性变化)

操作步骤

  • 加载二分类数据集
  • 定义模型结构:带有激活函数(产生非线性变化)的单个神经元
  • 训练模型并预测

举例

预测输入值[x,y]的属于哪个分类的概率

  1. import * as tf from '@tensorflow/tfjs'
  2. import { callbacks } from '@tensorflow/tfjs'
  3. import * as tfvis from '@tensorflow/tfjs-vis'
  4. import { getData } from './data' // 二分类数据集生成
  5. window.onload = async () => {
  6. const data = getData(400)
  7. // 可视化数据集
  8. tfvis.render.scatterplot({
  9. name: '逻辑回归训练数据'
  10. },{
  11. values: [
  12. data.filter(p => p.label === 1), // 分类为1 的数组
  13. data.filter(p => p.label === 0) // 分类为0 的数组
  14. ] // 这里是嵌套数组,来渲染不同颜色的点
  15. })
  16. // 定义模型:带有激活函数的单个神经元
  17. // 1、初始化模型
  18. const model = tf.sequential()
  19. // 2、添加层, 并设计层的神经元个数、inputShape、激活函数
  20. model.add(tf.layers.dense({
  21. units: 1,
  22. inputShape: [2], // 长度为2的一纬数组 (x, y)
  23. activation: 'sigmoid' // S状弯曲
  24. }))
  25. // 训练数据
  26. // 1、将训练数据转化成tensor
  27. // 2、训练模型
  28. // 3、可视化训练过程
  29. model.compile({
  30. loss: tf.losses.logLoss,
  31. optimizer: tf.train.adam(0.1)
  32. })
  33. const inputs = tf.tensor(data.map((p)=> [p.x, p.y])); // 特征数量为二的转成tensor
  34. const label = tf.tensor(data.map((p)=> p.label)) // 输出分类值
  35. await model.fit(
  36. inputs,
  37. label, {
  38. batchSize: 40,
  39. epochs: 20,
  40. callbacks: tfvis.show.fitCallbacks({
  41. name: '训练过程'
  42. }, ['loss'])
  43. })
  44. window.predict = (form) => {
  45. const pred = model.predict(tf.tensor([[form.x.value*1, form.y.value*1]])) // * 1转化成数字
  46. alert(`预测结果${pred.dataSync()[0]}`)
  47. }
  48. }

二分类数据集生成

正态分布算法

  1. export function getData(numSamples) {
  2. let points = [];
  3. function genGauss(cx, cy, label) {
  4. for (let i = 0; i < numSamples / 2; i++) {
  5. let x = normalRandom(cx);
  6. let y = normalRandom(cy);
  7. // label 分类类别
  8. points.push({ x, y, label });
  9. }
  10. }
  11. genGauss(2, 2, 1);
  12. genGauss(-2, -2, 0);
  13. return points;
  14. }
  15. /** 生成正态分布的随机数
  16. * Samples from a normal distribution. Uses the seedrandom library as the
  17. * random generator.
  18. *
  19. * @param mean The mean. Default is 0.
  20. * @param variance The variance. Default is 1.
  21. */
  22. function normalRandom(mean = 0, variance = 1) {
  23. let v1, v2, s;
  24. do {
  25. v1 = 2 * Math.random() - 1;
  26. v2 = 2 * Math.random() - 1;
  27. s = v1 * v1 + v2 * v2;
  28. } while (s > 1);
  29. let result = Math.sqrt(-2 * Math.log(s) / s) * v1;
  30. return mean + Math.sqrt(variance) * result;
  31. }