R の多項式回帰

Sheeraz Gul 2024年2月15日
R の多項式回帰

多項式回帰は、独立した x と従属 y の間の関係が n 次多項式としてモデル化される線形回帰として定義できます。 このチュートリアルでは、R で多項式回帰を実行する方法を示します。

R の多項式回帰

多項式回帰は、xy の平均の間の非線形関係に適合します。 多項式または二次項を回帰に追加します。

この回帰は、1つの結果変数と予測子に使用されます。 多項式回帰は、主に次の場合に使用されます。

  1. 流行病の進行
  2. 組織の成長率の計算
  3. 堆積物中の炭素同位体分布

ggplot2 を使用して、R の多項式回帰をプロットできます。このパッケージがまだインストールされていない場合。 まず、インストールする必要があります。

install.packages('ggplot2')

ここでは、多項式回帰の段階的なプロセスを示します。

データ作成

delftstack の生徒のデータを含むデータ フレームを作成します。学習時間数、最終試験の点数、クラスの生徒総数である 60 人です。

例:

#create data frame
delftstack <- data.frame(hours = runif(60, 6, 20), marks=60)
delftstack$marks = delftstack$marks + delftstack$hours^3/160 + delftstack$hours*runif(60, 1, 2)

#view the head of the data
head(delftstack)

このコードは、後で多項式回帰に使用されるデータを作成します。

出力:

      hours     marks
1  7.106636  71.33509
2  8.501039  74.93339
3 18.051042 124.92229
4 19.153316 141.40656
5 18.306620 118.47464
6  6.240467  70.53522

データの可視化

次のステップは、データを視覚化することです。 回帰モデルを作成する前に、学習時間数と最終試験の点数との関係を示す必要があります。

例:

# Visualization
library(ggplot2)

ggplot(delftstack, aes(x=hours, y=marks)) + geom_point()

上記のコードは、データのグラフをプロットします。

可視化プロット

多項式回帰モデルのあてはめ

次のステップは、次数 1 ~ 6 の多項式回帰モデルと k=10 の k 分割交差検証です。

例:

#shuffle data
delftstack.shuffled <- delftstack[sample(nrow(df)),]

# number of k-fold cross-validation
K <- 10

#define the degree of polynomials to fit
degree <- 6

# now create k equal-sized folds
fold <- cut(seq(1,nrow(delftstack.shuffled)),breaks=K,labels=FALSE)

#The object to hold MSE's of models
mse_object = matrix(data=NA,nrow=K,ncol=degree)

#K-fold cross validation
for(i in 1:K){
    #testing and training data
    test_indexes <- which(fold==i,arr.ind=TRUE)
    test_data <- delftstack.shuffled[test_indexes, ]
    train_data <- delftstack.shuffled[-test_indexes, ]

    # using k-fold cv for models evaluation
    for (j in 1:degree){
        fit.train = lm(marks ~ poly(hours,j), data=train_data)
        fit.test = predict(fit.train, newdata=test_data)
        mse_object[i,j] = mean((fit.test-test_data$marks)^2)
    }
}

# MSE for each degree
colMeans(mse_object)

出力:

[1] 26.13112 15.45428 15.87187 16.88782 18.13103 19.10502

6つのモデルがあり、各モデルの MSE は上記のコードの出力に示されています。 この出力は、それぞれ h=1 から h=6 までの角度です。

最小の MSE を持つモデルは、h=2 の MSE 値が他のすべてよりも小さいため、h=2 である多項式回帰モデルになります。

最終モデルの分析

最後に、最終モデルを分析し、最適なモデルの概要を示します。

例:

#fitting the best model
best_model = lm(marks ~ poly(hours,2, raw=T), data=delftstack)

#summary of the best model
summary(best_model)

上記のコードは、最適なモデルの概要を表示します。

出力:

Call:
lm(formula = marks ~ poly(hours, 2, raw = T), data = delftstack)

Residuals:
   Min     1Q Median     3Q    Max
-8.797 -2.598  0.337  2.443  9.872

Coefficients:
                         Estimate Std. Error t value Pr(>|t|)
(Intercept)              68.42847    5.54533  12.340  < 2e-16 ***
poly(hours, 2, raw = T)1 -1.07557    0.93476  -1.151    0.255
poly(hours, 2, raw = T)2  0.22958    0.03577   6.418 2.95e-08 ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Residual standard error: 4.204 on 57 degrees of freedom
Multiple R-squared:  0.9669,    Adjusted R-squared:  0.9657
F-statistic: 831.9 on 2 and 57 DF,  p-value: < 2.2e-16

出力から、マーク = 68.42847 - 1.07557*(時間) + .22958*(時間)2 であることがわかります。 この方程式を使用して、学習時間数に基づいて生徒が何点取得するかを予測できます。

たとえば、生徒が 5 時間勉強した場合、計算は次のようになります。

marks = 68.42847 - 1.07557*(5) + .22958*(5)2

marks = 68.42847 - 1.07557*5 + .22958*25

marks = 68.42847 - 5.37785 + 5.7395

marks = 68.79012

学生が 5 時間勉強すると、最終試験で 68.79012 点を取得します。

最後に、当てはめたモデルをプロットして、生データにどれだけ対応しているかを確認できます。

ggplot(delftstack, aes(x=hours, y=marks)) +
    geom_point() +
    stat_smooth(method='lm', formula = y ~ poly(x,2), size = 1) +
    xlab('Hours Studied') +
    ylab('Marks')

出力 (プロット):

最終プロット

著者: Sheeraz Gul
Sheeraz Gul avatar Sheeraz Gul avatar

Sheeraz is a Doctorate fellow in Computer Science at Northwestern Polytechnical University, Xian, China. He has 7 years of Software Development experience in AI, Web, Database, and Desktop technologies. He writes tutorials in Java, PHP, Python, GoLang, R, etc., to help beginners learn the field of Computer Science.

LinkedIn Facebook

関連記事 - R Regression