最小二乗法によるデータ解析

今回は「最小二乗法」をやります。
ざっくり言うとデータがどんな関数から作られたか予測しよう!ってことです。
名前の由来は最初に出てきたコンセプトで、実際のデータと予測値の誤差を二乗したものを最小にし(近づけ)よう!ってなとこからです。
前提条件はいくつかありますがそんなのは最小二乗法 - Wikipediaに任せましょう。
さて。まず最初に覚えてもらうことがあります。
AI(今回は違いますが)とかの学習に使う元のデータを「データセット」と言います。テストに出します。
関数と言いましたが正確には多項式です。つまり多項式近似をするプログラムと言い換えられます。

理由は後にして具体的な方法を話すと、前述のWikipediaに載っている「正規方程式」を使います。
行列の計算によってまとめて答えをはじき出します。それだけ。
{ \displaystyle \boldsymbol{a} = (G^\textrm{T} G)^{-1} G^\textrm{T} \boldsymbol{y}}を解くだけですね。
とは言っても記号の意味とかわからねえよ!と思うのでWikipediaには書いてありますが一つづつ説明します。
まず、データセットの形式ですが、XとYの1:1対応(YはXの関数だと想定しています)なので、

No X Y
1 0.1 0.0
2 0.2 0.3
3 0.3 0.5
4 0.4 0.3
5 0.5 0.0

ってな感じになります。Noはプログラム中でしか使わないので無視してもらって構いません。
で、{ \displaystyle G }ですが、これはXを作りたい関数にそのまま入れてできた行列です。
今回は{ \displaystyle f(x) = a_1 + a_2 x + a_3 x^2 ... a_n x^\textrm{n-1}}のような多項式を想定しているので、先ほどのデータを使うと

0.1^0 0.1^1 0.1^2
0.2^0 0.2^1 0.2^2
0.3^0 0.3^1 0.3^2
0.4^0 0.4^1 0.4^2
0.5^0 0.5^1 0.5^2

となります。Tは転置行列を表すので、

0.1^0 0.2^0 0.3^0 0.4^0 0.5^0
0.1^1 0.2^1 0.3^1 0.4^1 0.5^1
0.1^2 0.2^2 0.3^2 0.4^2 0.5^2

こんな感じになります。ひっくり返したって感じで、詳しく知りたい人は調べるなり下記のソースコード(transposedMatrix())を参照してください。
で、行列が隣り合っているのは行列の積を表します。書くのが大変なので省略します。dot()参照。
行列の-1乗は逆行列を表します。inverseMatrix()参照。
yは文字通りy成分です。

じゃあそれぞれの理由についてですが私なりの解釈が多分に含まれているのでご留意ください。
まずなぜ誤差を2乗するのか、は-を消すため、大きい誤差をより重大だと捉えるため、が考えられます。
最小二乗法と言っているのにそれを計算していないのは、偏微分する際には使ったが最終的な方程式には出てこない、ということだと考えています。数学に強くないので詳しくは分かりません。
Wikipediaにはあまり良くない方法みたいなことが書かれているけど、なぜ別の方法を使わなかったのかというと、理解できなかった計算的な無駄はありますがどうせコンピュータがやる上、最近のCPUは非常に早いのでわかりやすさを優先しました。

また、ここではDouble型を使っていますが、わずかに誤差が出るのでFloatを推奨します。Floatを使うと消える程度の計算上のバグです。
なので、doubleの精度が必要な状況ではミスが出る可能性があります。
ここでは、学習させるデータと同じサイズのテスト用データも生成して「学習しすぎ」を防ぎます。
「学習しすぎた」状態のことをオーバーフィッティングと言います。過学習、ないし過剰適合とも言われます。
最適な値を出すためには、学習用データとテスト用データのerrorがともに下がっている最大のところを取ります。
以下ソースコードになります。

import java.util.Random;

/*    列↓ 列↓ 列↓ 列↓
   行→ 0   1   2   4
   行→ -1  2   0   1
   行→ 1   0  -2  -1
   行→ 3   1   1   3
   Array[行][列]
*/
public class LeastSquaresMethod {
    public static void main (String[] args){
        int[] power = {0,1,3,5,9};
        int size = 20;
        double[][] dataSet = generateDataSet(size);
        double[][] testData = generateDataSet(size);
        double[][] a = new double[size][1];
        showArray(dataSet);
        showArray(testData);
        for(int i = 0; i < size; i++){
            a[i][0] = dataSet[i][1];
        }
        for(int p = 0; p < power.length; p++){
            power[p]++;
            double[][] g = new double[size][power[p]];
            for(int i = 0; i < size; i++){
                for(int j = 0; j < power[p]; j++){
                    g[i][j] = Math.pow(dataSet[i][0], j);
                }
            }
            System.out.println("Power : " + (power[p] - 1));
            double[][] ans = dot(dot(inverseMatrix(dot(transposedMatrix(g), g)), transposedMatrix(g)), a);
            showArray(ans);
            
            double error = 0d;
            double errorTest = 0d;
            for(int i = 0; i < dataSet.length; i++){
                double tmp = 0d;
                double tmpTest = 0d;
                for(int j = 0; j < ans.length; j++){
                    tmp += ans[j][0] * Math.pow(dataSet[i][0], j);
                    tmpTest += ans[j][0] * Math.pow(testData[i][0], j);
                }
                error += Math.pow((dataSet[i][1] - tmp), 2);
                errorTest += Math.pow((testData[i][1] - tmpTest), 2);
            }
            System.out.println("error     : " + error);
            System.out.println("errorTest : " + errorTest + "\n");
        }
    }
    
