const math = require('mathjs/core').create();
math.import(require('mathjs/lib/type/matrix'));
math.import(require('mathjs/lib/function/arithmetic'));
math.import(require('mathjs/lib/function/matrix'));
var stat = require('pw-stat');
/**
* An LDA object.
* @constructor
* @param {...number[][]} classes - Each parameter is a 2d class array. In each class array, rows are samples, columns are variables.
* @example
* let classifier = new LDA(class1, class2, class3);
*/
function LDA(...classes) {
// Compute pairwise LDA classes (needed for multiclass LDA)
if(classes.length < 2) {
throw new Error('Please pass at least 2 classes');
}
let numberOfPairs = classes.length * (classes.length - 1) / 2;
let pair1 = 0;
let pair2 = 1;
let pairs = new Array(numberOfPairs);
for(let i = 0; i < numberOfPairs; i++){
pairs[i] = computeLdaParams(classes[pair1], classes[pair2], pair1, pair2);
pair2++;
if(pair2 == classes.length) {
pair1++;
pair2 = pair1 + 1;
}
}
this.pairs = pairs;
this.numberOfClasses = classes.length;
}
function computeLdaParams(class1, class2, class1id, class2id) {
let mu1 = math.transpose(stat.mean(class1));
let mu2 = math.transpose(stat.mean(class2));
let pooledCov = math.add(stat.cov(class1), stat.cov(class2));
let theta = math.multiply(math.inv(pooledCov), math.subtract(mu2, mu1));
let b = math.multiply(-1, math.transpose(theta), math.add(mu1, mu2), 1 / 2);
return {
theta: theta,
b: b,
class1id: class1id,
class2id: class2id
}
}
/**
* Project the unknown data point to one dimension.
* Currently only supports binary LDA.
* @param {number[]} point - The data point to be projected.
* @returns {number} value less than 0 if predicted to be in class 1, 0 if exactly in between, greater than 0 if class 2
*/
LDA.prototype.project = function (point) {
if(this.pairs.length != 1) {
throw new Error('LDA project currently only supports 2 classes. LDA classify can be used to perform multiclass classification.');
}
return projectPoint(point, this.pairs[0].theta, this.pairs[0].b);
}
function projectPoint(point, theta, b) {
return math.add(math.multiply(point, theta), b);
}
/**
* Classify an unknown point. Uses a pairwise voting system in the event of multiclass classification.
* @param {number[]} point - The data point to be classified.
* @returns {number} Returns the predicted class. Class numbers range from 0 to (number_of_classes - 1).
*/
LDA.prototype.classify = function(point) {
// In the event of a binary classifier, skip the voting process
if(this.numberOfClasses == 2) {
return projectPoint(point, this.pairs[0].theta, this.pairs[0].b) <= 0 ? 0 : 1;
}
// Start each class with 0 votes
let votes = new Array(this.numberOfClasses);
for(let i = 0; i < this.numberOfClasses; i++) {
votes[i] = 0;
}
// Allow each pair to cast a vote
for(let i = 0; i < this.pairs.length; i++) {
let params = this.pairs[i];
let projection = projectPoint(point, params.theta, params.b);
if(projection <= 0) {
votes[params.class1id]++;
} else {
votes[params.class2id]++;
}
}
// Find the winning class
let classificaion = 0;
let maxVotes = votes[0];
for(let i = 1; i < votes.length; i++){
if(votes[i] > maxVotes) {
classificaion = i;
maxVotes = votes[i];
}
}
return classificaion;
}
module.exports = LDA;