RubyでSVMをつかう [rb-libsvm]

近年大人気(?)のSVM[Support Vector Machine]を、Rubyから簡単に使う方法です。

SVMのライブラリとしてはlibsvmが著名です。これをRubyから触るためのgemは複数ありますが、ここではrb-libsvmを使いました。

環境はDebian wheezy, Rubyのバージョンは1.9.3p194です。

導入

まずlibsvm

sudo apt-get install libsvm-dev libsvm3

続いてrb-libsvm

sudo gem install rb-libsvm

なお今回、結果を視覚的に出したいので、gnuplotも入れておきます。x11-gnuplot が無くてこの後はまりました。

sudo apt-get install gnuplot x11-gnuplot 
sudo gem install gnuplot

学習データの用意

お題

今回は、適当に点(x, y)を打った時、こんな曲線の上か下かを判定するマシーンを作ることを目指します。色々とついている係数は、画面に収まりが良いようにしているだけで深い意味はありません。
 y = x + 50 \sin \left( \frac{x}{15} \right)
f:id:Schima:20131105122545p:plain

もちろん実際の問題に適用するときは、このような答えの関数は未知です。

訓練データの点を準備

ここから、Rubyプログラムを書きます。まずは点を表すクラスです。
@label は、その点が曲線の上か下かを示す数値です。座標(@x, @y)と、その時の答え(@label)によって、libsvmに訓練をさせます。

class Point
  attr_reader :x, :y, :label
  def initialize(x, y)
    @x, @y = x, y
    @label = calc_label(x, y)
  end 
  def calc_label(x, y)
    fx = x + 50*Math.sin(x/15.0)
    (y > fx) ? 1 : 2
  end
end

訓練データは、今回は500個くらい用意してみましょう。乱数で適当に点を打ちまくります。

points = Array.new(500){ Point.new(rand(300), rand(300)) }

訓練データを可視化

どのように点が作られたのかを見てみましょう。ここでgnuplotを使います。

require 'gnuplot'

points = Array.new(500){ Point.new(rand(300), rand(300)) }

Gnuplot.open do |gp|
  Gnuplot::Plot.new(gp) do |plot|
    plot.title("points")
    plot.size("ratio 1 1")
    plot.xlabel("x")
    plot.ylabel("y")
    plot.xrange("[0:300]")
    plot.yrange("[0:300]")

    p1, p2 = points.partition{|p| p.label == 1 }

    plot.data << Gnuplot::Dataset.new([p1.collect(&:x), p1.collect(&:y)])
    plot.data << Gnuplot::Dataset.new([p2.collect(&:x), p2.collect(&:y)])
    plot.data << Gnuplot::Dataset.new("x + 50*sin(x/15)") do |d|
      d.with = "line"
      d.notitle
    end
  end
end

f:id:Schima:20131105133329p:plain

訓練 (Train)

あまり説明するところもない(というかできない)ので一気に。

svm_typeまたはkernel_typeによるのかもしれませんが、以下の例では教師データexamplesを正規化する必要はありませんでした。

require 'libsvm'

# SVMパラメータ
parameter = Libsvm::SvmParameter.new.tap { |p|
  p.svm_type = 0 # C_SVC
  p.kernel_type = 2 # RBF
  p.cache_size = 100 # in MB
  p.eps = 0.000001
  p.degree = 3
  p.c = 1
  p.nu = 0.5
  p.gamma = 0.001
  p.p = 0.1
}

# 教師データ(座標)
examples = points.collect{ |p|
  Libsvm::Node.features(p.x, p.y)
}
# 応答(教師データに対する答え)
labels = points.collect(&:label)

# 訓練
problem = Libsvm::Problem.new
problem.set_examples(labels, examples)
model = Libsvm::Model.train(problem, parameter)

訓練結果はファイルに保存しておくことができます。次回以降は訓練を飛ばし、すぐに次述べるPredictを行うことができます。

model.save("train.txt")

予測 (Predict)

predictの方法

たとえば(50, 200) の点が曲線より上(1)なのか下(2)なのかを判定するときは、以下のようになります。predictの返り値predが1または2になります。

query = Libsvm::Node.features(50, 200)
pred = model.predict(query)

答え合わせ

