文献[1]で紹介されている線形回帰をJavaで実装してみました.
線形回帰は,与えられたデータに適した関数を基底関数の線形結合から求める手法です. 多項式基底とベイズ基底について正規化項ありとなしで関数を求めてみました. 結果は次の図のとおりです.

[1] | 中谷 秀洋: 線形回帰を実装してみよう, 機械学習はじめよう 第11回, http://gihyo.jp/dev/serial/01/machine-learning/0011 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.ArrayList; | |
import java.util.List; | |
import org.apache.commons.math3.linear.Array2DRowRealMatrix; | |
import org.apache.commons.math3.linear.ArrayRealVector; | |
import org.apache.commons.math3.linear.LUDecomposition; | |
import org.apache.commons.math3.linear.RealMatrix; | |
import org.apache.commons.math3.linear.RealVector; | |
public class LinearRegression { | |
List<BasisFunction> basisFunctions = new ArrayList<BasisFunction>(); | |
double lambda = 0.0; | |
RealVector weightVector; | |
RealMatrix identityMatrix() { | |
RealMatrix matrix = new Array2DRowRealMatrix(this.basisFunctions.size(), this.basisFunctions.size()); | |
for (int i = 0; i < this.basisFunctions.size(); i++) { | |
matrix.setEntry(i, i, 1.0); | |
} | |
return matrix; | |
} | |
RealVector phiVector(double x) { | |
RealVector vector = new ArrayRealVector(this.basisFunctions.size()); | |
for (int i = 0; i < this.basisFunctions.size(); i++) { | |
double term = this.basisFunctions.get(i).calculate(x); | |
vector.setEntry(i, term); | |
} | |
return vector; | |
}; | |
RealMatrix phiMatrix(List<Example> trainingSet) { | |
RealMatrix matrix = new Array2DRowRealMatrix(trainingSet.size(), this.basisFunctions.size()); | |
for (int i = 0; i < trainingSet.size(); i++) { | |
Example example = trainingSet.get(i); | |
RealVector rowVector = phiVector(example.x); | |
matrix.setRowVector(i, rowVector); | |
} | |
return matrix; | |
} | |
RealVector tVector(List<Example> trainingSet) { | |
RealVector vector = new ArrayRealVector(trainingSet.size()); | |
for (int i = 0; i < trainingSet.size(); i++) { | |
Example example = trainingSet.get(i); | |
vector.setEntry(i, example.y); | |
} | |
return vector; | |
} | |
public void learn(List<Example> trainingSet) { | |
RealMatrix I = this.identityMatrix(); | |
RealMatrix PHI = this.phiMatrix(trainingSet); | |
RealMatrix PHI_T = PHI.transpose(); | |
RealVector t = this.tVector(trainingSet); | |
RealMatrix C = I.scalarMultiply(this.lambda).add(PHI_T.multiply(PHI)); | |
RealVector c = PHI_T.operate(t); | |
this.weightVector = new LUDecomposition(C).getSolver().solve(c); | |
} | |
public double calculate(double x) { | |
return this.weightVector.dotProduct(this.phiVector(x)); | |
} | |
public void addBaseFunction(BasisFunction baseFunction) { | |
this.basisFunctions.add(baseFunction); | |
} | |
public void setLambda(double lambda) { | |
this.lambda = lambda; | |
} | |
public static interface BasisFunction { | |
double calculate(double x); | |
} | |
public static class Example { | |
double x; | |
double y; | |
Example(double x, double y) { | |
this.x = x; | |
this.y = y; | |
} | |
} | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.util.ArrayList; | |
import java.util.List; | |
import org.junit.Test; | |
public class LinearRegressionTest { | |
@SuppressWarnings("serial") | |
static List<LinearRegression.Example> trainingSet = new ArrayList<LinearRegression.Example>() {{ | |
add(new LinearRegression.Example(0.02, 0.05)); | |
add(new LinearRegression.Example(0.12, 0.87)); | |
add(new LinearRegression.Example(0.19, 0.94)); | |
add(new LinearRegression.Example(0.27, 0.92)); | |
add(new LinearRegression.Example(0.42, 0.54)); | |
add(new LinearRegression.Example(0.51,-0.11)); | |
add(new LinearRegression.Example(0.64,-0.78)); | |
add(new LinearRegression.Example(0.84,-0.79)); | |
add(new LinearRegression.Example(0.88,-0.89)); | |
add(new LinearRegression.Example(0.99,-0.04)); | |
}}; | |
@Test | |
public void test1() { | |
System.out.println("test1"); | |
LinearRegression linearRegression = new LinearRegression(); | |
for (int i = 0; i < 4; i++) { | |
linearRegression.addBaseFunction(new PolynominalBasis(i)); | |
} | |
linearRegression.learn(trainingSet); | |
for (double x = 0.0; x <= 1.0; x += 0.01) { | |
double y = linearRegression.calculate(x); | |
System.out.printf("%.2f\t%.2f\n", x, y); | |
} | |
} | |
@Test | |
public void test2() { | |
System.out.println("test2"); | |
LinearRegression linearRegression = new LinearRegression(); | |
for (int i = 0; i < 4; i++) { | |
linearRegression.addBaseFunction(new PolynominalBasis(i)); | |
} | |
linearRegression.setLambda(0.01); | |
linearRegression.learn(trainingSet); | |
for (double x = 0.0; x <= 1.0; x += 0.01) { | |
double y = linearRegression.calculate(x); | |
System.out.printf("%.2f\t%.2f\n", x, y); | |
} | |
} | |
@Test | |
public void test3() { | |
System.out.println("test3"); | |
int N = 10; | |
double s = 0.1; | |
LinearRegression linearRegression = new LinearRegression(); | |
for (int i = 0; i <= N; i++) { | |
linearRegression.addBaseFunction(new GaussianBasis(i * 1.0 / N, s)); | |
} | |
linearRegression.setLambda(1.0); | |
linearRegression.learn(trainingSet); | |
for (double x = 0.0; x <= 1.0; x += 0.01) { | |
double y = linearRegression.calculate(x); | |
System.out.printf("%.2f\t%.2f\n", x, y); | |
} | |
} | |
@Test | |
public void test4() { | |
System.out.println("test4"); | |
int N = 10; | |
double s = 0.1; | |
LinearRegression linearRegression = new LinearRegression(); | |
for (int i = 0; i <= N; i++) { | |
linearRegression.addBaseFunction(new GaussianBasis(i * 1.0 / N, s)); | |
} | |
linearRegression.setLambda(0.01); | |
linearRegression.learn(trainingSet); | |
for (double x = 0.0; x <= 1.0; x += 0.01) { | |
double y = linearRegression.calculate(x); | |
System.out.printf("%.2f\t%.2f\n", x, y); | |
} | |
} | |
static class PolynominalBasis implements LinearRegression.BasisFunction { | |
int n; | |
PolynominalBasis(int n) { | |
this.n = n; | |
} | |
@Override | |
public double calculate(double x) { | |
return Math.pow(x, n); | |
} | |
} | |
static class GaussianBasis implements LinearRegression.BasisFunction { | |
double mu; | |
double s; | |
GaussianBasis(double mu, double s) { | |
this.mu = mu; | |
this.s = s; | |
} | |
@Override | |
public double calculate(double x) { | |
return Math.exp(-1.0 * (x - this.mu) * (x - this.mu) / (2.0 * this.s * this.s)); | |
} | |
} | |
} |
0 件のコメント:
コメントを投稿