2014年1月29日水曜日

線形回帰をJavaで実装

文献[1]で紹介されている線形回帰をJavaで実装してみました.

線形回帰は,与えられたデータに適した関数を基底関数の線形結合から求める手法です. 多項式基底とベイズ基底について正規化項ありとなしで関数を求めてみました. 結果は次の図のとおりです.

[1] 中谷 秀洋: 線形回帰を実装してみよう, 機械学習はじめよう 第11回, http://gihyo.jp/dev/serial/01/machine-learning/0011

import java.util.ArrayList;
import java.util.List;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.LUDecomposition;
import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.linear.RealVector;
public class LinearRegression {
List<BasisFunction> basisFunctions = new ArrayList<BasisFunction>();
double lambda = 0.0;
RealVector weightVector;
RealMatrix identityMatrix() {
RealMatrix matrix = new Array2DRowRealMatrix(this.basisFunctions.size(), this.basisFunctions.size());
for (int i = 0; i < this.basisFunctions.size(); i++) {
matrix.setEntry(i, i, 1.0);
}
return matrix;
}
RealVector phiVector(double x) {
RealVector vector = new ArrayRealVector(this.basisFunctions.size());
for (int i = 0; i < this.basisFunctions.size(); i++) {
double term = this.basisFunctions.get(i).calculate(x);
vector.setEntry(i, term);
}
return vector;
};
RealMatrix phiMatrix(List<Example> trainingSet) {
RealMatrix matrix = new Array2DRowRealMatrix(trainingSet.size(), this.basisFunctions.size());
for (int i = 0; i < trainingSet.size(); i++) {
Example example = trainingSet.get(i);
RealVector rowVector = phiVector(example.x);
matrix.setRowVector(i, rowVector);
}
return matrix;
}
RealVector tVector(List<Example> trainingSet) {
RealVector vector = new ArrayRealVector(trainingSet.size());
for (int i = 0; i < trainingSet.size(); i++) {
Example example = trainingSet.get(i);
vector.setEntry(i, example.y);
}
return vector;
}
public void learn(List<Example> trainingSet) {
RealMatrix I = this.identityMatrix();
RealMatrix PHI = this.phiMatrix(trainingSet);
RealMatrix PHI_T = PHI.transpose();
RealVector t = this.tVector(trainingSet);
RealMatrix C = I.scalarMultiply(this.lambda).add(PHI_T.multiply(PHI));
RealVector c = PHI_T.operate(t);
this.weightVector = new LUDecomposition(C).getSolver().solve(c);
}
public double calculate(double x) {
return this.weightVector.dotProduct(this.phiVector(x));
}
public void addBaseFunction(BasisFunction baseFunction) {
this.basisFunctions.add(baseFunction);
}
public void setLambda(double lambda) {
this.lambda = lambda;
}
public static interface BasisFunction {
double calculate(double x);
}
public static class Example {
double x;
double y;
Example(double x, double y) {
this.x = x;
this.y = y;
}
}
}
import java.util.ArrayList;
import java.util.List;
import org.junit.Test;
public class LinearRegressionTest {
@SuppressWarnings("serial")
static List<LinearRegression.Example> trainingSet = new ArrayList<LinearRegression.Example>() {{
add(new LinearRegression.Example(0.02, 0.05));
add(new LinearRegression.Example(0.12, 0.87));
add(new LinearRegression.Example(0.19, 0.94));
add(new LinearRegression.Example(0.27, 0.92));
add(new LinearRegression.Example(0.42, 0.54));
add(new LinearRegression.Example(0.51,-0.11));
add(new LinearRegression.Example(0.64,-0.78));
add(new LinearRegression.Example(0.84,-0.79));
add(new LinearRegression.Example(0.88,-0.89));
add(new LinearRegression.Example(0.99,-0.04));
}};
@Test
public void test1() {
System.out.println("test1");
LinearRegression linearRegression = new LinearRegression();
for (int i = 0; i < 4; i++) {
linearRegression.addBaseFunction(new PolynominalBasis(i));
}
linearRegression.learn(trainingSet);
for (double x = 0.0; x <= 1.0; x += 0.01) {
double y = linearRegression.calculate(x);
System.out.printf("%.2f\t%.2f\n", x, y);
}
}
@Test
public void test2() {
System.out.println("test2");
LinearRegression linearRegression = new LinearRegression();
for (int i = 0; i < 4; i++) {
linearRegression.addBaseFunction(new PolynominalBasis(i));
}
linearRegression.setLambda(0.01);
linearRegression.learn(trainingSet);
for (double x = 0.0; x <= 1.0; x += 0.01) {
double y = linearRegression.calculate(x);
System.out.printf("%.2f\t%.2f\n", x, y);
}
}
@Test
public void test3() {
System.out.println("test3");
int N = 10;
double s = 0.1;
LinearRegression linearRegression = new LinearRegression();
for (int i = 0; i <= N; i++) {
linearRegression.addBaseFunction(new GaussianBasis(i * 1.0 / N, s));
}
linearRegression.setLambda(1.0);
linearRegression.learn(trainingSet);
for (double x = 0.0; x <= 1.0; x += 0.01) {
double y = linearRegression.calculate(x);
System.out.printf("%.2f\t%.2f\n", x, y);
}
}
@Test
public void test4() {
System.out.println("test4");
int N = 10;
double s = 0.1;
LinearRegression linearRegression = new LinearRegression();
for (int i = 0; i <= N; i++) {
linearRegression.addBaseFunction(new GaussianBasis(i * 1.0 / N, s));
}
linearRegression.setLambda(0.01);
linearRegression.learn(trainingSet);
for (double x = 0.0; x <= 1.0; x += 0.01) {
double y = linearRegression.calculate(x);
System.out.printf("%.2f\t%.2f\n", x, y);
}
}
static class PolynominalBasis implements LinearRegression.BasisFunction {
int n;
PolynominalBasis(int n) {
this.n = n;
}
@Override
public double calculate(double x) {
return Math.pow(x, n);
}
}
static class GaussianBasis implements LinearRegression.BasisFunction {
double mu;
double s;
GaussianBasis(double mu, double s) {
this.mu = mu;
this.s = s;
}
@Override
public double calculate(double x) {
return Math.exp(-1.0 * (x - this.mu) * (x - this.mu) / (2.0 * this.s * this.s));
}
}
}

0 件のコメント: