近年大人気(?)のSVM[Support Vector Machine]を、Rubyから簡単に使う方法です。
SVMのライブラリとしてはlibsvmが著名です。これをRubyから触るためのgemは複数ありますが、ここではrb-libsvmを使いました。
環境はDebian wheezy, Rubyのバージョンは1.9.3p194です。
導入
まずlibsvm。
sudo apt-get install libsvm-dev libsvm3
続いてrb-libsvm。
sudo gem install rb-libsvm
なお今回、結果を視覚的に出したいので、gnuplotも入れておきます。x11-gnuplot が無くてこの後はまりました。
sudo apt-get install gnuplot x11-gnuplot sudo gem install gnuplot
学習データの用意
お題
今回は、適当に点(x, y)を打った時、こんな曲線の上か下かを判定するマシーンを作ることを目指します。色々とついている係数は、画面に収まりが良いようにしているだけで深い意味はありません。
もちろん実際の問題に適用するときは、このような答えの関数は未知です。
訓練データの点を準備
ここから、Rubyプログラムを書きます。まずは点を表すクラスです。
@label は、その点が曲線の上か下かを示す数値です。座標(@x, @y)と、その時の答え(@label)によって、libsvmに訓練をさせます。
class Point attr_reader :x, :y, :label def initialize(x, y) @x, @y = x, y @label = calc_label(x, y) end def calc_label(x, y) fx = x + 50*Math.sin(x/15.0) (y > fx) ? 1 : 2 end end
訓練データは、今回は500個くらい用意してみましょう。乱数で適当に点を打ちまくります。
points = Array.new(500){ Point.new(rand(300), rand(300)) }
訓練データを可視化
どのように点が作られたのかを見てみましょう。ここでgnuplotを使います。
require 'gnuplot' points = Array.new(500){ Point.new(rand(300), rand(300)) } Gnuplot.open do |gp| Gnuplot::Plot.new(gp) do |plot| plot.title("points") plot.size("ratio 1 1") plot.xlabel("x") plot.ylabel("y") plot.xrange("[0:300]") plot.yrange("[0:300]") p1, p2 = points.partition{|p| p.label == 1 } plot.data << Gnuplot::Dataset.new([p1.collect(&:x), p1.collect(&:y)]) plot.data << Gnuplot::Dataset.new([p2.collect(&:x), p2.collect(&:y)]) plot.data << Gnuplot::Dataset.new("x + 50*sin(x/15)") do |d| d.with = "line" d.notitle end end end
訓練 (Train)
あまり説明するところもない(というかできない)ので一気に。
svm_typeまたはkernel_typeによるのかもしれませんが、以下の例では教師データexamplesを正規化する必要はありませんでした。
require 'libsvm' # SVMパラメータ parameter = Libsvm::SvmParameter.new.tap { |p| p.svm_type = 0 # C_SVC p.kernel_type = 2 # RBF p.cache_size = 100 # in MB p.eps = 0.000001 p.degree = 3 p.c = 1 p.nu = 0.5 p.gamma = 0.001 p.p = 0.1 } # 教師データ(座標) examples = points.collect{ |p| Libsvm::Node.features(p.x, p.y) } # 応答(教師データに対する答え) labels = points.collect(&:label) # 訓練 problem = Libsvm::Problem.new problem.set_examples(labels, examples) model = Libsvm::Model.train(problem, parameter)
訓練結果はファイルに保存しておくことができます。次回以降は訓練を飛ばし、すぐに次述べるPredictを行うことができます。
model.save("train.txt")
予測 (Predict)
predictの方法
たとえば(50, 200) の点が曲線より上(1)なのか下(2)なのかを判定するときは、以下のようになります。predictの返り値predが1または2になります。
query = Libsvm::Node.features(50, 200) pred = model.predict(query)
答え合わせ
300px四方に隙間なく点を打ちまくって、どんな結果になるか見てみましょう。
Gnuplot.open dp |gp| Gnuplot::Plot.new(gp) do |plot| plot.title("points") plot.size("ratio 1 1") plot.xlabel("x") plot.ylabel("y") plot.xrange("[0:300]") plot.yrange("[0:300]") p1 = [] p2 = [] for x in 0...300 for y in 0...300 query = Libsvm::Node.features(x, y) pred = model.predict(query) if pred == 1 p1 << Point.new(x, y) elsif pred == 2 p2 << Point.new(x, y) end end end plot.data << Gnuplot::Dataset.new([p1.collect(&:x), p1.collect(&:y)]) do |d| d.with = "dot" end plot.data << Gnuplot::Dataset.new([p2.collect(&:x), p2.collect(&:y)]) do |d| d.with = "dot" end end end
ちょっとずれが大きいですね。まんべんなく点を打ったわけではないので、境界近くの点のムラ次第で出っ張ったり入り込んだりします。
ここで、教師データの数を500から10倍の5000にすると、以下のようになります。まだ若干よれよれしていますが、かなり真の答えに近づきました。
コード一覧
# -*- coding: utf-8 -*- require 'libsvm' require 'gnuplot' # 点クラス class Point attr_reader :x, :y, :label def initialize(x, y) @x, @y = x, y @label = calc_label(x, y) end def calc_label(x, y) fx = x + 50*Math.sin(x/15.0) (y > fx) ? 1 : 2 end end # 点の初期化 points = Array.new(5000){ Point.new(rand(300), rand(300)) } # pointsを見てみる Gnuplot.open do |gp| Gnuplot::Plot.new(gp) do |plot| plot.title("points") plot.size("ratio 1 1") plot.xlabel("x") plot.ylabel("y") plot.xrange("[0:300]") plot.yrange("[0:300]") p1, p2 = points.partition{|p| p.label == 1} plot.data << Gnuplot::DataSet.new([p1.collect(&:x), p1.collect(&:y)]) plot.data << Gnuplot::DataSet.new([p2.collect(&:x), p2.collect(&:y)]) plot.data << Gnuplot::DataSet.new("x + 50*sin(x/15)") do |d| d.with = "line" d.notitle end end end # SVMパラメータ parameter = Libsvm::SvmParameter.new.tap { |p| p.svm_type = 0 # C_SVC p.kernel_type = 2 # RBF p.cache_size = 100 # in MB p.eps = 0.000001 p.degree = 3 p.c = 1 p.nu = 0.5 p.gamma = 0.001 p.p = 0.1 } # 教師データ(座標) examples = points.collect {|p| Libsvm::Node.features(p.x, p.y) } # 教師データの応答(答え) labels = points.collect(&:label) # 訓練 problem = Libsvm::Problem.new problem.set_examples(labels, examples) model = Libsvm::Model.train(problem, parameter) #model.save("hoge.train") Gnuplot.open do |gp| Gnuplot::Plot.new(gp) do |plot| plot.title("points") plot.size("ratio 1 1") plot.xlabel("x") plot.ylabel("y") plot.xrange("[0:300]") plot.yrange("[0:300]") p1 = [] p2 = [] for x in 0...300 for y in 0...300 query = Libsvm::Node.features(x, y) pred = model.predict(query) if pred == 1 p1 << Point.new(x, y) elsif pred == 2 p2 << Point.new(x, y) end end end plot.data << Gnuplot::DataSet.new([p1.collect(&:x), p1.collect(&:y)]) do |d| d.with = "dot" end plot.data << Gnuplot::DataSet.new([p2.collect(&:x), p2.collect(&:y)]) do |d| d.with = "dot" end end end