    public static void showArray(double[][] array){
        for(int i = 0; i < array.length; i++){
            for(int j = 0; j < array[0].length; j++){
                //System.out.printf("%.3f ", array[i][j]);
                System.out.print(array[i][j] + " ");
            }
            System.out.println();
        }
        System.out.println();
    }
    
    public static double[][] generateDataSet(int size){
        Random rand = new Random();
        double[][] output = new double[size][2];
        for(int i = 0; i < size; i++){
            output[i][0] = (double)i / size;
            output[i][1] = Math.sin(2d * Math.PI * output[i][0]) + rand.nextGaussian() * 0.3d;
        }
        return output;
    }
    
    public static double[][] transposedMatrix(double[][] input){
        double[][] output = new double[input[0].length][input.length];
        for(int i = 0; i < output.length; i++){
            for(int j = 0; j < output[0].length; j++){
                output[i][j] = input[j][i];
            }
        }
        return output;
    }
    
    public static double[][] dot(double[][] input, double[][] input2){
        if(input[0].length != input2.length){
            return null;
        }
        int length = input[0].length;
        double[][] output = new double[input.length][input2[0].length];
        for(int i = 0; i < input.length; i++){
            for(int j = 0; j < input2[0].length; j++){
                double tmp = 0d;
                for(int k = 0; k < length; k++){
                    tmp += input[i][k] * input2[k][j];
                }
                output[i][j] = tmp;
            }
        }
        return output;
    }
    
    public static boolean swapRow(double[][] targetArray, int target, int target2){ // target行とtarget2行を入れ替えます
        int arrayColLength = targetArray[0].length;
        for(int i = 0; i < arrayColLength; i++){
            double tmp = targetArray[target][i];
            targetArray[target][i] = targetArray[target2][i];
            targetArray[target2][i] = tmp;
        }
        return true;
    }
    public static double[][] inverseMatrix(double[][] input){//掃き出し法(ガウス・ジョルダン法)で逆行列を計算します
        if(input.length != input[0].length){
            return null;
        }
        int length = input.length;
        double[][] target = connectArray(input, identityMatrix(length)); //単位行列を結合します
        for(int i = 0; i < length; i++){
            double max = Math.abs(target[i][i]);//行の入れ替えをします
            int maxPiv = i;
            for(int j = i; j < length; j++){
                if(Math.abs(target[j][i]) > max){
                    max = Math.abs(target[j][i]);
                    maxPiv = j;
                }
            }
            swapRow(target, i, maxPiv);
            
            for(int j = 0; j < length; j++){//前進消去します
                if(i == j || target[j][i] == 0){
                        continue;
                }
                double tmp = target[j][i] / target[i][i];
                for(int k = 0; k < length * 2; k++){
                    target[j][k] -= target[i][k] * tmp;
                }
            }
        }
        for(int i = 0; i < length; i++){//左側を単位行列にします
            double tmp = target[i][i];
            for(int j = 0; j < length * 2; j++){
                target[i][j] /= tmp;
            }
        }
        double[][] output = new double[length][length];
        for(int i = 0; i < length; i++){ //右側(計算結果)を切り出します
            for(int j = 0; j < length; j++){
                output[i][j] = target[i][j + length];
            }
        }
        return output;
    }
    
    public static double[][] identityMatrix(int size){ // 単位行列を作ります
        double[][] output = new double[size][size];
        for(int i = 0; i < size; i++){
            output[i][i] = 1;
        }
        return output;
    }
    public static double[][] connectArray(double[][] input, double[][] input2){ //inputにinput2を横に(列を増やして)結合します。行数が同じもののみです
        if(input.length != input2.length){
            return null;
        }
        double[][] output = new double[input.length][input[0].length + input2[0].length];
        for(int i = 0; i < output.length; i++){
            for(int j = 0; j < output[0].length; j++){
                if(j < input[0].length){
                    output[i][j] = input[i][j];
                }
                else{
                    output[i][j] = input2[i][j - input[0].length];
                }
            }
        }
        return output;
    }
}

参考文献:
最小二乗法 - Wikipedia
過剰適合 - Wikipedia
ガウスの消去法 - Wikipedia
行列の乗法 - Wikipedia