OpenCvSharpをつかう その16(SVM)

前回の記事に引き続き、SVMです。今回はOpenCvSharpからOpenCVのopencv2/ml.hpp(実装当時はopencv/ml.h)にある機械学習を使います。

OpenCvSharpをつかう 記事一覧

準備

NuGetで導入した場合は特に追加の手順はありません。以下は手動で導入する場合の話です。

機械学習を提供するのは OpenCvSharp.CPlusPlus.dll です。OpenCvSharp.dllに加え、これを参照に追加してください。

また、OpenCVのmlは昔からC++実装で、C#からの単純なラップが難しいため、間を取り持つOpenCvSharpExtern.dllが必要です。これはその12のBlobと同じです。
このあたりの記事を参考に、ビルドした.exeと同じ場所に置いてあげてください。(「参照設定」ではありません。置くだけですので注意。)
http://schima.hatenablog.com/entries/2009/06/16
http://schima.hatenablog.com/entries/2009/12/01
http://schima.hatenablog.com/entry/20100528/1275043079


学習データの用意

お題

前回の記事と同じにします。適当に点(x, y)を打った時、こんな曲線の上か下かを判定するマシーンを作ることを目指します。色々とついている係数は、画面に収まりが良いようにしているだけで深い意味はありません。
 y = x + 50 \sin \left( \frac{x}{15} \right)
f:id:Schima:20131105122545p:plain

関数の定義

上のグラフで示す関数を定義しておきます。C#一般の命名規則的に良くないですが、数学っぽくf(x)。

static double f(double x)
{
    return x + 50 * Math.Sin(x/15.0);
}

訓練データの点を準備

500個適当に点を打ちまくります。
pointsが座標、responsesがそれに対する応答(曲線の上か下か)とします。

CvPoint2D32f[] points = new CvPoint2D32f[500];
int[] responses = new int[points.Length];

Random rand = new Random();
for (int i = 0; i < responses.Length; i++)
{
    double x = rand.Next(0, 300);
    double y = rand.Next(0, 300);
    points[i] = new CvPoint2D32f(x, y);
    responses[i] = (y > f(x)) ? 1 : 2;
}

訓練データを可視化

どのように点が作られたのかを見てみましょう。頑張って自分で描画します。y軸方向は下向きなので注意です。

using (IplImage pointsPlot = new IplImage(300, 300, BitDepth.U8, 3))
{
    pointsPlot.Zero();
    // pointsとresponsesの様子を描画
    for (int i = 0; i < points.Length; i++)
    {
        int x = (int)points[i].X;
        int y = (int)(300 - points[i].Y);
        int res = responses[i];
        CvColor color = (res == 1) ? CvColor.Red : CvColor.GreenYellow;
        pointsPlot.Circle(x, y, 2, color, -1);
    }
    // 答えの関数の曲線も描く
    for (int x = 1; x < 300; x++)
    {
        int y1 = (int)(300 - f(x-1));
        int y2 = (int)(300 - f(x));
        pointsPlot.Line(x-1, y1, x, y2, CvColor.LightBlue, 1);
    }
    CvWindow.ShowImages(pointsPlot);
}

f:id:Schima:20131110231806p:plain

訓練 (Train)

訓練データを普通の配列で用意していましたが、OpenCVのmlでは入力をCvMat(またはcv::Mat)で用意する必要があります。その6で紹介したように、CvMatのコンストラクタに配列を指定すればデータを共有できます。dataMatの初期化が特にポイントだと思います。

また、OpenCVSVMでは、入力を0~1の数値で正規化する必要があります。今回は座標を0~300の範囲でランダムに打っているので、300で割っています。本当はCvMatインスタンスを別に用意して、300で割った結果をそこに入れる・・・といういつもの書き方が必要ですが、GCを信じてoperator/を使っても構いません。+ - * / どれも用意しています。

// CvMatに移し替え
CvMat dataMat = new CvMat(points.Length, 2, MatrixType.F32C1, points, true);
CvMat resMat = new CvMat(responses.Length, 1, MatrixType.S32C1, responses, true);

// pointsを正規化
dataMat /= 300.0;

// SVMの用意
CvTermCriteria criteria = new CvTermCriteria(1000, 0.000001);
CvSVMParams param = new CvSVMParams(
    SVMType.CSvc,
    SVMKernelType.Rbf,
    100.0,  // degree
    100.0,  // gamma
    1.0, // coeff0
    1.0, // c
    0.5, // nu
    0.1, // p
    null,
    criteria);
CvSVM svm = new CvSVM();
svm.Train(dataMat, resMat, null, null, param);

