【Python】決定係数と相関係数の計算
目的
前回、あるデータXからデータYの関係を単回帰による線形方程式で表すことを行いました。ただし、これはあくまで近似ですので、どれくらい実際のデータをうまく近似(説明)できているかを評価したい、と思うのが自然かと思います。
今回は単回帰でデータを表した際のフィット具合を評価する方法について実践しながら勉強したいと思います。
決定係数とは?
回帰で求めた予測式がどれくらい元のデータにフィットしているかを見るための指標として決定係数というものがあります。
決定係数の考え方としては、以下の通りです。
データx_iとy_iについて単回帰を考え、以下のような一次の線形回帰の式が求まったとします。
y=ax+b
ここであるデータx_iを上記の式に代入すると、y_{x_i}=ax_i+bとなります。
このy_{x_i}は線形回帰の近似式で求まるx=x_iの時の予測値ということになります。
次にこの予測値と実際のy_iとの誤差を計算すると、以下のようになります。
y_i-y_{x_i}=y_i-(ax_i+b)
これはある一点のx_iにおける誤差なので、これを二乗して全体にわたって足し、平均するとすべてのデータを加味した分散のようなバラつきを評価できる指標になります。それが次の式です。
s_{yx}^2 =\frac{1}{n}\sum_{i=1}^n\{ y_i-(ax_i+b) \}^2
これに対し、データyの分散s_y^2は単に以下の式になります。
s_y^2 = \frac{1}{n}\sum_{i=1}^n (y_i- \overline{y})^2
そしてこれらの差分s_r^2を以下の式で定義します。
s_r^2 = s_y^2-s_{yx}^2
この式の意味としては通常のデータyの分散から、単回帰でデータyを近似した時になお残る分散値を引いていることになります。つまり、残っているのはデータxの影響による偏差分という解釈になります。
このs_r^2とs_y^2の比をとることで、データyの分散のうちのどれだけがデータxの影響を受けているかを測れることになりますこれが決定係数r^2です。
r^2 = \frac{s_r^2}{s_y^2}
Pythonによる決定係数の計算
実際のデータに対しpythonの計算を計算します。
データは前回、単回帰を計算した時と同様にkaggleの以下で公開されているオープンデータを使用します。 https://www.kaggle.com/testpython/linear-regression
01 02 03 04 05 06 07 08 09 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 | import pandas as pd import numpy as np import matplotlib.pyplot as plt #データ読み込み df = pd.read_csv( 'C:\\Users\\Yusan\\Desktop\\dataset\\test.csv' ,engine = 'python' ) # scikit-learnの linear_modelを読み込み from sklearn import linear_model #インスタンス作成 clf = linear_model.LinearRegression() # x軸データ x = df[[ 'x' ]].values # y軸データ y = df[ 'y' ].values # 単回帰モデルを作成 clf.fit(x,y) # y=ax+bのa a = clf.coef_ # y=ax+bのb b = clf.intercept_ #yの標本分散 sy2 = df[ 'y' ].var(ddof = False ) #線形回帰式との誤差を計算 d = df[ 'y' ] - a * df[ 'x' ] - b #誤差の2乗平均を計算 syx2 = np.mean(d * * 2 ) #sr^2を計算 sr2 = sy2 - syx2 #決定係数計算 r2 = sr2 / sy2 print (r2) #0.9891203611402715 |
このrの値が1に近いほどデータyがデータxの影響を受けているということになります。
相関係数
相関係数rは決定係数r^2の平方根に当たります。
r = \begin{cases} \sqrt{r^2}\ \ \ (a>0) \\ -\sqrt{r^2}\ \ \ (a<0) \end{cases}
pythonでのコードは以下になります。
01 | np.sqrt(r2) #0.9945453037143513 |