LSTM/GRU of TF: Based on the tensorflow framework, use LSTM and GRU algorithm (batch_size tuning comparison) to realize the case of house price regression prediction on the boston house price data set.
Directory
Based on tensorflow framework, use LSTM and GRU algorithm (batch_size tuning comparison) to realize house price regression prediction case for boston house price data set respectively
# 1. Define the data set
# 2. Data preprocessing
# 2.1, Separate features and labels
# 3. Model training and reasoning
# 3.1. Segmentation dataset
# 3.2, data reprocessing
# Reshape the input data into a 3D tensor (number of samples, time steps, number of features)
# 3.3. Model building and training
# 3.4, model training
# draw the loss curve of the training set and test set
# 3.5, model evaluation
# draw predictions
Related articles
LSTM/GRU of TF: Based on tensorflow framework, use LSTM and GRU algorithm (batch_size tuning comparison) to realize house price regression prediction case for boston house price data set respectively
LSTM/GRU of TF: Based on tensorflow framework, use LSTM and GRU algorithm (batch_size tuning comparison) to realize house price regression prediction case implementation code for boston house price data set respectively
LSTM of TF: Based on the tensorflow framework, using the LSTM algorithm (random search and tuning parameters) to realize the case of house price regression prediction on the boston house price data set
LSTM of TF: Based on the tensorflow framework, using the LSTM algorithm (random search and tuning parameters) to realize the house price regression prediction case implementation code for the boston house price data set
LSTM of TF: Based on tensorflow framework, using LSTM algorithm (grid search parameter tuning) to realize house price regression prediction case implementation code for boston house price data set
Using LSTM and GRU algorithms (batch_size tuning comparison) to realize the case of house price regression prediction based on the tensorflow framework for the boston house price data set
# 1. Define the data set
CRIM ZN INDUS CHAS NOX ... TAX PTRATIO B LSTAT target 0 0.00632 18.0 2.31 0.0 0.538 ... 296.0 15.3 396.90 4.98 24.0 1 0.02731 0.0 7.07 0.0 0.469 ... 242.0 17.8 396.90 9.14 21.6 2 0.02729 0.0 7.07 0.0 0.469 ... 242.0 17.8 392.83 4.03 34.7 3 0.03237 0.0 2.18 0.0 0.458 ... 222.0 18.7 394.63 2.94 33.4 4 0.06905 0.0 2.18 0.0 0.458 ... 222.0 18.7 396.90 5.33 36.2 .. ... ... ... ... ... ... ... ... ... ... ... 501 0.06263 0.0 11.93 0.0 0.573 ... 273.0 21.0 391.99 9.67 22.4 502 0.04527 0.0 11.93 0.0 0.573 ... 273.0 21.0 396.90 9.08 20.6 503 0.06076 0.0 11.93 0.0 0.573 ... 273.0 21.0 396.90 5.64 23.9 504 0.10959 0.0 11.93 0.0 0.573 ... 273.0 21.0 393.45 6.48 22.0 505 0.04741 0.0 11.93 0.0 0.573 ... 273.0 21.0 396.90 7.88 11.9 [506 rows x 14 columns] <class 'pandas. core. frame. DataFrame'> RangeIndex: 506 entries, 0 to 505 Data columns (total 14 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 CRIM 506 non-null float64 1 ZN 506 non-null float64 2 INDUS 506 non-null float64 3 CHAS 506 non-null float64 4 NOX 506 non-null float64 5 RM 506 non-null float64 6 AGE 506 non-null float64 7 DIS 506 non-null float64 8 RAD 506 non-null float64 9 TAX 506 non-null float64 10 PTRATIO 506 non-null float64 11 B 506 non-null float64 12 LSTAT 506 non-null float64 13 target 506 non-null float64 dtypes: float64(14) memory usage: 55.5 KB
# 2. Data preprocessing
# 2.1, Separate Features and Labels
# 3. Model training and reasoning
# 3.1, Segment Dataset
# 3.2, data reprocessing
# Reshape the input data into a 3D tensor (number of samples, time steps, number of features )
<class 'numpy.ndarray'> (455, 1, 13) X_train [[[6.04700e-02 0.00000e + 00 2.46000e + 00 ... 1.78000e + 01 3.87110e + 02 1.31500e+01]] [[6.29760e-01 0.00000e + 00 8.14000e + 00 ... 2.10000e + 01 3.96900e + 02 8.26000e+00]] [[7.99248e + 00 0.00000e + 00 1.81000e + 01 ... 2.02000e + 01 3.96900e + 02 2.45600e+01]] ... [[3.51140e-01 0.00000e + 00 7.38000e + 00 ... 1.96000e + 01 3.96900e + 02 7.70000e+00]] [[9.18702e + 00 0.00000e + 00 1.81000e + 01 ... 2.02000e + 01 3.96900e + 02 2.36000e+01]] [[4.55587e + 00 0.00000e + 00 1.81000e + 01 ... 2.02000e + 01 3.54700e + 02 7.12000e+00]]]
# 3.3, Model building and training
Epoch 1/800 15/15 [=================================] - 2s 22ms/step - loss: 367.5489 - val_loss: 189.7880 Epoch 2/800 15/15 [=================================] - 0s 3ms/step - loss: 95.0314 - val_loss: 133.7617 Epoch 3/800 15/15 [=================================] - 0s 3ms/step - loss: 63.8908 - val_loss: 110.7545 Epoch 4/800 15/15 [=================================] - 0s 3ms/step - loss: 54.9615 - val_loss: 108.7314 Epoch 5/800 15/15 [================================] - 0s 3ms/step - loss: 53.5053 - val_loss: 104.1971 Epoch 6/800 15/15 [=================================] - 0s 3ms/step - loss: 50.4742 - val_loss: 111.0977 Epoch 7/800 15/15 [================================] - 0s 4ms/step - loss: 46.4744 - val_loss: 100.7286 Epoch 8/800 15/15 [================================] - 0s 3ms/step - loss: 46.6553 - val_loss: 99.4326 Epoch 9/800 15/15 [================================] - 0s 3ms/step - loss: 48.1464 - val_loss: 96.9524 Epoch 10/800 15/15 [================================] - 0s 3ms/step - loss: 46.4484 - val_loss: 96.3056 Epoch 11/800 15/15 [================================] - 0s 3ms/step - loss: 41.9167 - val_loss: 92.1237 Epoch 12/800 15/15 [================================] - 0s 3ms/step - loss: 40.4515 - val_loss: 89.9320 Epoch 13/800 15/15 [================================] - 0s 3ms/step - loss: 46.7765 - val_loss: 91.3324 Epoch 14/800 15/15 [================================] - 0s 3ms/step - loss: 45.2451 - val_loss: 83.1068 Epoch 15/800 15/15 [================================] - 0s 4ms/step - loss: 44.0281 - val_loss: 77.3420 Epoch 16/800 15/15 [=================================] - 0s 3ms/step - loss: 42.0810 - val_loss: 85.3165 Epoch 17/800 15/15 [================================] - 0s 4ms/step - loss: 37.4590 - val_loss: 70.4207 ... Epoch 757/800 15/15 [================================] - 0s 3ms/step - loss: 4.9589 - val_loss: 37.6601 Epoch 758/800 15/15 [================================] - 0s 3ms/step - loss: 4.6070 - val_loss: 36.7595 Epoch 759/800 15/15 [================================] - 0s 3ms/step - loss: 5.8827 - val_loss: 41.7672 Epoch 760/800 15/15 [================================] - 0s 3ms/step - loss: 5.3787 - val_loss: 42.0669 Epoch 761/800 15/15 [================================] - 0s 3ms/step - loss: 5.2201 - val_loss: 47.2067 Epoch 762/800 15/15 [================================] - 0s 3ms/step - loss: 4.5653 - val_loss: 46.1523 Epoch 763/800 15/15 [================================] - 0s 3ms/step - loss: 5.7319 - val_loss: 43.6643 Epoch 764/800 15/15 [================================] - 0s 3ms/step - loss: 4.3259 - val_loss: 41.5630 Epoch 765/800 15/15 [================================] - 0s 3ms/step - loss: 4.2562 - val_loss: 40.0810 Epoch 766/800 15/15 [=================================] - 0s 3ms/step - loss: 5.4430 - val_loss: 37.8770 Epoch 767/800 15/15 [================================] - 0s 3ms/step - loss: 5.2480 - val_loss: 46.7311 Epoch 768/800 15/15 [=================================] - 0s 3ms/step - loss: 4.7596 - val_loss: 40.3361 Epoch 769/800 15/15 [=================================] - 0s 3ms/step - loss: 4.8908 - val_loss: 42.5085 Epoch 770/800 15/15 [================================] - 0s 3ms/step - loss: 4.7232 - val_loss: 39.5460 Epoch 771/800 15/15 [================================] - 0s 3ms/step - loss: 4.6950 - val_loss: 41.4992 Epoch 772/800 15/15 [================================] - 0s 3ms/step - loss: 4.9918 - val_loss: 42.5983 Epoch 773/800 15/15 [================================] - 0s 3ms/step - loss: 5.0848 - val_loss: 50.5700 Epoch 774/800 15/15 [================================] - 0s 3ms/step - loss: 7.3065 - val_loss: 30.6110 Epoch 775/800 15/15 [================================] - 0s 3ms/step - loss: 8.4268 - val_loss: 42.6159 Epoch 776/800 15/15 [================================] - 0s 3ms/step - loss: 9.8120 - val_loss: 45.9334 Epoch 777/800 15/15 [================================] - 0s 3ms/step - loss: 5.8191 - val_loss: 47.6144 Epoch 778/800 15/15 [================================] - 0s 3ms/step - loss: 7.1561 - val_loss: 50.9774 Epoch 779/800 15/15 [================================] - 0s 3ms/step - loss: 5.4377 - val_loss: 32.8492 Epoch 780/800 15/15 [================================] - 0s 3ms/step - loss: 6.8226 - val_loss: 25.2842 Epoch 781/800 15/15 [================================] - 0s 4ms/step - loss: 6.6671 - val_loss: 32.6189 Epoch 782/800 15/15 [================================] - 0s 3ms/step - loss: 5.4168 - val_loss: 48.3716 Epoch 783/800 15/15 [================================] - 0s 3ms/step - loss: 6.1797 - val_loss: 26.0899 Epoch 784/800 15/15 [================================] - 0s 3ms/step - loss: 7.9516 - val_loss: 23.3273 Epoch 785/800 15/15 [================================] - 0s 3ms/step - loss: 7.4502 - val_loss: 29.4346 Epoch 786/800 15/15 [================================] - 0s 3ms/step - loss: 6.0257 - val_loss: 40.4242 Epoch 787/800 15/15 [================================] - 0s 3ms/step - loss: 5.0951 - val_loss: 49.2547 Epoch 788/800 15/15 [================================] - 0s 3ms/step - loss: 4.8296 - val_loss: 37.7677 Epoch 789/800 15/15 [================================] - 0s 3ms/step - loss: 6.2584 - val_loss: 40.6493 Epoch 790/800 15/15 [=================================] - 0s 4ms/step - loss: 6.6211 - val_loss: 31.1713 Epoch 791/800 15/15 [================================] - 0s 4ms/step - loss: 5.2383 - val_loss: 55.5448 Epoch 792/800 15/15 [================================] - 0s 3ms/step - loss: 5.4096 - val_loss: 43.4175 Epoch 793/800 15/15 [================================] - 0s 3ms/step - loss: 5.0846 - val_loss: 39.2757 Epoch 794/800 15/15 [================================] - 0s 4ms/step - loss: 4.5122 - val_loss: 41.4339 Epoch 795/800 15/15 [================================] - 0s 3ms/step - loss: 4.8337 - val_loss: 47.6556 Epoch 796/800 15/15 [================================] - 0s 3ms/step - loss: 5.1507 - val_loss: 44.6231 Epoch 797/800 15/15 [================================] - 0s 3ms/step - loss: 5.2973 - val_loss: 44.1294 Epoch 798/800 15/15 [================================] - 0s 3ms/step - loss: 4.5970 - val_loss: 46.2830 Epoch 799/800 15/15 [================================] - 0s 3ms/step - loss: 4.4145 - val_loss: 45.5006 Epoch 800/800 15/15 [================================] - 0s 4ms/step - loss: 4.1411 - val_loss: 46.0986</ pre> <p></p> <h2 id="# 3.4, model training"># 3.4, model training</h2> <pre>test set mean square error: 46.099 boston_val_MAE: 3.6606673203262625 boston_val_MSE: 46.098603197697834 boston_val_RMSE: 6.7895952160418105 boston_val_R2: 0.5967941831772676
# Draw the loss curve of training set and test set
?< img alt="" height="831" src="//i2.wp.com/img-blog.csdnimg.cn/b30a0c7b20d240f4a67aa93a8e6ae708.png" width="1200">?
?
?
?
?< img alt="" height="842" src="//i2.wp.com/img-blog.csdnimg.cn/0cfd7b6e6df34315973dbe5f12571dda.png" width="1200">??
# 3.5, Model Evaluation
boston_LSTM_8_val_MAE: 3.642160939235313 boston_LSTM_8_val_MSE: 45.455696025386075 boston_LSTM_8_val_RMSE: 6.742083952709732 boston_LSTM_8_val_R2: 0.6024174319000424 model_name: boston_LSTM_8_0.6024 boston_GRU_8_val_MAE: 2.9867677613800647 boston_GRU_8_val_MSE: 27.82258275867772 boston_GRU_8_val_RMSE: 5.2747116280113095 boston_GRU_8_val_R2: 0.7566471339875444 model_name: boston_GRU_8_0.7566 boston_LSTM_16_val_MAE: 3.33460968615962 boston_LSTM_16_val_MSE: 31.442621036483846 boston_LSTM_16_val_RMSE: 5.607372025867719 boston_LSTM_16_val_R2: 0.7249841249268865 model_name: boston_LSTM_16_0.7250 boston_GRU_16_val_MAE: 3.244604744630701 boston_GRU_16_val_MSE: 31.916550943083372 boston_GRU_16_val_RMSE: 5.649473510255922 boston_GRU_16_val_R2: 0.7208388519283171 model_name: boston_GRU_16_0.7208 boston_LSTM_32_val_MAE: 3.289968232547536 boston_LSTM_32_val_MSE: 32.27946842476036 boston_LSTM_32_val_RMSE: 5.6815023035074415 boston_LSTM_32_val_R2: 0.7176645596615586 model_name: boston_LSTM_32_0.7177 boston_GRU_32_val_MAE: 3.6742190005732516 boston_GRU_32_val_MSE: 38.03228366757621 boston_GRU_32_val_RMSE: 6.167031998261093 boston_GRU_32_val_R2: 0.6673470140504223 model_name: boston_GRU_32_0.6673 boston_LSTM_64_val_MAE: 3.2653301463407627 boston_LSTM_64_val_MSE: 31.565284368449653 boston_LSTM_64_val_RMSE: 5.618299063635688 boston_LSTM_64_val_R2: 0.7239112384286261 model_name: boston_LSTM_64_0.7239 boston_GRU_64_val_MAE: 3.738212989358341 boston_GRU_64_val_MSE: 35.85519371644912 boston_GRU_64_val_RMSE: 5.987920650480358 boston_GRU_64_val_R2: 0.6863891383481191 model_name: boston_GRU_64_0.6864
# Draw forecast results
?< img alt="" height="808" src="//i2.wp.com/img-blog.csdnimg.cn/4b6aceff47164181b81476fecaa3ff80.png" width="1200">?
?
?
?
?
?
?
?
The knowledge points of the article match the official knowledge files, and you can further learn relevant knowledge Algorithm skill treeHomepageOverview 41925 people are studying systematically