友之介的つぶやきブログ

PCやらプログラムやらぶつぶつ言ってます

ID3

Wikipediaの情報を元にpythonでID3の決定木を試してみました。

ID3 - Wikipedia

 

今回は画面に表示するだけにしました。
”余分な変数を削除し”と書いてあった通り、変数が消えてしまいましたが、バグではなくこういうものなんだろうか?

 

 # coding: utf8
import math

label = ['食性', '発生形態', '体温']
data = [['ペンギン', '肉食', '卵生', '恒温', '鳥類'],
    ['ライオン', '肉食', '胎生', '恒温', '哺乳類'],
    ['ウシ', '草食', '胎生', '恒温', '哺乳類'],
    ['トカゲ', '肉食', '卵生', '変温', '爬虫類'],
    ['ブンチョウ', '草食', '卵生', '恒温', '鳥類']]

def getKeys(data, idx):
    return list(set(map(lambda x:x[idx], data)))

def getEntropy(data, keys):
    cls = map(lambda x:x[-1], data)
    denominator = float(len(cls))
    if denominator == 0: return 0

    m = 0
    for key in keys:
        p = cls.count(key) / denominator
        if p != 0:
            m += -p * math.log(p, 3)

    return m

def getExpectation(data, keys, mc, idx, ckeys):
    denominator = float(len(data))
    if denominator == 0: return 0

    m = mc
    for key in ckeys:
        c2 = filter(lambda x:x[idx] == key, data)
        m2 = getEntropy(c2, keys)
        m += -(m2 * (len(c2) / denominator))

    return m

def carateTree(data, keys, idxs, ts=""):
    mc = getEntropy(data, keys)

    maxm = 0
    idx = 0
    for i in idxs:
        m = getExpectation(data, keys, mc, i, getKeys(data, i))
        if maxm < m:
            maxm = m
            idx = i

    print ts + label[idx - 1]
    ts += "\t"
    idxs = filter(lambda x:x != idx, idxs)
    for key in getKeys(data, idx):
        fdata = filter(lambda x:x[idx] == key, data)
        if len(getKeys(fdata, -1)) <= 1:
            print ts + key, fdata[0][-1]
            continue

        print ts + key
        carateTree(fdata, keys, idxs, ts + "\t")


keys = getKeys(data, -1)
idxs = range(1, len(data[0]) - 1)

carateTree(data, keys, idxs)