Linear regression models relationships between inputs and outputs.
In this example, 5 points are plotted. The Y-coordinate has some random variations. The goal of the neural network that performs the linear regression, is to plot a line that fits all the points as good as it can.
After 60 iterations, the points are resetted and the algorithm tries the linear regression again.
Here is the code:
let gameWidth;
let gameHeight;
let X;
let Y;
let cellSize;
let cost;
let w = 2.8;
let b = -.4;
let iterations = 0;
function predict(input) {
return input * w + b;
}
function resetPoints() {
iterations = 0;
X = [-8, -4, 0, 4, 8];
Y = [
-4 + Math.random() * 5,
-2 + Math.random() * 5,
-0 + Math.random() * 5,
2 + Math.random() * 5,
4 + Math.random() * 5,
];
}
function setup() {
noStroke();
frameRate(24);
gameWidth = document.getElementById("p5container").clientWidth;
gameHeight = document.getElementById("p5container").clientHeight;
createCanvas(gameWidth, gameHeight).parent('p5container');
cellSize = gameWidth / 40;
textSize(cellSize);
resetPoints();
}
function train() {
let error = 0;
for (let i of X) {
const y = predict(i);
error += (y - i) ** 2;
}
cost = error / X.length;
let dw = 0;
let db = 0;
for (let i in X) {
dw += 2 * (w * X[i] + b - Y[i]) * X[i];
db += 2 * (w * X[i] + b - Y[i]);
}
w = w - 0.01 * dw * (1 / X.length);
b = b - 0.01 * db * (1 / X.length);
iterations++;
if (iterations > 60) {
resetPoints();
}
}
function draw() {
train();
background("#252525");
strokeWeight(1);
stroke("#845EC2");
line(width / 2, 0, gameWidth / 2, gameHeight);
line(0, gameHeight / 2, gameWidth, gameHeight / 2);
stroke("#F9F871");
const x1 = -12;
const y1 = predict(x1);
const x2 = 12;
const y2 = predict(x2);
line(x1 * cellSize + gameWidth / 2, gameHeight / 2 - y1 * cellSize, x2 * cellSize + gameWidth / 2, gameHeight / 2 - y2 * cellSize);
noStroke();
fill("#C34A36");
for (let i in X) {
const x = X[i];
const y = Y[i];
ellipse(x * cellSize + gameWidth / 2, gameHeight / 2 - y * cellSize, cellSize * .75);
}
fill("#FFC75F");
text(`w: ${w}`, cellSize, cellSize * 2);
text(`b: ${b}`, cellSize, cellSize * 3.5);
text(`cost: ${cost}`, cellSize, cellSize * 5);
text(`iterations: ${iterations}`, cellSize, cellSize * 6.5);
}