Python - scikit-learn による初歩の機械学習(Hello, world!)

公開日:2019-06-30 更新日:2019-06-30
[Python]

1. 概要

scikit-learn を使って論理演算の機械学習を行い、論理演算の結果を予測します。



2. 論理演算の機械学習

from sklearn.metrics import accuracy_score

# アルゴリズム
from sklearn.svm       import LinearSVC
from sklearn.neighbors import KNeighborsClassifier
from sklearn.ensemble  import RandomForestClassifier

import warnings
warnings.filterwarnings('ignore')

# 学習データ:入力
in_data = [
    [0, 0],
    [0, 1],
    [1, 0],
    [1, 1]
]

# 学習データ:出力(ラベル)
out_data = [0, 0, 0, 1] # and
#out_data = [0, 1, 1, 1] # or
#out_data = [0, 1, 1, 0] # xor

# アルゴリズムの設定
clf = LinearSVC()
#clf = RandomForestClassifier()
#clf = KNeighborsClassifier(n_neighbors = 1)

# 学習
clf.fit(in_data, out_data)

# テストデータ(予測したいデータ))
test_data = [
    [0, 0], 
    [0, 1], 
    [1, 0], 
    [1, 1],
]

# 予測
result = clf.predict(test_data)
#print("正解:", out_data)
print("予測結果:", result)

#print("正解率 = " , accuracy_score(out_data, result))

3. RGBから色の判定

import pandas as pd

from sklearn.metrics import accuracy_score

from sklearn.svm       import SVC
from sklearn.ensemble  import RandomForestClassifier
from sklearn.svm       import LinearSVC
from sklearn.neighbors import KNeighborsClassifier

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

import warnings
warnings.filterwarnings('ignore')

color_data = [
                      # Idx RGB 0-255 -> 0 or 1
    ['k', 'black'],   # 0   000 
    ['b', 'blue'],    # 1   001 
    ['g', 'green'],   # 2   010 
    ['c', 'cyan'],    # 3   011 
    ['r', 'red'],     # 4   100 
    ['m', 'magenta'], # 5   101 
    ['y', 'yellow'],  # 6   110 
    ['w', 'white'],   # 7   111 
]

def getColorIdx(r, g, b):
    r = 1 if r >= 128 else 0
    g = 1 if g >= 128 else 0
    b = 1 if b >= 128 else 0
    return r * 4 + g * 2 + b

##############################
# 学習データの作成
##############################
color_range = range(0, 256, 16)
color_list = [
        [ getColorIdx(r, g, b), r, g, b ]
        for r in color_range
        for g in color_range
        for b in color_range
]
# 学習データを入力と出力に分割
color_df = pd.DataFrame(
    color_list,
    columns = ['color', 'r', 'g', 'b']
)
in_data  = color_df.loc[:, ["r", "g", "b"]]
out_data = color_df.loc[:, "color"]

##############################
# 学習
##############################
#clf = LinearSVC()
#clf = RandomForestClassifier()
clf = KNeighborsClassifier(n_neighbors = 1)
clf.fit(in_data, out_data)

##############################
# 予測
##############################
# テストデータの作成
test_color_range = range(5, 256, 16)
# 入力
test_in_data = [
    [r, g, b]
    for r in test_color_range
    for g in test_color_range
    for b in test_color_range
]
# 出力(判定用の予測結果として使用)
test_out_data = [
    getColorIdx(r, g, b)
    for r in test_color_range
    for g in test_color_range
    for b in test_color_range
]

# 予測
result = clf.predict(test_in_data)
print("正解率:", accuracy_score(result, test_out_data))

##############################
# 3D散布図の表示
##############################
fig = plt.figure()
ax = Axes3D(fig)

# 視点の角度
ax.view_init(elev = 45, azim = 30) 

# ラベル
ax.set_xlabel("R")
ax.set_ylabel("G")
ax.set_zlabel("B")

# グラフデータ
test_color_df = pd.DataFrame(test_in_data)
test_color_df['color'] = result # 結果列の追加

# 色で絞り込んで1色ごとに描画します
for c in range(0, 8):
    # 色で絞り込み
    data = test_color_df[test_color_df['color'] == c]
    
    r = data[0]
    g = data[1]
    b = data[2]
    colorChar = color_data[c][0] # c == data[3]
    ax.plot(r, g, b, 
            marker = "o", 
            color = colorChar, 
            linestyle = 'None')

plt.show()