IT Learning

実践形式でITのお勉強

Python

【Python】決定係数と相関係数の計算

投稿日:2019年6月27日 更新日:

目的

前回、あるデータXからデータYの関係を単回帰による線形方程式で表すことを行いました。ただし、これはあくまで近似ですので、どれくらい実際のデータをうまく近似(説明)できているかを評価したい、と思うのが自然かと思います。
今回は単回帰でデータを表した際のフィット具合を評価する方法について実践しながら勉強したいと思います。

決定係数とは?

回帰で求めた予測式がどれくらい元のデータにフィットしているかを見るための指標として決定係数というものがあります。

決定係数の考え方としては、以下の通りです。
データ\(x_i\)と\(y_i\)について単回帰を考え、以下のような一次の線形回帰の式が求まったとします。

$$y=ax+b$$

ここであるデータ\(x_i\)を上記の式に代入すると、\(y_{x_i}=ax_i+b\)となります。
この\(y_{x_i}\)は線形回帰の近似式で求まる\(x=x_i\)の時の予測値ということになります。

次にこの予測値と実際のy_iとの誤差を計算すると、以下のようになります。

$$y_i-y_{x_i}=y_i-(ax_i+b)$$

これはある一点の\(x_i\)における誤差なので、これを二乗して全体にわたって足し、平均するとすべてのデータを加味した分散のようなバラつきを評価できる指標になります。それが次の式です。

$$s_{yx}^2 =\frac{1}{n}\sum_{i=1}^n\{ y_i-(ax_i+b) \}^2 $$

これに対し、データ\(y\)の分散\(s_y^2\)は単に以下の式になります。

$$s_y^2 = \frac{1}{n}\sum_{i=1}^n (y_i- \overline{y})^2$$

そしてこれらの差分\(s_r^2\)を以下の式で定義します。

$$s_r^2 = s_y^2-s_{yx}^2$$

この式の意味としては通常のデータ\(y\)の分散から、単回帰でデータ\(y\)を近似した時になお残る分散値を引いていることになります。つまり、残っているのはデータ\(x\)の影響による偏差分という解釈になります。

この\(s_r^2\)と\(s_y^2\)の比をとることで、データ\(y\)の分散のうちのどれだけがデータ\(x\)の影響を受けているかを測れることになりますこれが決定係数\(r^2\)です。

$$r^2 = \frac{s_r^2}{s_y^2}$$

Pythonによる決定係数の計算

実際のデータに対しpythonの計算を計算します。

データは前回、単回帰を計算した時と同様にkaggleの以下で公開されているオープンデータを使用します。 https://www.kaggle.com/testpython/linear-regression

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

#データ読み込み
df = pd.read_csv('C:\\Users\\Yusan\\Desktop\\dataset\\test.csv',engine  ='python')

# scikit-learnの linear_modelを読み込み
from sklearn import linear_model

#インスタンス作成
clf = linear_model.LinearRegression()

# x軸データ
x = df[['x']].values
 
# y軸データ
y = df['y'].values
 
# 単回帰モデルを作成
clf.fit(x,y)
 
# y=ax+bのa
a = clf.coef_

# y=ax+bのb
b = clf.intercept_

#yの標本分散
sy2 = df['y'].var(ddof=False)

#線形回帰式との誤差を計算
d = df['y']-a*df['x']-b

#誤差の2乗平均を計算
syx2 = np.mean(d**2)

#sr^2を計算
sr2 = sy2 - syx2

#決定係数計算
r2 = sr2/sy2

print(r2) #0.9891203611402715

この\(r\)の値が1に近いほどデータ\(y\)がデータ\(x\)の影響を受けているということになります。

相関係数

相関係数\(r\)は決定係数\(r^2\)の平方根に当たります。

$$ r = \begin{cases} \sqrt{r^2}\ \ \ (a>0) \\ -\sqrt{r^2}\ \ \ (a<0) \end{cases} $$

pythonでのコードは以下になります。

np.sqrt(r2) #0.9945453037143513

Related

-Python

執筆者:


comment

メールアドレスが公開されることはありません。 が付いている欄は必須項目です