300px四方に隙間なく点を打ちまくって、どんな結果になるか見てみましょう。

Gnuplot.open dp |gp|
  Gnuplot::Plot.new(gp) do |plot|
    plot.title("points")
    plot.size("ratio 1 1")
    plot.xlabel("x")
    plot.ylabel("y")
    plot.xrange("[0:300]")
    plot.yrange("[0:300]")

    p1 = []
    p2 = []

    for x in 0...300
      for y in 0...300
        query = Libsvm::Node.features(x, y)
        pred = model.predict(query)
        if pred == 1
           p1 << Point.new(x, y)
        elsif pred == 2
           p2 << Point.new(x, y)
        end
      end
    end

    plot.data << Gnuplot::Dataset.new([p1.collect(&:x), p1.collect(&:y)]) do |d|
      d.with = "dot"
    end
    plot.data << Gnuplot::Dataset.new([p2.collect(&:x), p2.collect(&:y)]) do |d|
      d.with = "dot"
    end

  end
end

f:id:Schima:20131105140117p:plain

ちょっとずれが大きいですね。まんべんなく点を打ったわけではないので、境界近くの点のムラ次第で出っ張ったり入り込んだりします。

ここで、教師データの数を500から10倍の5000にすると、以下のようになります。まだ若干よれよれしていますが、かなり真の答えに近づきました。
f:id:Schima:20131105140559p:plain

コード一覧

# -*- coding: utf-8 -*-
require 'libsvm'
require 'gnuplot'

# 点クラス
class Point
  attr_reader :x, :y, :label
  def initialize(x, y)
    @x, @y = x, y
    @label = calc_label(x, y)
  end
  def calc_label(x, y)
    fx = x + 50*Math.sin(x/15.0)
    (y > fx) ? 1 : 2
  end
end

# 点の初期化
points = Array.new(5000){ Point.new(rand(300), rand(300)) }
# pointsを見てみる
Gnuplot.open do |gp|
  Gnuplot::Plot.new(gp) do |plot|
    plot.title("points")
    plot.size("ratio 1 1")
    plot.xlabel("x")
    plot.ylabel("y")
    plot.xrange("[0:300]")
    plot.yrange("[0:300]")

    p1, p2 = points.partition{|p| p.label == 1}

    plot.data << Gnuplot::DataSet.new([p1.collect(&:x), p1.collect(&:y)]) 
    plot.data << Gnuplot::DataSet.new([p2.collect(&:x), p2.collect(&:y)]) 

    plot.data << Gnuplot::DataSet.new("x + 50*sin(x/15)") do |d|
      d.with = "line"
      d.notitle
    end
  end
end


# SVMパラメータ
parameter = Libsvm::SvmParameter.new.tap { |p|
  p.svm_type = 0 # C_SVC
  p.kernel_type = 2 # RBF
  p.cache_size = 100 # in MB
  p.eps = 0.000001
  p.degree = 3
  p.c = 1
  p.nu = 0.5
  p.gamma = 0.001
  p.p = 0.1
}

# 教師データ(座標)
examples = points.collect {|p| 
  Libsvm::Node.features(p.x, p.y) 
}
# 教師データの応答(答え)
labels = points.collect(&:label)

# 訓練
problem = Libsvm::Problem.new
problem.set_examples(labels, examples)
model = Libsvm::Model.train(problem, parameter)
#model.save("hoge.train")

Gnuplot.open do |gp|
  Gnuplot::Plot.new(gp) do |plot|
    plot.title("points")
    plot.size("ratio 1 1")
    plot.xlabel("x")
    plot.ylabel("y")
    plot.xrange("[0:300]")
    plot.yrange("[0:300]")

    p1 = []
    p2 = []

    for x in 0...300
      for y in 0...300
        query = Libsvm::Node.features(x, y)
        pred = model.predict(query)
        if pred == 1
          p1 << Point.new(x, y)
        elsif pred == 2
          p2 << Point.new(x, y)
        end
      end
    end

    plot.data << Gnuplot::DataSet.new([p1.collect(&:x), p1.collect(&:y)]) do |d|
      d.with = "dot"
    end
    plot.data << Gnuplot::DataSet.new([p2.collect(&:x), p2.collect(&:y)]) do |d|
      d.with = "dot"
    end

  end
end