予測 (Predict)

Predictの方法

入力座標を正規化させて訓練したので、問い合わせも正規化したものを投げなければなりません。

float[] sample = {x/300f, y/300f};
CvMat sampleMat = new CvMat(1, 2, MatrixType.F32C1, sample);
int ret = (int)svm.Predict(sampleMat);

答え合わせ

300px四方に隙間なく点を打ちまくって、どんな結果になるか見てみましょう。

// 使ってみる
using (IplImage retPlot = new IplImage(300, 300, BitDepth.U8, 3))
{
    retPlot.Zero();

    for (int x = 0; x < 300; x++)
    {
        for (int y = 0; y < 300; y++)
        {
            float[] sample = {x/300f, y/300f};
            CvMat sampleMat = new CvMat(1, 2, MatrixType.F32C1, sample);

            int ret = (int)svm.Predict(sampleMat);
            CvRect plotRect = new CvRect(x, 300 - y, 1, 1);
            if (ret == 1)
                retPlot.Rectangle(plotRect, CvColor.Red);
            else if (ret == 2)
                retPlot.Rectangle(plotRect, CvColor.GreenYellow);
        }
    }
    CvWindow.ShowImages(retPlot);
}

f:id:Schima:20131110235526p:plain

5000点にしたとき

やはり精度は良くなります。
f:id:Schima:20131110234635p:plain
f:id:Schima:20131110235535p:plain

コード一覧

using OpenCvSharp;
using OpenCvSharp.CPlusPlus;

class Program
{
    static double f(double x)
    {
        return x + 50 * Math.Sin(x/15.0);
    }

    static void Main()
    {
        // 訓練データ            
        CvPoint2D32f[] points = new CvPoint2D32f[500];
        int[] responses = new int[points.Length];
        Random rand = new Random();
        for (int i = 0; i < responses.Length; i++)
        {
            double x = rand.Next(0, 300);
            double y = rand.Next(0, 300);
            points[i] = new CvPoint2D32f(x, y);
            responses[i] = (y > f(x)) ? 1 : 2;
        }

        // dataとresponsesの様子を描画
        using (IplImage pointsPlot = new IplImage(300, 300, BitDepth.U8, 3))
        {
            pointsPlot.Zero();                
            for (int i = 0; i < points.Length; i++)
            {
                int x = (int)points[i].X;
                int y = (int)(300 - points[i].Y);
                int res = responses[i];
                CvColor color = (res == 1) ? CvColor.Red : CvColor.GreenYellow;
                pointsPlot.Circle(x, y, 2, color, -1);
            }
            // 答えの関数の曲線も描く
            for (int x = 1; x < 300; x++)
            {
                int y1 = (int)(300 - f(x - 1));
                int y2 = (int)(300 - f(x));
                pointsPlot.Line(x-1, y1, x, y2, CvColor.LightBlue, 1);
            }
            CvWindow.ShowImages(pointsPlot);
        }

        // 訓練
        CvMat dataMat = new CvMat(points.Length, 2, MatrixType.F32C1, points, true);
        CvMat resMat = new CvMat(responses.Length, 1, MatrixType.S32C1, responses, true);
        using (CvSVM svm = new CvSVM())
        {
            // data正規化
            dataMat /= 300.0;

            CvTermCriteria criteria = new CvTermCriteria(1000, 0.000001);
            CvSVMParams param = new CvSVMParams(
                SVMType.CSvc,
                SVMKernelType.Rbf,
                100.0,  // degree
                100.0,  // gamma
                1.0, // coeff0
                1.0, // c
                0.5, // nu
                0.1, // p
                null,
                criteria);
            svm.Train(dataMat, resMat, null, null, param);

            // 使ってみる
            using (IplImage retPlot = new IplImage(300, 300, BitDepth.U8, 3))
            {
                retPlot.Zero();

                for (int x = 0; x < 300; x++)
                {
                    for (int y = 0; y < 300; y++)
                    {
                        float[] sample = {x/300f, y/300f};
                        CvMat sampleMat = new CvMat(1, 2, MatrixType.F32C1, sample);
                        int ret = (int)svm.Predict(sampleMat);
                        CvRect plotRect = new CvRect(x, 300 - y, 1, 1);
                        if (ret == 1)
                            retPlot.Rectangle(plotRect, CvColor.Red);
                        else if (ret == 2)
                            retPlot.Rectangle(plotRect, CvColor.GreenYellow);
                    }
                }
                CvWindow.ShowImages(retPlot);
            }
        }
    }
}