含义
和线性回归不同,逻辑回归解决的是分类问题,会输出一个介于 0 到 1 之间(不包括 0 和 1)的概率值(非线性变化)
操作步骤
- 加载二分类数据集
- 定义模型结构:带有激活函数(产生非线性变化)的单个神经元
- 训练模型并预测
举例
预测输入值[x,y]的属于哪个分类的概率
import * as tf from '@tensorflow/tfjs'
import { callbacks } from '@tensorflow/tfjs'
import * as tfvis from '@tensorflow/tfjs-vis'
import { getData } from './data' // 二分类数据集生成
window.onload = async () => {
const data = getData(400)
// 可视化数据集
tfvis.render.scatterplot({
name: '逻辑回归训练数据'
},{
values: [
data.filter(p => p.label === 1), // 分类为1 的数组
data.filter(p => p.label === 0) // 分类为0 的数组
] // 这里是嵌套数组,来渲染不同颜色的点
})
// 定义模型:带有激活函数的单个神经元
// 1、初始化模型
const model = tf.sequential()
// 2、添加层, 并设计层的神经元个数、inputShape、激活函数
model.add(tf.layers.dense({
units: 1,
inputShape: [2], // 长度为2的一纬数组 (x, y)
activation: 'sigmoid' // S状弯曲
}))
// 训练数据
// 1、将训练数据转化成tensor
// 2、训练模型
// 3、可视化训练过程
model.compile({
loss: tf.losses.logLoss,
optimizer: tf.train.adam(0.1)
})
const inputs = tf.tensor(data.map((p)=> [p.x, p.y])); // 特征数量为二的转成tensor
const label = tf.tensor(data.map((p)=> p.label)) // 输出分类值
await model.fit(
inputs,
label, {
batchSize: 40,
epochs: 20,
callbacks: tfvis.show.fitCallbacks({
name: '训练过程'
}, ['loss'])
})
window.predict = (form) => {
const pred = model.predict(tf.tensor([[form.x.value*1, form.y.value*1]])) // * 1转化成数字
alert(`预测结果${pred.dataSync()[0]}`)
}
}
二分类数据集生成
正态分布算法
export function getData(numSamples) {
let points = [];
function genGauss(cx, cy, label) {
for (let i = 0; i < numSamples / 2; i++) {
let x = normalRandom(cx);
let y = normalRandom(cy);
// label 分类类别
points.push({ x, y, label });
}
}
genGauss(2, 2, 1);
genGauss(-2, -2, 0);
return points;
}
/** 生成正态分布的随机数
* Samples from a normal distribution. Uses the seedrandom library as the
* random generator.
*
* @param mean The mean. Default is 0.
* @param variance The variance. Default is 1.
*/
function normalRandom(mean = 0, variance = 1) {
let v1, v2, s;
do {
v1 = 2 * Math.random() - 1;
v2 = 2 * Math.random() - 1;
s = v1 * v1 + v2 * v2;
} while (s > 1);
let result = Math.sqrt(-2 * Math.log(s) / s) * v1;
return mean + Math.sqrt(variance) * result;
}