ID3
Wikipediaの情報を元にpythonでID3の決定木を試してみました。
今回は画面に表示するだけにしました。
”余分な変数を削除し”と書いてあった通り、変数が消えてしまいましたが、バグではなくこういうものなんだろうか?
# coding: utf8
import math
label = ['食性', '発生形態', '体温']
data = [['ペンギン', '肉食', '卵生', '恒温', '鳥類'],
['ライオン', '肉食', '胎生', '恒温', '哺乳類'],
['ウシ', '草食', '胎生', '恒温', '哺乳類'],
['トカゲ', '肉食', '卵生', '変温', '爬虫類'],
['ブンチョウ', '草食', '卵生', '恒温', '鳥類']]
def getKeys(data, idx):
return list(set(map(lambda x:x[idx], data)))
def getEntropy(data, keys):
cls = map(lambda x:x[-1], data)
denominator = float(len(cls))
if denominator == 0: return 0
m = 0
for key in keys:
p = cls.count(key) / denominator
if p != 0:
m += -p * math.log(p, 3)
return m
def getExpectation(data, keys, mc, idx, ckeys):
denominator = float(len(data))
if denominator == 0: return 0
m = mc
for key in ckeys:
c2 = filter(lambda x:x[idx] == key, data)
m2 = getEntropy(c2, keys)
m += -(m2 * (len(c2) / denominator))
return m
def carateTree(data, keys, idxs, ts=""):
mc = getEntropy(data, keys)
maxm = 0
idx = 0
for i in idxs:
m = getExpectation(data, keys, mc, i, getKeys(data, i))
if maxm < m:
maxm = m
idx = i
print ts + label[idx - 1]
ts += "\t"
idxs = filter(lambda x:x != idx, idxs)
for key in getKeys(data, idx):
fdata = filter(lambda x:x[idx] == key, data)
if len(getKeys(fdata, -1)) <= 1:
print ts + key, fdata[0][-1]
continue
print ts + key
carateTree(fdata, keys, idxs, ts + "\t")
keys = getKeys(data, -1)
idxs = range(1, len(data[0]) - 1)
carateTree(data, keys, idxs)