Python

【Python】決定係数と相関係数の計算

投稿日:2019年6月27日 更新日:

目的

前回、あるデータXからデータYの関係を単回帰による線形方程式で表すことを行いました。ただし、これはあくまで近似ですので、どれくらい実際のデータをうまく近似(説明)できているかを評価したい、と思うのが自然かと思います。
今回は単回帰でデータを表した際のフィット具合を評価する方法について実践しながら勉強したいと思います。

決定係数とは?

回帰で求めた予測式がどれくらい元のデータにフィットしているかを見るための指標として決定係数というものがあります。

決定係数の考え方としては、以下の通りです。
データ\(x_i\)と\(y_i\)について単回帰を考え、以下のような一次の線形回帰の式が求まったとします。

$$y=ax+b$$

ここであるデータ\(x_i\)を上記の式に代入すると、\(y_{x_i}=ax_i+b\)となります。
この\(y_{x_i}\)は線形回帰の近似式で求まる\(x=x_i\)の時の予測値ということになります。

次にこの予測値と実際のy_iとの誤差を計算すると、以下のようになります。

$$y_i-y_{x_i}=y_i-(ax_i+b)$$

これはある一点の\(x_i\)における誤差なので、これを二乗して全体にわたって足し、平均するとすべてのデータを加味した分散のようなバラつきを評価できる指標になります。それが次の式です。

$$s_{yx}^2 =\frac{1}{n}\sum_{i=1}^n\{ y_i-(ax_i+b) \}^2 $$

これに対し、データ\(y\)の分散\(s_y^2\)は単に以下の式になります。

$$s_y^2 = \frac{1}{n}\sum_{i=1}^n (y_i- \overline{y})^2$$

そしてこれらの差分\(s_r^2\)を以下の式で定義します。

$$s_r^2 = s_y^2-s_{yx}^2$$

この式の意味としては通常のデータ\(y\)の分散から、単回帰でデータ\(y\)を近似した時になお残る分散値を引いていることになります。つまり、残っているのはデータ\(x\)の影響による偏差分という解釈になります。

この\(s_r^2\)と\(s_y^2\)の比をとることで、データ\(y\)の分散のうちのどれだけがデータ\(x\)の影響を受けているかを測れることになりますこれが決定係数\(r^2\)です。

$$r^2 = \frac{s_r^2}{s_y^2}$$

Pythonによる決定係数の計算

実際のデータに対しpythonの計算を計算します。

データは前回、単回帰を計算した時と同様にkaggleの以下で公開されているオープンデータを使用します。 https://www.kaggle.com/testpython/linear-regression

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

#データ読み込み
df = pd.read_csv('C:\\Users\\Yusan\\Desktop\\dataset\\test.csv',engine  ='python')

# scikit-learnの linear_modelを読み込み
from sklearn import linear_model

#インスタンス作成
clf = linear_model.LinearRegression()

# x軸データ
x = df[['x']].values
 
# y軸データ
y = df['y'].values
 
# 単回帰モデルを作成
clf.fit(x,y)
 
# y=ax+bのa
a = clf.coef_

# y=ax+bのb
b = clf.intercept_

#yの標本分散
sy2 = df['y'].var(ddof=False)

#線形回帰式との誤差を計算
d = df['y']-a*df['x']-b

#誤差の2乗平均を計算
syx2 = np.mean(d**2)

#sr^2を計算
sr2 = sy2 - syx2

#決定係数計算
r2 = sr2/sy2

print(r2) #0.9891203611402715

この\(r\)の値が1に近いほどデータ\(y\)がデータ\(x\)の影響を受けているということになります。

相関係数

相関係数\(r\)は決定係数\(r^2\)の平方根に当たります。

$$ r = \begin{cases} \sqrt{r^2}\ \ \ (a>0) \\ -\sqrt{r^2}\ \ \ (a<0) \end{cases} $$

pythonでのコードは以下になります。

np.sqrt(r2) #0.9945453037143513

-Python

執筆者:


comment

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です

関連記事

【Python】単回帰の計算と表示

Contents1 目的2 回帰の式3 単回帰とは?4 最小二乗法5 pythonによる実装6 注意点 目的 ある得られたデータをもとに、今後発生するデータが何かを予測したいという場合があります。例え …

【Python】階乗と順列と組合せ

Contents1 目的2 階乗の計算方法3 順列の計算4 組合せの計算5 まとめ 目的 統計的なデータ分析アプローチの中には確率的な考え方も多く使われます。今回は確率的な手法を用いる際に必要となる基 …

Pythonの文字列操作

Contents1 目的2 文字列の連結3 文字列の抽出4 文字列の抽出(スライス)5 文字列の繰り返し6 文字数取得7 文字列の分割8 文字列の結合9 文字列の置換 目的 pythonの文字列操作に …

【Python】BeautifulSoupでtableが最初の数行しか取得できない場合の対処

Contents1 概要2 環境3 発生事象3.0.1 実行結果4 対処法 概要 PythonでWebスクレイピングをするときの定番であるBeautifulsoupですが、tableを取得しようとした …

【Python】PaSoRiによるICカードの読み取り

Contents1 目的2 環境3 方針4 Step1 : WinUSBのインストール5 Step2 : libusbのインストール6 Step3 : Pythonによる実装7 結果 目的 交通機関な …

言語切り替え

カテゴリー