2010年11月6日土曜日

Rubyでベイジアンフィルタを作成

そろそろRubyにも手を出しておこうかということで,Rubyでベイジアンフィルタを作ってみました.

作ったといっても,元々Perlで書かれたものをRubyで書き直しただけです.元のプログラムは,WEB+DB PRESS Vol.56の記事[1]で解説されていたものです.作ったプログラムは,bayes_sample.rbとClassifier.rbの二つです.両方とも以下に記しておきます.

Rubyは,初めて使ったので,お粗末なコードになっているかもしれません.でも,とりあえず動作すると思います.ただし,事前にMecabとMecab-rubyをインストールしておく必要があります.

[1] 伊藤 直也: ベイジアンフィルタに挑戦―未知のデータを学習して分類―, アルゴリズム実践教室 第1回, WEB+DB PRESS, Vol. 56, pp. 134-142, 2010-05



bayes_sample.rb
require 'MeCab'
require 'Classifier'

def text2vec(text)
  mecab = MeCab::Tagger.new
  node = mecab.parseToNode(text)
  vec = Hash.new(0)
  while node do
    if (node.posid >= 1 and node.posid <= 4) or node.posid == "?" then
      vec[node.surface] = vec[node.surface] + 1
    end
    node = node.next
  end
  return vec
end

cl = Classifier.new()
cl.train(text2vec("perlやpythonはスクリプト言語です"), "it")
cl.train(text2vec("perlでベイジアンフィルタを作りました"), "it")
cl.train(text2vec("pythonはニシキヘビ科のヘビの総称"), "science")

print "1, 推定カテゴリ: ", cl.predict(text2vec("perlは楽しい")), "\n"
print "2, 推定カテゴリ: ", cl.predict(text2vec("pythonとperl")), "\n"
print "3, 推定カテゴリ: ", cl.predict(text2vec("pythonとヘビ")), "\n"


Classifier.rb
class Classifier
  def initialize
    @term_count = Hash.new {|h, k| h[k] = Hash.new(0)}
    @category_count = Hash.new(0)
  end

  def train(vec, cat)
    vec.each do |term, count|
      @term_count[term][cat] = @term_count[term][cat] + count
      @category_count[cat] = @category_count[term] + 1
    end
  end

  def predict(vec)
    scores = Hash.new
    @category_count.keys.each do |cat|
      scores[cat] = self.score(vec, cat)
    end
    classes = scores.to_a
    classes.sort! do |a, b|
      (b[1] <=> a[1]) * 2 + (a[0] <=> b[0])
    end
    return classes[0][0]
  end

  def score(vec, cat)
    cat_prob = Math.log(self.cat_prob(cat))
    not_likely = 1.0 / (self.total_term_count() * 10)
    doc_prob = 0.0
    vec.each do |term, count|
      term_prob = self.term_prob(term, cat)
      if term_prob == 0.0 then
        term_prob = not_likely
      end
      doc_prob += Math.log(term_prob) * count;
    end
    return cat_prob + doc_prob
  end

  def cat_prob(cat)
    return @category_count[cat].to_f / self.total_term_count()
  end

  def term_prob(term, cat)
    return self.term_count(term, cat).to_f / @category_count[cat]
  end

  def term_count(term, cat)
    return @term_count[term][cat]
  end

  def total_term_count
    total = 0
    @category_count.values.each do |count|
      total += count;
    end
    return total
  end

  def dump
    p @term_count
    p @category_count
  end
end

0 件のコメント: