The purpose of this project is to demonstrate data science techniques on datasets provided by State Farm insurance company. The first step is to load and clean the data, as well as conduct exploratory data analysis to understand the data. Following EDA, a few classification models will be built and compared. A logistic regression and another model will be chosen as the final models. We will then compare and contrast the different models based on respective strengths and weaknesses. Finally, predictions will be made on the test data, in the form of class probabilities for belonging to the positive class.
# import libraries
import pandas as pd
import plotly_express as px
import plotly.graph_objects as go
import numpy as np
from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
from sklearn.dummy import DummyClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn import svm
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier
import lightgbm as lgb
from imblearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OrdinalEncoder, LabelEncoder
from sklearn.neural_network import MLPClassifier
from tensorflow import keras
from tensorflow.keras.optimizers import Adam
from sklearn.experimental import enable_iterative_imputer
from sklearn.impute import SimpleImputer, KNNImputer, IterativeImputer
from sklearn.metrics import accuracy_score, auc, roc_auc_score, roc_curve, f1_score, classification_report, confusion_matrix
from imblearn.over_sampling import SMOTE
# show graphs in html
import plotly.io as pio
pio.renderers.default = "plotly_mimetype+notebook"
# read dataset
train = pd.read_csv('datasets/exercise_40_train.csv')
test = pd.read_csv('datasets/exercise_40_test.csv')
# set max column length to 110
pd.set_option('display.max_columns', 110)
# look at dataset
train.head()
y | x1 | x2 | x3 | x4 | x5 | x6 | x7 | x8 | x9 | x10 | x11 | x12 | x13 | x14 | x15 | x16 | x17 | x18 | x19 | x20 | x21 | x22 | x23 | x24 | x25 | x26 | x27 | x28 | x29 | x30 | x31 | x32 | x33 | x34 | x35 | x36 | x37 | x38 | x39 | x40 | x41 | x42 | x43 | x44 | x45 | x46 | x47 | x48 | x49 | x50 | x51 | x52 | x53 | x54 | x55 | x56 | x57 | x58 | x59 | x60 | x61 | x62 | x63 | x64 | x65 | x66 | x67 | x68 | x69 | x70 | x71 | x72 | x73 | x74 | x75 | x76 | x77 | x78 | x79 | x80 | x81 | x82 | x83 | x84 | x85 | x86 | x87 | x88 | x89 | x90 | x91 | x92 | x93 | x94 | x95 | x96 | x97 | x98 | x99 | x100 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0 | 0.165254 | 18.060003 | Wed | 1.077380 | -1.339233 | -1.584341 | 0.0062% | 0.220784 | 1.816481 | 1.171788 | 109.626841 | 4.644568 | 4.814885 | 1.541740 | -0.587361 | 101.906052 | 4.278444 | 77.709700 | $-908.650758424405 | 7.328029 | 36.601967 | 126.384294 | 1.264713 | female | 3.834082 | 2.352406 | 905.491445 | 14.779467 | 0.752524 | NaN | no | -0.597288 | Colorado | 46.667221 | 3.159002 | 69.444841 | 32.450076 | NaN | 5-10 miles | 18.129499 | NaN | 0.904626 | 1.227266 | 0.703326 | -0.810371 | 0.234654 | 53.359113 | 58.251700 | 16.084987 | 1602.756464 | 2.740134 | 7.617666 | 72.167186 | 0.166127 | NaN | -1.532562 | NaN | 300.629990 | 0 | August | 9.840873 | 6.600008 | 1.252479 | 46.992716 | farmers | 1.212048 | 12.896733 | 15.263491 | 0.518653 | 0.543768 | -0.912506 | 53.521544 | 7.723719 | NaN | NaN | NaN | mercedes | 51.287604 | 1.0 | 11.131462 | -0.471594 | -3.261865 | 0.085710 | 0.383351 | 126.935322 | 47.872599 | -0.162668 | 1.079562 | 10.602395 | 2.770844 | -0.397427 | 0.909479 | no | 5.492487 | NaN | 10.255579 | 7.627730 | 0 | yes | 104.251338 |
1 | 1 | 2.441471 | 18.416307 | Friday | 1.482586 | 0.920817 | -0.759931 | 0.0064% | 1.192441 | 3.513950 | 1.419900 | 84.079367 | 1.459868 | 1.443983 | NaN | -1.165869 | 100.112121 | 8.487132 | 70.808566 | $-1864.9622875143 | 6.796408 | 3.789176 | 93.037021 | 0.346201 | male | 4.345028 | 1.651579 | 868.784447 | 25.914680 | 1.445294 | NaN | no | 1.127376 | Tennessee | 47.928569 | -1.013712 | 61.289132 | 35.521824 | 31.526217 | 5-10 miles | 11.389081 | 99074.100620 | -1.157897 | -0.822055 | -0.851141 | 2.651773 | 1.104631 | 37.728834 | 56.052749 | NaN | 3350.837875 | 1.995635 | NaN | 73.964560 | NaN | NaN | 0.956001 | NaN | 300.629990 | 0 | April | 6.939395 | 3.864254 | 0.057936 | 48.216622 | allstate | 0.738526 | 14.402071 | 33.940951 | -0.140384 | 1.016703 | -0.000234 | 46.797571 | 7.260365 | -1.731529 | 0.666354 | -2.870800 | mercedes | 42.918352 | NaN | 11.715043 | 1.691428 | -4.789183 | 4.684722 | 2.138771 | 102.409522 | 46.584780 | 1.252022 | 1.223852 | 10.408146 | 3.703013 | 0.656651 | 9.093466 | no | 3.346429 | 4.321172 | NaN | 10.505284 | 1 | yes | 101.230645 |
2 | 1 | 4.427278 | 19.188092 | Thursday | 0.145652 | 0.366093 | 0.709962 | -8e-04% | 0.952323 | 0.782974 | -1.247022 | 95.375221 | 1.098525 | 1.216059 | 0.450624 | 0.211685 | 99.215210 | 8.601193 | 75.922820 | $-543.187402955527 | 3.802395 | 7.407996 | 71.022413 | 0.070968 | male | 4.332644 | -0.375737 | 1077.233497 | 13.177479 | 4.174536 | NaN | no | 2.605279 | Texas | 56.674425 | 0.108486 | 69.798923 | 30.684074 | 31.049447 | 5-10 miles | 14.372443 | 100087.339539 | 0.869508 | 0.150728 | NaN | -0.856591 | -2.561083 | 50.236892 | 63.975108 | 6.998154 | 3756.910196 | 1.772648 | 15.057485 | 59.428690 | 1.844493 | NaN | 4.127857 | NaN | 182.369349 | 0 | September | 6.228138 | 1.370661 | -0.239705 | 54.120933 | geico | -0.032980 | 14.402071 | 18.211817 | -0.819798 | 1.010811 | -0.000234 | 48.202036 | 9.336021 | 0.209817 | 1.124866 | -3.558718 | subaru | 55.020333 | 1.0 | 5.660882 | -2.608974 | -6.387984 | 2.506272 | 2.138771 | NaN | 53.977291 | 0.657292 | -0.353469 | NaN | 3.997183 | 2.059615 | 0.305170 | no | 4.456565 | NaN | 8.754572 | 7.810979 | 0 | yes | 109.345215 |
3 | 0 | 3.925235 | 19.901257 | Tuesday | 1.763602 | -0.251926 | -0.827461 | -0.0057% | -0.520756 | 1.825586 | 2.223038 | 96.420382 | -1.390239 | 3.962961 | NaN | -2.046856 | NaN | 6.611554 | 74.966925 | $-182.626380634258 | 7.728963 | 8.136213 | 121.610846 | 0.700954 | male | 7.294990 | -0.603983 | 1051.655489 | 17.006528 | 2.347355 | NaN | no | 1.071202 | Minnesota | 59.154933 | 1.319711 | 65.408246 | 34.401290 | 48.363690 | 5-10 miles | 13.191173 | 100934.096543 | NaN | -0.965711 | NaN | 0.422522 | -2.123048 | 41.857197 | 59.226119 | NaN | 1961.609788 | 3.155214 | NaN | 68.671023 | -1.020225 | 5.833712 | 0.663759 | NaN | 300.629990 | 0 | September | 6.005140 | 0.013162 | 0.318335 | 54.784192 | geico | -0.466535 | 14.402071 | 14.629914 | 1.389325 | 0.704880 | -1.510949 | 49.882647 | 5.661421 | 1.606797 | 1.726010 | -0.398417 | nissan | 47.769343 | 1.0 | 7.472328 | 1.424316 | -5.431998 | 3.285291 | 2.138771 | 105.208424 | 49.543472 | 2.066346 | 1.761266 | NaN | 2.933707 | 0.899392 | 5.971782 | no | 4.100022 | 1.151085 | NaN | 9.178325 | 1 | yes | 103.021970 |
4 | 0 | 2.868802 | 22.202473 | Sunday | 3.405119 | 0.083162 | 1.381504 | 0.0109% | -0.732739 | 2.151990 | -0.275406 | 90.769952 | 7.230125 | 3.877312 | 0.392002 | -1.201565 | 100.626558 | 9.103015 | 77.977156 | $967.007090837503 | 2.272765 | 24.452102 | 56.492317 | -1.102387 | male | 6.313827 | 0.429187 | 949.904947 | 16.962710 | 0.510985 | NaN | yes | 2.283921 | New York | 46.445617 | 0.022747 | 66.662910 | 31.135261 | 31.819899 | 5-10 miles | 17.210430 | NaN | NaN | 1.012093 | NaN | -0.348240 | 3.477451 | 47.844153 | 55.921988 | NaN | 2345.195505 | 3.253079 | 14.193721 | 71.853326 | 0.926779 | 4.115990 | -2.273628 | NaN | 149.725023 | 0 | January | 7.654926 | 1.305936 | 0.418272 | 51.486405 | geico | -0.195764 | 14.402071 | 12.227512 | -2.951163 | 1.096548 | -0.000234 | 51.349106 | 9.422401 | 3.488398 | 1.884259 | 0.019803 | toyota | 44.640410 | 1.0 | 6.530625 | 0.705003 | -5.664815 | 3.395916 | 2.138771 | 96.150945 | 54.843346 | 0.663113 | -0.952377 | NaN | 2.922302 | 3.003595 | 1.046096 | yes | 3.234033 | 2.074927 | 9.987006 | 11.702664 | 0 | yes | 92.925935 |
At first glance, we see various problems with the dataset, and we collect some ideas of how to deal with those problems: label encode x3, remove % in x7, fill missing values, remove dollar sign in x19, binarize x24, binarize x31, label encode x33, label encode x39, label encode x60, label encode x64, label encode x65, label encode x77, binarize x93, binarize x99. The most efficient method would be to use a pipeline to label encode and impute missing values.
# summary info on columns
train.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 40000 entries, 0 to 39999 Columns: 101 entries, y to x100 dtypes: float64(86), int64(3), object(12) memory usage: 30.8+ MB
# looking at shape of data
train.shape
(40000, 101)
# looking at column names
train.columns
Index(['y', 'x1', 'x2', 'x3', 'x4', 'x5', 'x6', 'x7', 'x8', 'x9', ... 'x91', 'x92', 'x93', 'x94', 'x95', 'x96', 'x97', 'x98', 'x99', 'x100'], dtype='object', length=101)
# remove special characters
train.x7 = train.x7.str.replace('%', '').astype(float)
train.x19 = train.x19.str.replace('$', '').astype(float)
C:\Users\XIX\AppData\Local\Temp\ipykernel_32516\2428904411.py:3: FutureWarning: The default value of regex will change from True to False in a future version. In addition, single character regular expressions will *not* be treated as literal strings when regex=True.
# Check proper implementation
train[['x7', 'x19']].head()
x7 | x19 | |
---|---|---|
0 | 0.0062 | -908.650758 |
1 | 0.0064 | -1864.962288 |
2 | -0.0008 | -543.187403 |
3 | -0.0057 | -182.626381 |
4 | 0.0109 | 967.007091 |
We needed to remove the special characters from the dataset, and then convert those columns into float. By default, x19 was rounded to 6 decimal places. This should have a minimal effect on the model performance.
# looking at categories
train.select_dtypes(['object'])
x3 | x24 | x31 | x33 | x39 | x60 | x65 | x77 | x93 | x99 | |
---|---|---|---|---|---|---|---|---|---|---|
0 | Wed | female | no | Colorado | 5-10 miles | August | farmers | mercedes | no | yes |
1 | Friday | male | no | Tennessee | 5-10 miles | April | allstate | mercedes | no | yes |
2 | Thursday | male | no | Texas | 5-10 miles | September | geico | subaru | no | yes |
3 | Tuesday | male | no | Minnesota | 5-10 miles | September | geico | nissan | no | yes |
4 | Sunday | male | yes | New York | 5-10 miles | January | geico | toyota | yes | yes |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
39995 | Sun | female | no | NaN | 5-10 miles | July | farmers | NaN | no | yes |
39996 | Thursday | male | yes | Illinois | 5-10 miles | July | progressive | ford | no | yes |
39997 | Monday | male | yes | NaN | 5-10 miles | August | geico | ford | no | yes |
39998 | Tuesday | male | no | Ohio | 5-10 miles | December | farmers | NaN | no | yes |
39999 | Thursday | NaN | no | Florida | 5-10 miles | January | progressive | toyota | no | NaN |
40000 rows × 10 columns
We need to take a better look at the object columns with EDA.
# rows with missing values
train.isna().any(axis=1).sum()
39999
We see that most rows have at least one missing value
# checking for rows where all values are missing
train.isna().all(axis=0).sum()
0
Dataset does not contain any rows where all values are missing.
# looking for duplicates
train.duplicated().sum()
0
# look at test set
test.head()
x1 | x2 | x3 | x4 | x5 | x6 | x7 | x8 | x9 | x10 | x11 | x12 | x13 | x14 | x15 | x16 | x17 | x18 | x19 | x20 | x21 | x22 | x23 | x24 | x25 | x26 | x27 | x28 | x29 | x30 | x31 | x32 | x33 | x34 | x35 | x36 | x37 | x38 | x39 | x40 | x41 | x42 | x43 | x44 | x45 | x46 | x47 | x48 | x49 | x50 | x51 | x52 | x53 | x54 | x55 | x56 | x57 | x58 | x59 | x60 | x61 | x62 | x63 | x64 | x65 | x66 | x67 | x68 | x69 | x70 | x71 | x72 | x73 | x74 | x75 | x76 | x77 | x78 | x79 | x80 | x81 | x82 | x83 | x84 | x85 | x86 | x87 | x88 | x89 | x90 | x91 | x92 | x93 | x94 | x95 | x96 | x97 | x98 | x99 | x100 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 4.747627 | 20.509439 | Wednesday | 2.299105 | -1.815777 | -0.752166 | 0.0098% | -3.240309 | 0.587948 | -0.260721 | 101.113628 | -0.812035 | 3.251085 | -0.004432 | -0.917763 | 102.193597 | 7.097387 | 75.256714 | $120.216189955777 | 8.269754 | 4.794438 | 56.134458 | -0.083641 | NaN | 5.740955 | -3.152540 | 924.013304 | 17.697125 | 1.402273 | NaN | no | 1.461654 | Florida | 57.930285 | 4.727300 | 66.416594 | 28.450466 | 33.499310 | 5-10 miles | 16.776061 | 99971.844703 | 0.492812 | -0.963616 | NaN | NaN | 1.266416 | 53.020326 | 61.840284 | NaN | 1825.413159 | 2.517924 | NaN | 70.880778 | NaN | 1.923125 | 0.239009 | NaN | 300.62999 | 0 | May | 5.461123 | 5.149473 | 1.179229 | 59.346822 | progressive | 2.219502 | 17.667792 | 5.557066 | -2.030901 | 0.636111 | -0.000234 | 62.363381 | 4.613878 | 2.415655 | 3.632041 | -0.382482 | mercedes | 55.938387 | 1.0 | 8.325299 | -1.274085 | -5.663245 | 2.402660 | -0.061040 | NaN | 64.002500 | 0.548974 | 3.420875 | 11.553023 | 5.904644 | NaN | 12.542333 | no | 3.107683 | 0.533904 | 12.438759 | 7.298306 | 0 | NaN | 93.567120 |
1 | 1.148654 | 19.301465 | Fri | 1.862200 | -0.773707 | -1.461276 | 0.0076% | 0.443209 | 0.522113 | -1.090886 | 104.791999 | 8.805876 | 1.651993 | NaN | -1.396065 | 99.356609 | 7.117930 | 79.061540 | $-267.562586413086 | 5.668347 | 12.377178 | NaN | 0.321981 | female | 2.058123 | -0.442525 | 1107.628399 | 15.747234 | 2.027073 | NaN | yes | 0.608259 | North Carolina | 55.646392 | 0.789132 | 68.881807 | 32.242558 | -1.034680 | 5-10 miles | 11.959804 | 99910.554923 | 0.547935 | 1.001799 | 1.315020 | 3.229084 | 5.873890 | 49.116516 | 61.279131 | 9.360134 | 1818.390072 | 2.269700 | 0.336755 | 69.832156 | 2.666487 | 1.224195 | -1.214940 | NaN | 300.62999 | 0 | July | 6.520113 | 2.101449 | -0.871509 | NaN | allstate | 1.289800 | 14.402071 | 7.302161 | -1.553879 | 1.352019 | -0.000234 | 54.104054 | 9.010198 | 0.537178 | 1.489493 | -1.106853 | ford | 44.322947 | 1.0 | 6.088895 | 1.878944 | -8.237793 | 3.636347 | 3.726452 | 105.488589 | 53.387201 | -0.751229 | 0.295234 | 6.212811 | 4.876645 | -0.848567 | 7.213829 | yes | 4.276078 | NaN | 10.386987 | 12.527094 | 1 | yes | 98.607486 |
2 | 4.986860 | 18.769675 | Saturday | 1.040845 | -1.548690 | 2.632948 | -5e-04% | -1.167885 | 5.739275 | 0.222975 | 102.109546 | 7.831517 | 3.055358 | 2.036434 | 1.057296 | NaN | 10.943217 | 87.567827 | $-311.292903116571 | 3.219583 | 31.135956 | 50.048638 | 1.174485 | male | 1.609278 | -0.303259 | 1009.911508 | 12.008190 | 0.618778 | NaN | no | 0.680923 | NaN | 53.178113 | 0.869321 | 70.249633 | 35.207243 | 48.980294 | 5-10 miles | 14.564732 | 100729.380783 | 0.096947 | -0.490053 | NaN | 1.333292 | 0.750075 | 48.258898 | 63.737244 | 11.564194 | 1815.680559 | 1.704048 | NaN | 67.090400 | 1.547230 | NaN | 1.428580 | NaN | 300.62999 | 0 | January | 12.190433 | 1.793349 | -0.114922 | 48.121885 | progressive | -1.755707 | 14.402071 | 9.903803 | 1.720469 | 0.765756 | -0.000234 | 51.522621 | 11.700359 | -1.867170 | 0.383319 | -1.078648 | NaN | 48.854080 | 1.0 | 8.711055 | -0.073306 | -8.739095 | NaN | 2.138771 | NaN | 49.687134 | 2.641871 | 1.718243 | NaN | 4.455432 | 1.143388 | 10.483928 | no | 2.090868 | -1.780474 | 11.328177 | 11.628247 | 0 | yes | 94.578246 |
3 | 3.709183 | 18.374375 | Tuesday | -0.169882 | -2.396549 | -0.784673 | -0.016% | -2.662226 | 1.548050 | 0.210141 | 82.653354 | 0.436885 | 1.578106 | NaN | -1.287913 | 102.410965 | 6.588790 | 71.825782 | $2229.14940030076 | 7.459929 | 1.822459 | 88.144007 | 0.909556 | female | 8.864059 | 0.641209 | 841.889126 | 20.904196 | 0.725017 | NaN | no | 0.622849 | Mississippi | 50.311869 | 0.453211 | 65.253390 | 34.432292 | 52.756665 | 5-10 miles | 18.503815 | 101476.778846 | 0.888038 | -0.007376 | -1.126059 | 1.129508 | -0.455920 | 44.525657 | 60.008453 | 12.852088 | 2251.680231 | 2.915405 | 5.895661 | 75.219207 | NaN | NaN | -0.415800 | NaN | 300.62999 | 0 | July | 6.865209 | 5.083537 | 1.685063 | 46.761738 | geico | -0.807993 | 14.402071 | 16.576216 | 0.033036 | 0.284538 | -0.000234 | 54.625974 | 13.160347 | -0.329204 | 2.171326 | -0.109125 | subaru | NaN | 1.0 | 11.742605 | -0.253294 | -6.641284 | 4.755348 | 2.138771 | NaN | NaN | 1.811825 | 0.461637 | 18.198978 | 3.947223 | 0.693646 | 3.862867 | no | 2.643847 | 1.662240 | 10.064961 | 10.550014 | 1 | NaN | 100.346261 |
4 | 3.801616 | 20.205541 | Monday | 2.092652 | -0.732784 | -0.703101 | 0.0186% | 0.056422 | 2.878167 | -0.457618 | 75.036421 | 8.034303 | 1.631426 | 0.643738 | 0.349166 | 101.513490 | 5.777599 | 74.602441 | $-469.049529991235 | 8.245194 | 0.904920 | 51.705319 | -0.544762 | female | 2.408958 | 1.841905 | 885.172420 | 14.401750 | 4.059599 | NaN | yes | 1.073262 | Georgia | 39.646787 | -0.686812 | 71.673393 | 37.257458 | 64.572325 | 5-10 miles | 11.477353 | 99444.069807 | 0.597749 | 0.432984 | NaN | 2.973636 | 2.684343 | 46.377723 | 55.276157 | 15.245726 | 3377.213091 | 0.461064 | 9.296694 | 64.547880 | 2.196671 | NaN | 3.294733 | NaN | 300.62999 | 0 | January | NaN | 4.758357 | -1.053362 | 49.328246 | progressive | -0.943724 | 15.155869 | 24.834647 | 3.127852 | 1.427115 | -0.000234 | 55.277258 | 14.443014 | -1.075761 | 6.086487 | -1.002809 | ford | 51.429529 | 0.0 | 11.602066 | 0.091523 | -4.620275 | 2.060447 | 2.138771 | NaN | 49.747279 | 0.320393 | 0.930729 | 10.014853 | 1.637334 | -0.834763 | 3.632039 | yes | 4.074434 | NaN | 9.255766 | 12.716137 | 1 | yes | 102.578918 |
# shape of dataset
test.shape
(10000, 100)
# look at info on columns
test.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 10000 entries, 0 to 9999 Data columns (total 100 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 x1 10000 non-null float64 1 x2 10000 non-null float64 2 x3 10000 non-null object 3 x4 10000 non-null float64 4 x5 9398 non-null float64 5 x6 10000 non-null float64 6 x7 10000 non-null object 7 x8 10000 non-null float64 8 x9 10000 non-null float64 9 x10 10000 non-null float64 10 x11 8671 non-null float64 11 x12 10000 non-null float64 12 x13 10000 non-null float64 13 x14 7572 non-null float64 14 x15 10000 non-null float64 15 x16 7247 non-null float64 16 x17 10000 non-null float64 17 x18 10000 non-null float64 18 x19 10000 non-null object 19 x20 10000 non-null float64 20 x21 10000 non-null float64 21 x22 9387 non-null float64 22 x23 10000 non-null float64 23 x24 9031 non-null object 24 x25 10000 non-null float64 25 x26 9383 non-null float64 26 x27 10000 non-null float64 27 x28 10000 non-null float64 28 x29 10000 non-null float64 29 x30 1915 non-null float64 30 x31 10000 non-null object 31 x32 10000 non-null float64 32 x33 8230 non-null object 33 x34 10000 non-null float64 34 x35 10000 non-null float64 35 x36 10000 non-null float64 36 x37 10000 non-null float64 37 x38 9435 non-null float64 38 x39 10000 non-null object 39 x40 10000 non-null float64 40 x41 7596 non-null float64 41 x42 7582 non-null float64 42 x43 10000 non-null float64 43 x44 1434 non-null float64 44 x45 7937 non-null float64 45 x46 10000 non-null float64 46 x47 10000 non-null float64 47 x48 10000 non-null float64 48 x49 6746 non-null float64 49 x50 10000 non-null float64 50 x51 10000 non-null float64 51 x52 5920 non-null float64 52 x53 10000 non-null float64 53 x54 6794 non-null float64 54 x55 5576 non-null float64 55 x56 10000 non-null float64 56 x57 1923 non-null float64 57 x58 10000 non-null float64 58 x59 10000 non-null int64 59 x60 10000 non-null object 60 x61 8234 non-null float64 61 x62 10000 non-null float64 62 x63 9413 non-null float64 63 x64 8738 non-null float64 64 x65 10000 non-null object 65 x66 10000 non-null float64 66 x67 9380 non-null float64 67 x68 9400 non-null float64 68 x69 10000 non-null float64 69 x70 10000 non-null float64 70 x71 10000 non-null float64 71 x72 10000 non-null float64 72 x73 10000 non-null float64 73 x74 6837 non-null float64 74 x75 8734 non-null float64 75 x76 8644 non-null float64 76 x77 7682 non-null object 77 x78 7134 non-null float64 78 x79 9390 non-null float64 79 x80 8685 non-null float64 80 x81 10000 non-null float64 81 x82 10000 non-null float64 82 x83 9428 non-null float64 83 x84 10000 non-null float64 84 x85 7581 non-null float64 85 x86 9398 non-null float64 86 x87 10000 non-null float64 87 x88 9409 non-null float64 88 x89 7325 non-null float64 89 x90 10000 non-null float64 90 x91 8690 non-null float64 91 x92 9374 non-null float64 92 x93 10000 non-null object 93 x94 9385 non-null float64 94 x95 6828 non-null float64 95 x96 8372 non-null float64 96 x97 10000 non-null float64 97 x98 10000 non-null int64 98 x99 6700 non-null object 99 x100 10000 non-null float64 dtypes: float64(86), int64(2), object(12) memory usage: 7.6+ MB
# looking at missing values
test.isna().sum()
x1 0 x2 0 x3 0 x4 0 x5 602 ... x96 1628 x97 0 x98 0 x99 3300 x100 0 Length: 100, dtype: int64
# remove special characters
test.x7 = test.x7.str.replace('%', '').astype(float)
test.x19 = test.x19.str.replace('$', '').astype(float)
C:\Users\XIX\AppData\Local\Temp\ipykernel_32516\4155689457.py:3: FutureWarning: The default value of regex will change from True to False in a future version. In addition, single character regular expressions will *not* be treated as literal strings when regex=True.
# Check proper implementation
test[['x7', 'x19']].head()
x7 | x19 | |
---|---|---|
0 | 0.0098 | 120.216190 |
1 | 0.0076 | -267.562586 |
2 | -0.0005 | -311.292903 |
3 | -0.0160 | 2229.149400 |
4 | 0.0186 | -469.049530 |
We cleaned the data from the obvious issues, such as special characters and changing dtypes. We see many missing values as well as categorical columns in the dataset. We applied the same cleaning methods to both the training and test sets.
# values of column
train.x3.value_counts(dropna=False)
Wednesday 4930 Monday 4144 Friday 3975 Tuesday 3915 Sunday 3610 Saturday 3596 Tue 2948 Thursday 2791 Mon 2200 Wed 2043 Sat 1787 Thur 1643 Fri 1620 Sun 798 Name: x3, dtype: int64
# being consistent with labeling, short notation
train.x3 = train.x3.str.replace('Sunday', 'Sun')
train.x3 = train.x3.str.replace('Monday', 'Mon')
train.x3 = train.x3.str.replace('Tuesday', 'Tue')
train.x3 = train.x3.str.replace('Wednesday', 'Wed')
train.x3 = train.x3.str.replace('Thursday', 'Thur')
train.x3 = train.x3.str.replace('Friday', 'Fri')
train.x3 = train.x3.str.replace('Saturday', 'Sat')
We combined the corresponding days to the shorthand notation.
# values of column
train.x24.value_counts(dropna=False)
female 18158 male 17986 NaN 3856 Name: x24, dtype: int64
# check values
train.x33.value_counts(dropna=False)
NaN 7171 California 3393 Texas 2252 Florida 1802 New York 1714 Illinois 1240 Pennsylvania 1233 Ohio 1114 Michigan 982 Georgia 918 North Carolina 910 New Jersey 870 Virginia 791 Washington 750 Tennessee 690 Indiana 674 Arizona 665 Massachusetts 638 Wisconsin 635 Missouri 634 Minnesota 611 Maryland 581 Alabama 554 Colorado 530 Louisiana 501 South Carolina 491 Kentucky 478 Oregon 452 Connecticut 422 Oklahoma 397 Kansas 378 Nevada 373 Utah 370 Mississippi 361 Iowa 353 Arkansas 346 New Mexico 333 Nebraska 323 West Virginia 305 Hawaii 282 Idaho 277 Maine 247 Rhode Island 246 New Hampshire 231 Montana 195 Vermont 195 Wyoming 189 DC 186 South Dakota 183 North Dakota 181 Delaware 177 Alaska 176 Name: x33, dtype: int64
There are 52 values for what is a states column. Total should be 50 + 1 with D.C. Therefore, the missing value is not a missing state and is unlikely to be a territory from the list. The values will be imputed in the pipeline.
# Change values to 1
train.x39 = train.x39.str.replace('5-10 miles', '1').astype(int)
All rows of this column are the same, so we will change the value to 1.
# checking values
train.x60.value_counts(dropna=False)
December 8136 January 7922 July 7912 August 7907 June 1272 September 1245 February 1213 November 1043 April 951 March 807 May 799 October 793 Name: x60, dtype: int64
This column represents months. No duplicate naming is seen here, and all 12 months are present.
# checking values
train.x65.value_counts(dropna=False)
progressive 10877 allstate 10859 esurance 7144 farmers 5600 geico 5520 Name: x65, dtype: int64
This column represents the different insurance companies.
# checking values
train.x77.value_counts(dropna=False)
NaN 9257 ford 9005 subaru 5047 chevrolet 5011 mercedes 4494 toyota 3555 nissan 2575 buick 1056 Name: x77, dtype: int64
This column represents different vehicle manufacturers. As it is unlikely that the missing values are all one manufacturer missing from the list, these values will have to be imputed.
# checking values
train.x93.value_counts(dropna=False)
no 35506 yes 4494 Name: x93, dtype: int64
# values of column
train.x99.value_counts(dropna=False)
yes 27164 NaN 12836 Name: x99, dtype: int64
Missing values in this column are more likely to be no, rather than missing yes values. Therefore, we will fill in missing vales with no.
# fill missing values with no
train.x99.fillna('no', inplace=True)
# check proper implementation
train.x99.value_counts(dropna=False)
yes 27164 no 12836 Name: x99, dtype: int64
Filled missing values with no.
# summary statistics on data
train.describe()
y | x1 | x2 | x4 | x5 | x6 | x7 | x8 | x9 | x10 | x11 | x12 | x13 | x14 | x15 | x16 | x17 | x18 | x19 | x20 | x21 | x22 | x23 | x25 | x26 | x27 | x28 | x29 | x30 | x32 | x34 | x35 | x36 | x37 | x38 | x39 | x40 | x41 | x42 | x43 | x44 | x45 | x46 | x47 | x48 | x49 | x50 | x51 | x52 | x53 | x54 | x55 | x56 | x57 | x58 | x59 | x61 | x62 | x63 | x64 | x66 | x67 | x68 | x69 | x70 | x71 | x72 | x73 | x74 | x75 | x76 | x78 | x79 | x80 | x81 | x82 | x83 | x84 | x85 | x86 | x87 | x88 | x89 | x90 | x91 | x92 | x94 | x95 | x96 | x97 | x98 | x100 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 40000.000000 | 40000.000000 | 40000.000000 | 40000.000000 | 37572.000000 | 40000.000000 | 40000.000000 | 40000.000000 | 40000.000000 | 40000.000000 | 34890.000000 | 40000.000000 | 40000.000000 | 30136.000000 | 40000.000000 | 28788.000000 | 40000.000000 | 40000.000000 | 40000.000000 | 40000.000000 | 40000.000000 | 37613.000000 | 40000.000000 | 40000.000000 | 37567.000000 | 40000.000000 | 40000.000000 | 40000.000000 | 7664.000000 | 40000.000000 | 40000.000000 | 40000.000000 | 40000.000000 | 40000.000000 | 37598.000000 | 40000.0 | 40000.000000 | 30497.000000 | 30271.000000 | 40000.000000 | 5753.000000 | 31991.000000 | 40000.000000 | 40000.000000 | 40000.000000 | 27177.000000 | 40000.000000 | 40000.000000 | 23818.000000 | 40000.000000 | 27272.000000 | 22304.000000 | 40000.000000 | 7536.000000 | 40000.000000 | 40000.000000 | 32707.000000 | 40000.000000 | 37580.000000 | 34899.000000 | 40000.000000 | 37575.000000 | 37616.000000 | 40000.000000 | 40000.000000 | 40000.000000 | 40000.000000 | 40000.000000 | 27084.000000 | 34755.000000 | 34751.000000 | 28627.000000 | 37570.000000 | 34744.000000 | 40000.000000 | 40000.000000 | 37572.000000 | 40000.000000 | 30285.000000 | 37593.000000 | 40000.000000 | 37669.000000 | 29309.000000 | 40000.000000 | 34741.000000 | 37565.000000 | 37660.000000 | 27396.000000 | 33362.000000 | 40000.000000 | 40000.000000 | 40000.000000 |
mean | 0.145075 | 2.999958 | 20.004865 | 0.002950 | 0.005396 | 0.007234 | 0.000033 | 0.004371 | 2.722334 | 0.490339 | 100.008988 | 3.039317 | 3.881807 | 0.787999 | 0.011162 | 100.702408 | 8.005758 | 79.936742 | -5.028013 | 4.962774 | 11.030799 | 72.135445 | 0.202855 | 4.979892 | 0.001688 | 999.260857 | 15.022117 | 1.009982 | 0.020762 | 0.552148 | 52.971716 | 1.377624 | 70.003007 | 31.005898 | 36.039836 | 1.0 | 14.966821 | 99994.120795 | 0.651686 | 0.504630 | 1.115563 | 0.564405 | 0.025262 | 49.979288 | 60.012715 | 10.299127 | 2099.960510 | 1.478683 | 4.986956 | 69.980858 | 0.004900 | 3.017140 | 0.498347 | 32.660648 | 300.422034 | 0.099650 | 10.017561 | 3.820924 | -0.004021 | 50.020462 | 0.000904 | 14.415957 | 14.934315 | 0.000815 | 0.735386 | 0.000997 | 49.951629 | 6.928531 | -0.001297 | 1.427536 | -1.024062 | 49.933173 | 0.899255 | 10.024464 | 0.021090 | -6.096573 | 3.820098 | 2.138496 | 99.987203 | 49.985343 | -0.006049 | 0.453763 | 10.767838 | 3.000166 | 0.001636 | 4.002896 | 3.969385 | 0.031886 | 10.525530 | 10.002814 | 0.497650 | 100.024743 |
std | 0.352181 | 1.994490 | 1.604291 | 1.462185 | 1.297952 | 1.358551 | 0.009965 | 1.447223 | 1.966828 | 1.029863 | 13.343010 | 3.575534 | 2.678838 | 0.751117 | 1.480175 | 1.058560 | 2.270345 | 6.514228 | 1001.312630 | 3.544050 | 15.652503 | 26.890151 | 1.020094 | 2.442622 | 1.562125 | 104.197073 | 3.669128 | 1.418182 | 1.565846 | 1.893282 | 5.808011 | 1.678159 | 4.279912 | 3.397154 | 26.155245 | 0.0 | 3.243429 | 1343.329651 | 1.165203 | 1.033895 | 1.611816 | 1.204236 | 1.773983 | 4.438896 | 5.517545 | 2.696886 | 946.689773 | 1.580287 | 7.245175 | 4.975141 | 1.356709 | 1.844189 | 1.451026 | 8.419679 | 35.005180 | 0.299537 | 3.441451 | 2.236371 | 1.513578 | 4.986917 | 1.476789 | 1.160265 | 6.336299 | 1.352481 | 0.366656 | 0.426451 | 6.513115 | 5.592746 | 1.443795 | 2.390149 | 1.213259 | 4.178547 | 0.300995 | 2.566513 | 1.548012 | 2.287201 | 2.514043 | 0.780404 | 10.455759 | 4.813166 | 1.180598 | 1.090578 | 3.732763 | 1.164411 | 1.468790 | 3.017420 | 1.473939 | 1.823091 | 1.437581 | 1.986984 | 0.500001 | 5.247514 |
min | 0.000000 | -3.648431 | 13.714945 | -5.137161 | -5.616412 | -6.113153 | -0.043800 | -6.376810 | -3.143438 | -3.538129 | 50.526543 | -5.846331 | -3.060994 | -0.338985 | -6.141315 | 99.131018 | -1.384508 | 66.756855 | -4166.106634 | -9.561032 | -3.819778 | 4.263648 | -3.943834 | -2.489909 | -6.516479 | 604.067215 | -1.074573 | -4.165802 | -4.775358 | -3.966788 | 29.351461 | -3.500424 | 52.327899 | 17.179476 | -4.198270 | 1.0 | 9.999943 | 94236.454769 | -1.187438 | -4.538490 | -3.697254 | -1.319167 | -7.269258 | 31.919387 | 37.797055 | 0.521112 | -1630.227579 | -3.523781 | -23.955783 | 47.976420 | -5.191496 | 0.248270 | -5.457856 | 13.882758 | 98.627537 | 0.000000 | 1.610473 | -3.990470 | -13.864450 | 29.918688 | -6.715834 | 4.313964 | 0.400312 | -5.297159 | 0.008685 | -4.378783 | 25.222808 | -3.295149 | -5.364030 | -3.803147 | -7.818713 | 29.904840 | 0.000000 | 2.361436 | -5.852824 | -11.880645 | -2.945215 | -2.768332 | 58.261675 | 30.681184 | -5.237572 | -4.580025 | 6.106024 | -2.173310 | -6.328126 | -3.384399 | -1.663063 | -6.885150 | 8.210456 | 1.911272 | 0.000000 | 78.340735 |
25% | 0.000000 | 1.592714 | 18.921388 | -1.026798 | -0.872354 | -0.909831 | -0.006700 | -0.971167 | 1.340450 | -0.207833 | 91.056682 | -0.200773 | 1.987650 | 0.189364 | -0.986405 | 99.857805 | 6.465327 | 74.015148 | -682.694719 | 2.558445 | 1.349311 | 52.710060 | -0.484283 | 3.256295 | -1.050565 | 928.783472 | 12.534429 | 0.023706 | -1.094994 | -0.493818 | 49.046559 | 0.323663 | 67.111974 | 28.734531 | 17.758220 | 1.0 | 11.929216 | 99083.952636 | -0.297082 | -0.195922 | 0.035929 | -0.388365 | -1.177465 | 46.961124 | 56.301875 | 8.444097 | 1457.846924 | 0.384460 | 0.069080 | 66.620691 | -0.908522 | 1.551427 | -0.470991 | 26.004736 | 300.629990 | 0.000000 | 6.977200 | 2.251110 | -0.972524 | 46.659071 | -0.991991 | 14.402071 | 10.373442 | -0.933696 | 0.457529 | -0.000234 | 45.555511 | 2.935739 | -1.023591 | 0.114082 | -1.803088 | 47.123338 | 1.000000 | 8.087626 | -1.024157 | -7.794776 | 2.041133 | 2.138771 | 92.960369 | 46.750333 | -0.804739 | -0.279603 | 7.927605 | 2.211125 | -0.982679 | 1.610899 | 2.943758 | -1.190682 | 9.397548 | 8.665103 | 0.000000 | 96.516856 |
50% | 0.000000 | 2.875892 | 20.005944 | 0.002263 | 0.008822 | 0.007335 | 0.000100 | 0.002226 | 2.498876 | 0.486926 | 100.020965 | 3.073967 | 3.444608 | 0.677067 | 0.016980 | 100.540700 | 8.003181 | 79.857023 | 0.692233 | 4.982334 | 5.080475 | 68.963429 | 0.198314 | 4.846035 | 0.001465 | 999.470648 | 15.017631 | 0.924022 | 0.032074 | 0.264317 | 52.956891 | 1.175102 | 69.986920 | 31.013503 | 35.621679 | 1.0 | 14.946114 | 99998.477418 | 0.527768 | 0.506924 | 0.933941 | 0.413142 | 0.012976 | 50.014456 | 60.005288 | 10.413491 | 2099.394110 | 1.357368 | 4.957155 | 69.997643 | 0.006056 | 2.767685 | 0.500573 | 32.688223 | 300.629990 | 0.000000 | 10.023393 | 3.555353 | -0.010116 | 50.034872 | -0.005727 | 14.402071 | 14.100661 | -0.000059 | 0.690387 | -0.000234 | 49.922870 | 5.537448 | 0.008323 | 1.019482 | -0.977734 | 49.925583 | 1.000000 | 10.039939 | 0.018591 | -6.419965 | 3.454576 | 2.138771 | 99.989730 | 49.988564 | -0.009105 | 0.434799 | 9.931831 | 2.998776 | 0.005664 | 4.099763 | 3.912526 | 0.001523 | 10.358355 | 9.994318 | 0.000000 | 100.024977 |
75% | 0.000000 | 4.270295 | 21.083465 | 1.043354 | 0.892467 | 0.926222 | 0.006800 | 0.985023 | 3.827712 | 1.182994 | 109.025025 | 6.266835 | 5.319072 | 1.267256 | 1.007430 | 101.371152 | 9.537869 | 85.907545 | 670.404666 | 7.337529 | 14.648950 | 88.047645 | 0.889074 | 6.529241 | 1.061446 | 1069.255479 | 17.501026 | 1.908320 | 1.151743 | 1.107320 | 56.887108 | 2.146408 | 72.890344 | 33.297199 | 52.532933 | 1.0 | 17.973542 | 100896.590497 | 1.468320 | 1.196090 | 2.003364 | 1.337484 | 1.214231 | 52.976580 | 63.734615 | 12.231126 | 2736.132986 | 2.431755 | 9.851680 | 73.320909 | 0.923118 | 4.202086 | 1.469411 | 39.332480 | 300.629990 | 0.000000 | 13.059741 | 5.118315 | 0.959184 | 53.428866 | 0.995858 | 14.402071 | 18.607601 | 0.927747 | 0.964721 | -0.000234 | 54.302870 | 9.537602 | 1.023440 | 2.123620 | -0.201003 | 52.762082 | 1.000000 | 11.962432 | 1.070282 | -4.656606 | 5.205250 | 2.138771 | 106.984546 | 53.242358 | 0.788134 | 1.172849 | 12.710543 | 3.786751 | 0.989632 | 6.113157 | 4.930563 | 1.248742 | 11.448559 | 11.342574 | 1.000000 | 103.558762 |
max | 1.000000 | 13.837591 | 27.086468 | 5.150153 | 5.698128 | 5.639372 | 0.037900 | 5.869889 | 18.006669 | 4.724186 | 148.784484 | 11.149273 | 25.634165 | 4.291924 | 5.930678 | 104.753426 | 18.018633 | 93.909754 | 3867.314061 | 19.549984 | 196.185637 | 241.587355 | 4.713963 | 18.691319 | 6.696843 | 1415.094219 | 30.411096 | 8.298056 | 4.588930 | 21.484032 | 77.183908 | 29.910535 | 87.484540 | 44.589636 | 266.292588 | 1.0 | 20.000037 | 105443.357829 | 5.813998 | 5.316234 | 12.436032 | 6.718615 | 6.769371 | 67.323560 | 83.349036 | 19.302923 | 5792.404630 | 11.925995 | 37.384592 | 91.154443 | 6.421421 | 13.972244 | 6.803032 | 51.846691 | 500.394945 | 1.000000 | 18.393883 | 21.713475 | 9.778289 | 72.457397 | 6.578553 | 31.407836 | 66.579521 | 4.815422 | 2.402277 | 5.018038 | 74.974611 | 79.332855 | 4.935712 | 89.692287 | 3.595734 | 66.350776 | 1.000000 | 17.189460 | 6.674870 | 3.631436 | 25.029975 | 16.611924 | 139.989464 | 73.642866 | 5.355707 | 5.656528 | 38.566041 | 7.130058 | 6.868747 | 16.734658 | 11.669024 | 7.631773 | 18.725468 | 17.861580 | 1.000000 | 122.406809 |
# show correlation
fig = px.imshow(train.corr(), aspect='auto', title='Train Correlations')
fig.show()
This figure shows the correlations between the features and the target variable. Overall, we see no correlations of note.
# distribution of object columns
for col in train.select_dtypes('object'):
fig = px.histogram(train[col], title='Distribution of '+str(col), template='plotly_white')
fig.show()
The most common days are Wednesday Tuesday and Monday. The distribution of gender is balanced. Column x31 is distributed towards no, while the most common states are California and Texas. The months are distributed towards the winter and summer months. The most popular insurance companies are Progressive and Allstate, while the least common is Geico. The most common car manufacturer is Ford, while the least common is Buick. Column x93 is distributed towards no, while x99 is distributed towards yes. The distribution of these columns are likely to change after imputation.
# values of column
test.x3.value_counts(dropna=False)
Wednesday 1224 Friday 1089 Tuesday 1010 Monday 1005 Sunday 953 Saturday 846 Thursday 702 Tue 688 Wed 524 Mon 522 Thur 426 Sat 425 Fri 382 Sun 204 Name: x3, dtype: int64
# being consistent with labeling, short notation
test.x3 = test.x3.str.replace('Sunday', 'Sun')
test.x3 = test.x3.str.replace('Monday', 'Mon')
test.x3 = test.x3.str.replace('Tuesday', 'Tue')
test.x3 = test.x3.str.replace('Wednesday', 'Wed')
test.x3 = test.x3.str.replace('Thursday', 'Thur')
test.x3 = test.x3.str.replace('Friday', 'Fri')
test.x3 = test.x3.str.replace('Saturday', 'Sat')
We combined the corresponding days to the shorthand notation.
# values of column
test.x24.value_counts(dropna=False)
female 4532 male 4499 NaN 969 Name: x24, dtype: int64
Missing values need to be imputed.
# check values
test.x33.value_counts(dropna=False)
NaN 1770 California 841 Texas 593 Florida 475 New York 462 Pennsylvania 321 Illinois 306 Ohio 278 Michigan 245 North Carolina 238 Georgia 236 New Jersey 204 Washington 189 Virginia 188 Massachusetts 178 Indiana 162 Colorado 160 Tennessee 157 Oklahoma 153 Missouri 153 Alabama 149 Minnesota 148 Wisconsin 145 Maryland 139 South Carolina 132 Arizona 124 Louisiana 119 Kentucky 114 Arkansas 113 Utah 109 Oregon 102 Connecticut 100 Iowa 89 Nevada 88 Kansas 87 Mississippi 85 Nebraska 77 New Hampshire 73 Idaho 67 West Virginia 65 New Mexico 62 Rhode Island 57 Maine 54 South Dakota 50 North Dakota 48 Hawaii 46 Alaska 45 DC 44 Vermont 41 Wyoming 41 Montana 40 Delaware 38 Name: x33, dtype: int64
Again, there are 52 values for a missing value with the most counts.
# Change values to 1
test.x39 = test.x39.str.replace('5-10 miles', '1').astype(int)
All rows of this column are the same, so we will change the value to 1.
# checking values
test.x60.value_counts(dropna=False)
August 2055 July 2050 December 2028 January 1935 September 295 June 279 February 277 April 240 November 238 May 211 March 210 October 182 Name: x60, dtype: int64
No duplicate naming is seen here, and all 12 months are present.
# checking values
test.x65.value_counts(dropna=False)
progressive 2703 allstate 2686 esurance 1828 farmers 1451 geico 1332 Name: x65, dtype: int64
This column represents the different insurance companies.
# checking values
test.x77.value_counts(dropna=False)
ford 2325 NaN 2318 chevrolet 1265 subaru 1209 mercedes 1081 toyota 903 nissan 617 buick 282 Name: x77, dtype: int64
This column represents different vehicle manufacturers.
# checking values
test.x93.value_counts(dropna=False)
no 8848 yes 1152 Name: x93, dtype: int64
# values of column
test.x99.value_counts(dropna=False)
yes 6700 NaN 3300 Name: x99, dtype: int64
Missing values in this column are more likely to be no, rather than missing yes values. Therefore, we will fill in missing vales with no, just as we did with the training set.
# fill missing values with no
test.x99.fillna('no', inplace=True)
# check proper implementation
test.x99.value_counts(dropna=False)
yes 6700 no 3300 Name: x99, dtype: int64
Filled missing values with no.
# summary statistics on data
test.describe()
x1 | x2 | x4 | x5 | x6 | x7 | x8 | x9 | x10 | x11 | x12 | x13 | x14 | x15 | x16 | x17 | x18 | x19 | x20 | x21 | x22 | x23 | x25 | x26 | x27 | x28 | x29 | x30 | x32 | x34 | x35 | x36 | x37 | x38 | x39 | x40 | x41 | x42 | x43 | x44 | x45 | x46 | x47 | x48 | x49 | x50 | x51 | x52 | x53 | x54 | x55 | x56 | x57 | x58 | x59 | x61 | x62 | x63 | x64 | x66 | x67 | x68 | x69 | x70 | x71 | x72 | x73 | x74 | x75 | x76 | x78 | x79 | x80 | x81 | x82 | x83 | x84 | x85 | x86 | x87 | x88 | x89 | x90 | x91 | x92 | x94 | x95 | x96 | x97 | x98 | x100 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
count | 10000.000000 | 10000.000000 | 10000.000000 | 9398.000000 | 10000.000000 | 10000.000000 | 10000.000000 | 10000.000000 | 10000.000000 | 8671.000000 | 10000.000000 | 10000.000000 | 7572.000000 | 10000.000000 | 7247.000000 | 10000.000000 | 10000.000000 | 10000.000000 | 10000.000000 | 10000.000000 | 9387.000000 | 10000.000000 | 10000.000000 | 9383.000000 | 10000.000000 | 10000.000000 | 10000.000000 | 1915.000000 | 10000.000000 | 10000.000000 | 10000.000000 | 10000.000000 | 10000.000000 | 9435.000000 | 10000.0 | 10000.000000 | 7596.000000 | 7582.000000 | 10000.000000 | 1434.000000 | 7937.000000 | 10000.000000 | 10000.000000 | 10000.000000 | 6746.000000 | 10000.000000 | 10000.000000 | 5920.000000 | 10000.000000 | 6794.000000 | 5576.000000 | 10000.000000 | 1923.000000 | 10000.000000 | 10000.000000 | 8234.000000 | 10000.000000 | 9413.000000 | 8738.000000 | 10000.000000 | 9380.000000 | 9400.000000 | 10000.000000 | 10000.000000 | 10000.000000 | 10000.000000 | 10000.000000 | 6837.000000 | 8734.000000 | 8644.000000 | 7134.000000 | 9390.000000 | 8685.000000 | 10000.000000 | 10000.000000 | 9428.000000 | 10000.000000 | 7581.000000 | 9398.000000 | 10000.000000 | 9409.000000 | 7325.000000 | 10000.000000 | 8690.000000 | 9374.000000 | 9385.000000 | 6828.000000 | 8372.000000 | 10000.000000 | 10000.00000 | 10000.000000 |
mean | 2.944648 | 20.003002 | 0.004528 | 0.001215 | 0.001926 | 0.000008 | -0.003416 | 2.710221 | 0.506369 | 99.915979 | 3.019374 | 3.926723 | 0.784069 | 0.012419 | 100.727360 | 8.027597 | 80.026084 | -0.554811 | 4.987877 | 11.117466 | 72.091235 | 0.189215 | 5.011489 | -0.011677 | 999.005034 | 14.992886 | 0.978150 | 0.006684 | 0.550535 | 53.078686 | 1.387669 | 70.000737 | 30.988151 | 35.676454 | 1.0 | 15.043602 | 99982.619081 | 0.645443 | 0.513008 | 1.104018 | 0.565286 | 0.036462 | 50.046458 | 60.026938 | 10.277466 | 2094.418322 | 1.471573 | 4.733425 | 70.022077 | 0.020602 | 3.025879 | 0.487472 | 32.461509 | 300.674717 | 0.104100 | 9.995180 | 3.843418 | -0.004833 | 50.025191 | -0.006610 | 14.401629 | 14.788857 | 0.008452 | 0.730233 | 0.006077 | 49.980794 | 6.935070 | -0.005739 | 1.441040 | -1.010219 | 50.033544 | 0.898509 | 10.040383 | -0.008035 | -6.126672 | 3.817244 | 2.148462 | 99.986491 | 50.035111 | 0.016275 | 0.453659 | 10.847434 | 3.007583 | -0.003896 | 4.011856 | 3.949177 | 0.036814 | 10.564081 | 10.000062 | 0.49610 | 99.942280 |
std | 2.018091 | 1.600368 | 1.449873 | 1.290027 | 1.363301 | 0.009927 | 1.442214 | 1.985433 | 1.028552 | 13.254583 | 3.586150 | 2.706451 | 0.746441 | 1.482391 | 1.060159 | 2.283019 | 6.547863 | 1006.365662 | 3.547695 | 15.825830 | 27.229397 | 1.039170 | 2.437551 | 1.569508 | 104.846352 | 3.676316 | 1.401266 | 1.587620 | 1.874100 | 5.845842 | 1.605198 | 4.311982 | 3.403902 | 26.272060 | 0.0 | 3.237721 | 1363.732898 | 1.164881 | 1.025785 | 1.641046 | 1.219064 | 1.778882 | 4.402070 | 5.436317 | 2.727696 | 945.719816 | 1.610216 | 7.309549 | 4.961228 | 1.364694 | 1.848236 | 1.453488 | 8.419344 | 34.690659 | 0.305405 | 3.440018 | 2.249602 | 1.509589 | 4.968724 | 1.481606 | 1.147804 | 6.311567 | 1.329897 | 0.363357 | 0.407164 | 6.499956 | 5.600800 | 1.445674 | 2.485056 | 1.219318 | 4.127081 | 0.301994 | 2.562332 | 1.550861 | 2.270473 | 2.532796 | 0.785153 | 10.642192 | 4.827958 | 1.196265 | 1.101479 | 3.840455 | 1.156709 | 1.479756 | 3.010645 | 1.467001 | 1.833016 | 1.460146 | 2.002298 | 0.50001 | 5.302475 |
min | -2.639067 | 13.790389 | -4.768309 | -4.662646 | -5.720785 | -0.036100 | -5.627568 | -3.160208 | -3.452189 | 51.489501 | -5.039053 | -2.918862 | -0.338883 | -6.013757 | 99.131027 | -0.264370 | 66.471523 | -4497.412230 | -8.721968 | -3.094312 | 2.816051 | -3.504739 | -3.035988 | -5.798616 | 615.307402 | 1.870244 | -3.170093 | -4.574698 | -4.073306 | 28.166134 | -3.005756 | 54.232409 | 17.945555 | -3.562433 | 1.0 | 9.999943 | 95193.915675 | -1.187253 | -3.202323 | -3.319158 | -1.318430 | -6.227914 | 32.468040 | 37.728221 | 0.556815 | -1673.925468 | -3.006937 | -21.418238 | 52.402252 | -5.162871 | 0.249087 | -4.680512 | 14.630926 | 101.376280 | 0.000000 | 1.827238 | -2.172217 | -9.059326 | 32.340442 | -6.483385 | 4.389352 | 0.650828 | -4.133603 | 0.015258 | -3.933177 | 25.785545 | -2.871946 | -4.379516 | -3.564698 | -6.426074 | 33.519567 | 0.000000 | 2.927019 | -5.527865 | -11.420765 | -2.574451 | -1.900415 | 58.671658 | 33.500796 | -4.081535 | -3.769385 | 6.106138 | -1.652499 | -5.563109 | -3.889151 | -0.970920 | -6.512179 | 8.212286 | 2.444658 | 0.00000 | 79.100558 |
25% | 1.522883 | 18.926348 | -1.025638 | -0.878598 | -0.931918 | -0.006800 | -0.978422 | 1.328622 | -0.196678 | 90.981197 | -0.203236 | 1.991729 | 0.191783 | -0.988190 | 99.885403 | 6.457095 | 74.074136 | -671.622956 | 2.552523 | 1.359686 | 52.540203 | -0.514637 | 3.299264 | -1.042946 | 929.692219 | 12.498954 | 0.001232 | -1.144359 | -0.475393 | 49.058566 | 0.353986 | 67.095460 | 28.717372 | 16.850630 | 1.0 | 12.074160 | 99070.879988 | -0.305528 | -0.174593 | 0.033156 | -0.401007 | -1.156013 | 47.099402 | 56.389100 | 8.450101 | 1465.912201 | 0.354687 | -0.240869 | 66.612698 | -0.903126 | 1.540516 | -0.511281 | 25.799589 | 300.629990 | 0.000000 | 7.012192 | 2.254463 | -0.964745 | 46.614430 | -0.986255 | 14.402071 | 10.188464 | -0.903609 | 0.453536 | -0.000234 | 45.610562 | 2.955614 | -1.020489 | 0.146120 | -1.796733 | 47.192824 | 1.000000 | 8.126514 | -1.049609 | -7.797208 | 2.006995 | 2.138771 | 92.797428 | 46.793974 | -0.778406 | -0.297195 | 7.878161 | 2.215095 | -1.021329 | 1.646668 | 2.939821 | -1.202392 | 9.411218 | 8.632671 | 0.00000 | 96.442079 |
50% | 2.817275 | 20.013331 | -0.007336 | -0.009562 | 0.001364 | 0.000100 | 0.000347 | 2.467988 | 0.509366 | 99.918218 | 3.083938 | 3.481996 | 0.678295 | 0.014545 | 100.581635 | 8.027620 | 79.973740 | 10.071374 | 4.989075 | 5.151254 | 68.791299 | 0.186177 | 4.854312 | -0.007550 | 997.718354 | 15.018232 | 0.890432 | 0.002582 | 0.263972 | 53.133105 | 1.196375 | 69.966524 | 31.015743 | 35.381980 | 1.0 | 15.075852 | 99984.702510 | 0.528039 | 0.500114 | 0.909560 | 0.409918 | 0.030502 | 50.002312 | 60.049303 | 10.381110 | 2106.747397 | 1.335443 | 4.884394 | 70.056176 | 0.020854 | 2.786861 | 0.483200 | 32.313071 | 300.629990 | 0.000000 | 10.015601 | 3.575454 | -0.004496 | 50.004489 | 0.002438 | 14.402071 | 13.907039 | 0.030941 | 0.682508 | -0.000234 | 50.008291 | 5.577156 | -0.010455 | 1.033660 | -0.952967 | 50.027027 | 1.000000 | 10.048249 | 0.006684 | -6.441992 | 3.461849 | 2.138771 | 100.172805 | 50.043094 | 0.004213 | 0.439960 | 9.927533 | 3.003166 | -0.011712 | 4.110472 | 3.901630 | 0.049476 | 10.405880 | 9.967244 | 0.00000 | 99.895268 |
75% | 4.223699 | 21.083448 | 1.041062 | 0.882272 | 0.925603 | 0.006700 | 0.980095 | 3.797335 | 1.200406 | 108.722557 | 6.243080 | 5.322923 | 1.260943 | 1.004756 | 101.403862 | 9.551301 | 86.043355 | 689.172557 | 7.413721 | 14.573628 | 88.081697 | 0.902518 | 6.604764 | 1.058635 | 1069.712621 | 17.467351 | 1.879663 | 1.171711 | 1.100175 | 57.000415 | 2.167218 | 72.890605 | 33.266607 | 52.287895 | 1.0 | 18.070746 | 100920.654749 | 1.457480 | 1.209518 | 1.946130 | 1.362012 | 1.238546 | 52.911483 | 63.713410 | 12.256011 | 2735.707636 | 2.425340 | 9.740571 | 73.374176 | 0.909224 | 4.202790 | 1.466942 | 38.977898 | 300.629990 | 0.000000 | 13.022810 | 5.148163 | 0.957937 | 53.416694 | 1.002648 | 14.402071 | 18.424117 | 0.933835 | 0.959959 | -0.000234 | 54.201324 | 9.487395 | 1.000467 | 2.149136 | -0.185171 | 52.929501 | 1.000000 | 11.966315 | 1.019745 | -4.717504 | 5.244378 | 2.138771 | 107.220602 | 53.237523 | 0.814972 | 1.174502 | 12.783388 | 3.791607 | 1.005278 | 6.106678 | 4.896051 | 1.274604 | 11.518549 | 11.336859 | 1.00000 | 103.480639 |
max | 11.737364 | 25.808760 | 4.653302 | 4.709272 | 5.096100 | 0.048300 | 5.326779 | 17.165790 | 4.666843 | 148.312826 | 11.102173 | 22.916481 | 4.793638 | 5.780608 | 104.753426 | 16.332891 | 92.750969 | 3869.733323 | 19.075814 | 168.295257 | 212.245584 | 3.961364 | 16.545180 | 6.185420 | 1356.405814 | 29.780369 | 7.167330 | 4.446820 | 19.742288 | 76.805418 | 17.615444 | 87.684249 | 44.021192 | 154.051594 | 1.0 | 19.999856 | 104781.467076 | 4.903801 | 4.225407 | 11.897106 | 5.729408 | 7.053765 | 69.570256 | 82.679059 | 17.822711 | 5358.085268 | 9.500742 | 33.050567 | 88.493905 | 5.468435 | 12.775425 | 6.542576 | 51.602898 | 497.744141 | 1.000000 | 18.513492 | 16.569833 | 7.490917 | 68.030550 | 5.049130 | 29.656030 | 50.490992 | 4.690205 | 2.558823 | 4.469949 | 75.415573 | 46.043522 | 5.298340 | 80.679684 | 3.162323 | 64.953591 | 1.000000 | 17.597005 | 5.637256 | 3.149156 | 19.633040 | 16.005969 | 142.512136 | 69.874858 | 4.254048 | 4.904895 | 34.263562 | 7.286726 | 5.410440 | 15.462965 | 9.712523 | 6.536391 | 16.192126 | 17.939226 | 1.00000 | 120.527536 |
# distribution of object columns
for col in test.select_dtypes('object'):
fig = px.histogram(test[col], title='Distribution of '+ str(col), template='plotly_white')
fig.show()
We see similar distributions in these columns to the respective columns in the training set.
We observe some patterns in the dataset. We see certain weekdays and certain months are more prevalent in the datasets. Comparing the train and test datasets, we see many columns have similar distributions.
# separate features and target
X = train.drop(columns='y')
y = train.y
# values of the target
y.value_counts()
0 34197 1 5803 Name: y, dtype: int64
Target values are very imbalanced, therefore, we wil train models to optimize AUC or F1 scores. The appropriate metric depends on the specific problem and the business needs.
If the business problem involves minimizing false positives and false negatives equally, then optimizing on AUC may be appropriate, as AUC measures the ability of a model to distinguish between positive and negative classes.
However, if the business problem is such that minimizing false positives is more important than minimizing false negatives, or vice versa, then optimizing on F1 score may be more appropriate. F1 score is the harmonic mean of precision and recall and is a good metric to use when there is an uneven class distribution.
# ordinal encoding days and months in order
weekday_names = ['Mon', 'Tue', 'Wed', 'Thur', 'Fri', 'Sat', 'Sun']
month_names = ['January', 'February', 'March', 'April', 'May', 'June', 'July', 'August', 'September', 'October', 'November', 'December']
encoder_day = OrdinalEncoder(categories=[weekday_names])
encoder_month = OrdinalEncoder(categories=[month_names])
days = pd.DataFrame(encoder_day.fit_transform(X.x3.to_numpy().reshape(-1,1)), columns=['day'])
months = pd.DataFrame(encoder_month.fit_transform(X.x60.to_numpy().reshape(-1,1)), columns=['month'])
# replace columns with ordinal columns
X['x3'] = days
X['x60'] = months
# check for proper implementation
X.head()
x1 | x2 | x3 | x4 | x5 | x6 | x7 | x8 | x9 | x10 | x11 | x12 | x13 | x14 | x15 | x16 | x17 | x18 | x19 | x20 | x21 | x22 | x23 | x24 | x25 | x26 | x27 | x28 | x29 | x30 | x31 | x32 | x33 | x34 | x35 | x36 | x37 | x38 | x39 | x40 | x41 | x42 | x43 | x44 | x45 | x46 | x47 | x48 | x49 | x50 | x51 | x52 | x53 | x54 | x55 | x56 | x57 | x58 | x59 | x60 | x61 | x62 | x63 | x64 | x65 | x66 | x67 | x68 | x69 | x70 | x71 | x72 | x73 | x74 | x75 | x76 | x77 | x78 | x79 | x80 | x81 | x82 | x83 | x84 | x85 | x86 | x87 | x88 | x89 | x90 | x91 | x92 | x93 | x94 | x95 | x96 | x97 | x98 | x99 | x100 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.165254 | 18.060003 | 2.0 | 1.077380 | -1.339233 | -1.584341 | 0.0062 | 0.220784 | 1.816481 | 1.171788 | 109.626841 | 4.644568 | 4.814885 | 1.541740 | -0.587361 | 101.906052 | 4.278444 | 77.709700 | -908.650758 | 7.328029 | 36.601967 | 126.384294 | 1.264713 | female | 3.834082 | 2.352406 | 905.491445 | 14.779467 | 0.752524 | NaN | no | -0.597288 | Colorado | 46.667221 | 3.159002 | 69.444841 | 32.450076 | NaN | 1 | 18.129499 | NaN | 0.904626 | 1.227266 | 0.703326 | -0.810371 | 0.234654 | 53.359113 | 58.251700 | 16.084987 | 1602.756464 | 2.740134 | 7.617666 | 72.167186 | 0.166127 | NaN | -1.532562 | NaN | 300.629990 | 0 | 7.0 | 9.840873 | 6.600008 | 1.252479 | 46.992716 | farmers | 1.212048 | 12.896733 | 15.263491 | 0.518653 | 0.543768 | -0.912506 | 53.521544 | 7.723719 | NaN | NaN | NaN | mercedes | 51.287604 | 1.0 | 11.131462 | -0.471594 | -3.261865 | 0.085710 | 0.383351 | 126.935322 | 47.872599 | -0.162668 | 1.079562 | 10.602395 | 2.770844 | -0.397427 | 0.909479 | no | 5.492487 | NaN | 10.255579 | 7.627730 | 0 | yes | 104.251338 |
1 | 2.441471 | 18.416307 | 4.0 | 1.482586 | 0.920817 | -0.759931 | 0.0064 | 1.192441 | 3.513950 | 1.419900 | 84.079367 | 1.459868 | 1.443983 | NaN | -1.165869 | 100.112121 | 8.487132 | 70.808566 | -1864.962288 | 6.796408 | 3.789176 | 93.037021 | 0.346201 | male | 4.345028 | 1.651579 | 868.784447 | 25.914680 | 1.445294 | NaN | no | 1.127376 | Tennessee | 47.928569 | -1.013712 | 61.289132 | 35.521824 | 31.526217 | 1 | 11.389081 | 99074.100620 | -1.157897 | -0.822055 | -0.851141 | 2.651773 | 1.104631 | 37.728834 | 56.052749 | NaN | 3350.837875 | 1.995635 | NaN | 73.964560 | NaN | NaN | 0.956001 | NaN | 300.629990 | 0 | 3.0 | 6.939395 | 3.864254 | 0.057936 | 48.216622 | allstate | 0.738526 | 14.402071 | 33.940951 | -0.140384 | 1.016703 | -0.000234 | 46.797571 | 7.260365 | -1.731529 | 0.666354 | -2.870800 | mercedes | 42.918352 | NaN | 11.715043 | 1.691428 | -4.789183 | 4.684722 | 2.138771 | 102.409522 | 46.584780 | 1.252022 | 1.223852 | 10.408146 | 3.703013 | 0.656651 | 9.093466 | no | 3.346429 | 4.321172 | NaN | 10.505284 | 1 | yes | 101.230645 |
2 | 4.427278 | 19.188092 | 3.0 | 0.145652 | 0.366093 | 0.709962 | -0.0008 | 0.952323 | 0.782974 | -1.247022 | 95.375221 | 1.098525 | 1.216059 | 0.450624 | 0.211685 | 99.215210 | 8.601193 | 75.922820 | -543.187403 | 3.802395 | 7.407996 | 71.022413 | 0.070968 | male | 4.332644 | -0.375737 | 1077.233497 | 13.177479 | 4.174536 | NaN | no | 2.605279 | Texas | 56.674425 | 0.108486 | 69.798923 | 30.684074 | 31.049447 | 1 | 14.372443 | 100087.339539 | 0.869508 | 0.150728 | NaN | -0.856591 | -2.561083 | 50.236892 | 63.975108 | 6.998154 | 3756.910196 | 1.772648 | 15.057485 | 59.428690 | 1.844493 | NaN | 4.127857 | NaN | 182.369349 | 0 | 8.0 | 6.228138 | 1.370661 | -0.239705 | 54.120933 | geico | -0.032980 | 14.402071 | 18.211817 | -0.819798 | 1.010811 | -0.000234 | 48.202036 | 9.336021 | 0.209817 | 1.124866 | -3.558718 | subaru | 55.020333 | 1.0 | 5.660882 | -2.608974 | -6.387984 | 2.506272 | 2.138771 | NaN | 53.977291 | 0.657292 | -0.353469 | NaN | 3.997183 | 2.059615 | 0.305170 | no | 4.456565 | NaN | 8.754572 | 7.810979 | 0 | yes | 109.345215 |
3 | 3.925235 | 19.901257 | 1.0 | 1.763602 | -0.251926 | -0.827461 | -0.0057 | -0.520756 | 1.825586 | 2.223038 | 96.420382 | -1.390239 | 3.962961 | NaN | -2.046856 | NaN | 6.611554 | 74.966925 | -182.626381 | 7.728963 | 8.136213 | 121.610846 | 0.700954 | male | 7.294990 | -0.603983 | 1051.655489 | 17.006528 | 2.347355 | NaN | no | 1.071202 | Minnesota | 59.154933 | 1.319711 | 65.408246 | 34.401290 | 48.363690 | 1 | 13.191173 | 100934.096543 | NaN | -0.965711 | NaN | 0.422522 | -2.123048 | 41.857197 | 59.226119 | NaN | 1961.609788 | 3.155214 | NaN | 68.671023 | -1.020225 | 5.833712 | 0.663759 | NaN | 300.629990 | 0 | 8.0 | 6.005140 | 0.013162 | 0.318335 | 54.784192 | geico | -0.466535 | 14.402071 | 14.629914 | 1.389325 | 0.704880 | -1.510949 | 49.882647 | 5.661421 | 1.606797 | 1.726010 | -0.398417 | nissan | 47.769343 | 1.0 | 7.472328 | 1.424316 | -5.431998 | 3.285291 | 2.138771 | 105.208424 | 49.543472 | 2.066346 | 1.761266 | NaN | 2.933707 | 0.899392 | 5.971782 | no | 4.100022 | 1.151085 | NaN | 9.178325 | 1 | yes | 103.021970 |
4 | 2.868802 | 22.202473 | 6.0 | 3.405119 | 0.083162 | 1.381504 | 0.0109 | -0.732739 | 2.151990 | -0.275406 | 90.769952 | 7.230125 | 3.877312 | 0.392002 | -1.201565 | 100.626558 | 9.103015 | 77.977156 | 967.007091 | 2.272765 | 24.452102 | 56.492317 | -1.102387 | male | 6.313827 | 0.429187 | 949.904947 | 16.962710 | 0.510985 | NaN | yes | 2.283921 | New York | 46.445617 | 0.022747 | 66.662910 | 31.135261 | 31.819899 | 1 | 17.210430 | NaN | NaN | 1.012093 | NaN | -0.348240 | 3.477451 | 47.844153 | 55.921988 | NaN | 2345.195505 | 3.253079 | 14.193721 | 71.853326 | 0.926779 | 4.115990 | -2.273628 | NaN | 149.725023 | 0 | 0.0 | 7.654926 | 1.305936 | 0.418272 | 51.486405 | geico | -0.195764 | 14.402071 | 12.227512 | -2.951163 | 1.096548 | -0.000234 | 51.349106 | 9.422401 | 3.488398 | 1.884259 | 0.019803 | toyota | 44.640410 | 1.0 | 6.530625 | 0.705003 | -5.664815 | 3.395916 | 2.138771 | 96.150945 | 54.843346 | 0.663113 | -0.952377 | NaN | 2.922302 | 3.003595 | 1.046096 | yes | 3.234033 | 2.074927 | 9.987006 | 11.702664 | 0 | yes | 92.925935 |
Encoding all columns with ordinal encoding did not retain the order of days and months. Since there appears to be a trend in the data with respect to days and months, we want to retain the proper order of these labels. So we will encode these columns first, and then encode the other categorical columns later.
# preprocessing steps
preprocessor = Pipeline([('ordinal_encoder', OrdinalEncoder(handle_unknown='use_encoded_value', unknown_value=-1)), ('imputer', SimpleImputer()), ('scaler', StandardScaler())])
# Preprocess the test data
X_processed = preprocessor.fit_transform(X)
# implement SMOTE for class balance
oversampler = SMOTE(random_state=19)
X_final, y_final = oversampler.fit_resample(X_processed, y)
# shape of the final dataframe
X_final.shape
(68394, 100)
# targets are now balanced
y_final.value_counts()
0 34197 1 34197 Name: y, dtype: int64
# train and valid split
X_train, X_valid, y_train, y_valid = train_test_split(
X_final, y_final, test_size=0.25, random_state=19)
We preprocessed the data to convert the categorical columns into numerically labeled columns. Although some of our models selected can handle categorical values, we prefer to train the models on continuous values. We imputed the missing vales with simple imputer, scaled the data, and then implemented SMOTE to address class imbalance. Finally, we split the data into train and validation sets for hyperparameter tuning.
# Gridsearch CV for hyperparameter tuning
# Create a LightGBM dataset
lgb_train = lgb.Dataset(X_train, y_train)
lgb_valid = lgb.Dataset(X_valid, y_valid, reference=lgb_train)
# Define the parameter grid for the LightGBM model
param_grid = {
'boosting_type': ['gbdt'],
'num_leaves': [10, 15, 20],
'max_depth': [3, 4, 5],
'learning_rate': [0.1, 0.2],
'n_estimators': [100, 200, 300],
'random_state': [19]
}
# Define the parameters for the LightGBM model
params = {
'objective': 'binary',
'metric': 'auc',
}
# Create the GridSearchCV object
grid_search = GridSearchCV(LGBMClassifier(**params), param_grid, cv=2, scoring='roc_auc',verbose=3, n_jobs=-1)
# Fit the GridSearchCV object to the data
grid_search.fit(X_train, y_train)
# Print the best parameters and the best score
print("Best parameters: ", grid_search.best_params_)
print("Best score: ", grid_search.best_score_)
Fitting 2 folds for each of 54 candidates, totalling 108 fits Best parameters: {'boosting_type': 'gbdt', 'learning_rate': 0.1, 'max_depth': 5, 'n_estimators': 200, 'num_leaves': 20, 'random_state': 19} Best score: 0.963786706705483
# XG boost hyperparameter tuning
param_grid = {
'booster':['gbtree'],
'max_depth': [3, 4],
'learning_rate': [0.1],
#'n_estimators': [100, 200, 300],
'eval_metric':['auc']
}
# Create the XGBoost model
xgb = XGBClassifier(random_state=19)
# Create the GridSearchCV object
grid_search = GridSearchCV(xgb, param_grid, cv=2, scoring='roc_auc', verbose=3, n_jobs=-1)
# Fit the GridSearchCV object to the data
grid_search.fit(X_train, y_train)
# Print the best parameters and the best score
print("Best parameters: ", grid_search.best_params_)
print("Best score: ", grid_search.best_score_)
Fitting 2 folds for each of 2 candidates, totalling 4 fits Best parameters: {'booster': 'gbtree', 'eval_metric': 'auc', 'learning_rate': 0.1, 'max_depth': 4} Best score: 0.9624760773067411
We used Grid Search CV to tune hyperparameters of each model we selected, and we will use the best parameters in the pipeline.
# tuning neural network
optimizer = Adam(learning_rate=0.001)
model = keras.models.Sequential()
model.add(
keras.layers.Dense(
units=100, input_dim=X_train.shape[1], activation='relu'
))
model.add(keras.layers.Dense(
units=75, activation='relu'
))
model.add(keras.layers.Dense(
units=50, activation='relu'
))
model.add(keras.layers.Dense(
units=25, activation='relu'
))
model.add(keras.layers.Dense(
units=5, activation='relu'
))
model.add(keras.layers.Dense(
units=1, activation='sigmoid'
))
model.compile(loss='binary_crossentropy', optimizer=optimizer, metrics=['AUC'])
model.fit(X_train, y_train, epochs=10, verbose=2,
validation_data=(X_valid, y_valid))
Epoch 1/10 1603/1603 - 18s - loss: 0.4934 - auc: 0.8402 - val_loss: 0.4097 - val_auc: 0.8971 - 18s/epoch - 11ms/step Epoch 2/10 1603/1603 - 9s - loss: 0.3667 - auc: 0.9168 - val_loss: 0.3585 - val_auc: 0.9216 - 9s/epoch - 6ms/step Epoch 3/10 1603/1603 - 9s - loss: 0.3070 - auc: 0.9423 - val_loss: 0.3318 - val_auc: 0.9334 - 9s/epoch - 6ms/step Epoch 4/10 1603/1603 - 11s - loss: 0.2631 - auc: 0.9577 - val_loss: 0.3301 - val_auc: 0.9362 - 11s/epoch - 7ms/step Epoch 5/10 1603/1603 - 10s - loss: 0.2297 - auc: 0.9676 - val_loss: 0.3399 - val_auc: 0.9381 - 10s/epoch - 6ms/step Epoch 6/10 1603/1603 - 10s - loss: 0.2064 - auc: 0.9738 - val_loss: 0.3424 - val_auc: 0.9422 - 10s/epoch - 6ms/step Epoch 7/10 1603/1603 - 10s - loss: 0.1821 - auc: 0.9795 - val_loss: 0.3408 - val_auc: 0.9417 - 10s/epoch - 6ms/step Epoch 8/10 1603/1603 - 10s - loss: 0.1648 - auc: 0.9831 - val_loss: 0.3326 - val_auc: 0.9454 - 10s/epoch - 6ms/step Epoch 9/10 1603/1603 - 10s - loss: 0.1497 - auc: 0.9859 - val_loss: 0.3425 - val_auc: 0.9464 - 10s/epoch - 6ms/step Epoch 10/10 1603/1603 - 10s - loss: 0.1390 - auc: 0.9877 - val_loss: 0.3196 - val_auc: 0.9487 - 10s/epoch - 6ms/step
<keras.callbacks.History at 0x1a876c620a0>
A more complicated neural network with more layers and epochs can lead to overfitting. We trained models with 0.99 AUC with the training set, but with 0.95 AUC with the validation set.
# Classifier pipeline
pipe_lr = Pipeline([('lr_classifier', LogisticRegression(random_state=19, max_iter=2000))])
pipe_dt = Pipeline([('dt_classifier', DecisionTreeClassifier(random_state=19, max_depth=3))])
pipe_rf = Pipeline([('rf_classifier', RandomForestClassifier(random_state=19, n_estimators=40))])
pipe_sv = Pipeline([('svm_classifier', svm.LinearSVC(random_state=19, max_iter=2000))])
pipe_xg = Pipeline([('xg_classifier', XGBClassifier(random_state=19, n_estimators=200, learning_rate=0.1, eval_metric='auc', max_depth=4, subsample=1.0))])
pipe_lb = Pipeline([('lb_classifier', LGBMClassifier(boosting_type='gbdt', random_state=19, objective='binary', metric='auc', n_estimators=200, learning_rate=0.1, num_leaves=20, max_depth=5))])
pipe_ml = Pipeline([('ml_classifier', MLPClassifier(max_iter=200, random_state=19, early_stopping=True, n_iter_no_change=10))])
pipelines = [pipe_lr, pipe_dt, pipe_rf, pipe_sv, pipe_xg, pipe_lb, pipe_ml]
best_auc = 0
best_classifier = 0
best_pipeline = ""
pipe_dict = {0: 'Logistic Regression', 1: 'Decision Tree', 2: 'Random Forest', 3: 'SVM', 4: 'XG Boost', 5: 'Light GBM', 6:'Neural Network'}
# Use cross-validation to evaluate the models
for i, model in enumerate(pipelines):
model.fit(X_train, y_train)
scores = cross_val_score(model, X_final, y_final, cv=5, scoring='roc_auc')
print('{} Cross-Validation AUC: {:.2f}'.format(pipe_dict[i], scores.mean()))
if scores.mean() > best_auc:
best_auc = scores.mean()
best_pipeline = model
best_classifier = i
# Print the best classifier
print('\nClassifier with the best AUC-ROC score: {}'.format(pipe_dict[best_classifier]))
Logistic Regression Cross-Validation AUC: 0.77 Decision Tree Cross-Validation AUC: 0.73 Random Forest Cross-Validation AUC: 0.98
c:\Users\XIX\anaconda3\lib\site-packages\sklearn\svm\_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations. c:\Users\XIX\anaconda3\lib\site-packages\sklearn\svm\_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations. c:\Users\XIX\anaconda3\lib\site-packages\sklearn\svm\_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations. c:\Users\XIX\anaconda3\lib\site-packages\sklearn\svm\_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations. c:\Users\XIX\anaconda3\lib\site-packages\sklearn\svm\_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations. c:\Users\XIX\anaconda3\lib\site-packages\sklearn\svm\_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
SVM Cross-Validation AUC: 0.77 XG Boost Cross-Validation AUC: 0.96 Light GBM Cross-Validation AUC: 0.96 Neural Network Cross-Validation AUC: 0.94 Classifier with the best AUC-ROC score: Random Forest
We tried to implement two other imputers, KNN and iterative imputer. However, they were too computationally intensive for this system. KNN and iterative imputer use machine learning to impute the missing values, and increased accuracy of the imputed values comes at a cost in terms of model training time. Consequently, we will use simple imputation. The models were trained on the training set, and cross validation was used to determine average AUC scores.
# dummy model
pipe_dm = Pipeline([('dm_classifier', DummyClassifier(random_state=19, strategy='most_frequent'))])
pipe_dm.fit(X_processed, y)
scores = cross_val_score(pipe_dm, X_processed, y, cv=5, scoring='roc_auc')
final_score = sum(scores) / len(scores)
print('Average model AUC ROC score:', final_score)
Average model AUC ROC score: 0.5
# accuracy function of dummy model on imbalanced data
accuracy_score(y, pipe_dm.predict(X))
0.854925
# accuracy function of balanced data
accuracy_score(y_final, pipe_dm.predict(X_final))
0.5
A dummy model was trained to illustrate two things: the effect of class imbalance, and the difference between AUC and accuracy. This dummy is a baseline model that disregards the features, and always predicts the majority class, 0. As we can see, the accuracy of the model is 0.85, while the AUC score is also 0.5, when we use imbalanced data. However, accuracy is not a useful metric with imbalanced targets, because it does not properly illustrate the model's performance on the minority class with false negatives.Once we balance the classes, the accuracy of the dummy model drops down to 0.5.
# series of model scores
data = {'Logistic Regression': 0.7728, 'Decision Tree': 0.7378 , 'Random Forest': 0.9754, 'SVM': 0.7729, 'XG Boost': 0.9661, 'Light GBM': 0.9665, 'Neural Network': 0.94}
comp = pd.Series(data, name='AUC Score')
# model scores
fig = px.scatter(comp, color=comp.index, size=comp, title='Model Score Comparison', symbol=comp, labels={'index': 'Model', 'value': 'AUC Score'}, template='plotly_white')
fig.show()
# dummy model
probabilities_valid = pipe_dm.predict_proba(X_valid)
probabilities_one_valid = probabilities_valid[:, 1]
auc_roc = roc_auc_score(y_valid, probabilities_one_valid)
print(auc_roc)
# ROC AUC curve of results
fpr, tpr, thresholds = roc_curve(y_valid, probabilities_one_valid)
fig = px.area(
x=fpr, y=tpr,
title=f'ROC Curve (AUC={auc(fpr, tpr):.4f})',
labels=dict(x='False Positive Rate', y='True Positive Rate'),
width=700, height=500
)
fig.add_shape(
type='line', line=dict(dash='dash'),
x0=0, x1=1, y0=0, y1=1
)
fig.update_yaxes(scaleanchor="x", scaleratio=1)
fig.update_xaxes(constrain='domain')
fig.show()
0.5
probabilities_valid = pipe_lr.predict_proba(X_valid)
probabilities_one_valid = probabilities_valid[:, 1]
auc_roc = roc_auc_score(y_valid, probabilities_one_valid)
print(auc_roc)
# ROC AUC curve of results
fpr, tpr, thresholds = roc_curve(y_valid, probabilities_one_valid)
fig = px.area(
x=fpr, y=tpr,
title=f'ROC Curve (AUC={auc(fpr, tpr):.4f})',
labels=dict(x='False Positive Rate', y='True Positive Rate'),
width=700, height=500
)
fig.add_shape(
type='line', line=dict(dash='dash'),
x0=0, x1=1, y0=0, y1=1
)
fig.update_yaxes(scaleanchor="x", scaleratio=1)
fig.update_xaxes(constrain='domain')
fig.show()
0.7728281007460819
probabilities_valid = pipe_dt.predict_proba(X_valid)
probabilities_one_valid = probabilities_valid[:, 1]
auc_roc = roc_auc_score(y_valid, probabilities_one_valid)
print(auc_roc)
# ROC AUC curve of results
fpr, tpr, thresholds = roc_curve(y_valid, probabilities_one_valid)
fig = px.area(
x=fpr, y=tpr,
title=f'ROC Curve (AUC={auc(fpr, tpr):.4f})',
labels=dict(x='False Positive Rate', y='True Positive Rate'),
width=700, height=500
)
fig.add_shape(
type='line', line=dict(dash='dash'),
x0=0, x1=1, y0=0, y1=1
)
fig.update_yaxes(scaleanchor="x", scaleratio=1)
fig.update_xaxes(constrain='domain')
fig.show()
0.7377709495708669
probabilities_valid = pipe_rf.predict_proba(X_valid)
probabilities_one_valid = probabilities_valid[:, 1]
auc_roc = roc_auc_score(y_valid, probabilities_one_valid)
print(auc_roc)
# ROC AUC curve of results
fpr, tpr, thresholds = roc_curve(y_valid, probabilities_one_valid)
fig = px.area(
x=fpr, y=tpr,
title=f'ROC Curve (AUC={auc(fpr, tpr):.4f})',
labels=dict(x='False Positive Rate', y='True Positive Rate'),
width=700, height=500
)
fig.add_shape(
type='line', line=dict(dash='dash'),
x0=0, x1=1, y0=0, y1=1
)
fig.update_yaxes(scaleanchor="x", scaleratio=1)
fig.update_xaxes(constrain='domain')
fig.show()
0.9754461250233604
probabilities_valid = pipe_sv.decision_function(X_valid)
auc_roc = roc_auc_score(y_valid, probabilities_valid)
print(auc_roc)
# ROC AUC curve of results
fpr, tpr, thresholds = roc_curve(y_valid, probabilities_valid)
fig = px.area(
x=fpr, y=tpr,
title=f'ROC Curve (AUC={auc(fpr, tpr):.4f})',
labels=dict(x='False Positive Rate', y='True Positive Rate'),
width=700, height=500
)
fig.add_shape(
type='line', line=dict(dash='dash'),
x0=0, x1=1, y0=0, y1=1
)
fig.update_yaxes(scaleanchor="x", scaleratio=1)
fig.update_xaxes(constrain='domain')
fig.show()
0.772904564076928
probabilities_valid = pipe_xg.predict_proba(X_valid)
probabilities_one_valid = probabilities_valid[:, 1]
auc_roc = roc_auc_score(y_valid, probabilities_one_valid)
print(auc_roc)
# ROC AUC curve of results
fpr, tpr, thresholds = roc_curve(y_valid, probabilities_one_valid)
fig = px.area(
x=fpr, y=tpr,
title=f'ROC Curve (AUC={auc(fpr, tpr):.4f})',
labels=dict(x='False Positive Rate', y='True Positive Rate'),
width=700, height=500
)
fig.add_shape(
type='line', line=dict(dash='dash'),
x0=0, x1=1, y0=0, y1=1
)
fig.update_yaxes(scaleanchor="x", scaleratio=1)
fig.update_xaxes(constrain='domain')
fig.show()
0.9660532498899361
probabilities_valid = pipe_lb.predict_proba(X_valid)
probabilities_one_valid = probabilities_valid[:, 1]
auc_roc = roc_auc_score(y_valid, probabilities_one_valid)
print(auc_roc)
# ROC AUC curve of results
fpr, tpr, thresholds = roc_curve(y_valid, probabilities_one_valid)
fig = px.area(
x=fpr, y=tpr,
title=f'ROC Curve (AUC={auc(fpr, tpr):.4f})',
labels=dict(x='False Positive Rate', y='True Positive Rate'),
width=700, height=500
)
fig.add_shape(
type='line', line=dict(dash='dash'),
x0=0, x1=1, y0=0, y1=1
)
fig.update_yaxes(scaleanchor="x", scaleratio=1)
fig.update_xaxes(constrain='domain')
fig.show()
0.9665248900523244
AUC ROC is a metric that compares the True positive rate with the False Positive Rate. The dashed line through the curve represents 0.50, the score of a random model. AUC scores closer to 0 are poor performing, while a perfect AUC score is 1. We see most models performed well, and some performed excellent, when compared to a random model.
Logistic regression is a model that is simple, fast, and easily interpretable. Logistic regression works well with linearly separable data, and it can handle large datasets with low computational cost. A weakness of this model include its assumption that the input features are linearly separable, which may lead to poor performance, high bias, and underfitting when the data is too complex. Decision trees are also easily interpretable, and they can handle categorical data. It can handle categorical data by implementing one-hot encoding. Decision trees can also capture non-linear relationships. Weaknesses include its inclination to overfit the training data, and not generalize new data. Random forest can also handle categorical and continuous data, and it reduces overfitting by using multiple trees. Random forest is less interpretable than the previous methods, and requires hyperparameter tuning to reduce overfitting. Linear SVC is good for binary classification tasks, and can handle high-dimensional data. SVC models do not work well with imbalanced classes, can be sensitive to outliers, and are slow to train on large datasets. XG boost models can handle both categorical and continuous data, and reduce overfitting by using multiple trees. XG boost models may require significant tuning, which is a downside for those who are not familiar with this algorithm. Light GBM is similar to XG boost, but can handle larger datasets faster and with less memory. However, this model requires hyperparameter tuning to reduce overfitting. MLP models and other neural networks can handle complex relationships between features and targets. Neural networks can be computationally extensive, require hyperparameter tuning, and can suffer from overfitting.
Overall, the best model to use depends on the problem at hand, the size and complexity of the data, and the level of interpretability.
# Logistic regression pipeline feature importance
pipe_lr.fit(X_train, y_train)
logreg_classifier = pipe_lr.named_steps['lr_classifier']
logreg_importances = logreg_classifier.coef_
logreg_indices = np.argsort(logreg_importances)[::-1]
# making dataframe of important coefficients
lr_importance = pd.DataFrame(logreg_importances, columns=X.columns)
lr_importance = lr_importance.T
lr_top_10_df = lr_importance.nlargest(10, columns=0)
fig = px.pie(lr_top_10_df, names=lr_top_10_df.index, values=0, title='Top 10 Linear Regression Coefficients')
fig.show()
# decision tree pipeline feature importance
pipe_dt.fit(X_train, y_train)
dt_classifier = pipe_dt.named_steps['dt_classifier']
dt_importances = dt_classifier.feature_importances_
dt_indices = np.argsort(dt_importances)[::-1]
top_10_features = []
for f in range(10):
feature_index = dt_indices[f]
feature_name = train.columns[feature_index]
top_10_features.append((feature_name, dt_importances[feature_index]))
dt_top_10_df = pd.DataFrame(top_10_features, columns=['Feature', 'Importance'])
fig = px.pie(dt_top_10_df.head(2), title='Top Features of Decision Tree', names='Feature', values='Importance')
fig.show()
# random forest pipeline feature importance
pipe_rf.fit(X_train, y_train)
rf_classifier = pipe_rf.named_steps['rf_classifier']
rf_importances = rf_classifier.feature_importances_
rf_indices = np.argsort(rf_importances)[::-1]
top_10_features = []
for f in range(10):
feature_index = rf_indices[f]
feature_name = train.columns[feature_index]
top_10_features.append((feature_name, rf_importances[feature_index]))
rf_top_10_df = pd.DataFrame(top_10_features, columns=['Feature', 'Importance'])
fig = px.pie(rf_top_10_df, title='Top 10 Features of Random Forest', names='Feature', values='Importance')
fig.show()
# Support vector pipeline feature importance
pipe_sv.fit(X_train, y_train)
svm_classifier = pipe_sv.named_steps['svm_classifier']
svm_importances = svm_classifier.coef_
svm_indices = np.argsort(svm_importances)[::-1]
# making dataframe of important coefficients
sv_importance = pd.DataFrame(svm_importances, columns=X.columns)
sv_importance = sv_importance.T
sv_top_10_df = sv_importance.nlargest(10, columns=0)
c:\Users\XIX\anaconda3\lib\site-packages\sklearn\svm\_base.py:1206: ConvergenceWarning: Liblinear failed to converge, increase the number of iterations.
fig = px.pie(sv_top_10_df, names=sv_top_10_df.index, values=0, title='Top 10 Support Vector Coefficients')
fig.show()
# xg boost pipeline feature importance
pipe_xg.fit(X_train, y_train)
xg_classifier = pipe_xg.named_steps['xg_classifier']
xg_importances = xg_classifier.feature_importances_
xg_indices = np.argsort(xg_importances)[::-1]
top_10_features = []
for f in range(10):
feature_index = xg_indices[f]
feature_name = train.columns[feature_index]
top_10_features.append((feature_name, xg_importances[feature_index]))
xg_top_10_df = pd.DataFrame(top_10_features, columns=['Feature', 'Importance'])
fig = px.pie(xg_top_10_df, title='Top 10 Features of XG Boost', names='Feature', values='Importance')
fig.show()
# light boost pipeline feature importance
pipe_lb.fit(X_train, y_train)
lb_classifier = pipe_lb.named_steps['lb_classifier']
lb_importances = lb_classifier.feature_importances_
lb_indices = np.argsort(lb_importances)[::-1]
top_10_features = []
for f in range(10):
feature_index = lb_indices[f]
feature_name = train.columns[feature_index]
top_10_features.append((feature_name, lb_importances[feature_index]))
lb_top_10_df = pd.DataFrame(top_10_features, columns=['Feature', 'Importance'])
fig = px.pie(lb_top_10_df, title='Top 10 Features of XG Boost', names='Feature', values='Importance')
fig.show()
If scoring metrics can not be used to chose a model, feature importance can help pick a model based on explainability. Explainability is how to take a machine learning model and express the behavior in human terms. With complex models, you can not fully understand how the model parameters impact predictions. With feature importance, we can pick a model based on how it makes predictions, and which features are most important to each model. Even without feature importance, a model can still be selected based on its interpretability, as simpler models are easier to explain to stakeholders.
Another factor in choosing a model is the resource requirement of the machine learning algorithms. More complex models require more memory or computing power to train or make predictions. With limited resources, model selection is limited to simpler models.
Furthermore, we can use visualizations to show how predictions of two models differ from actual values. A confusion matrix can show true positive and true negative values, and a visualization of the confusion matrix can illustrate the results of the classification model's predictions.
# confusion matrix map
fig = go.Figure(data=go.Heatmap(z=[[1205, 185], [8557, 53]], text=[['False Negatives', 'True Positives'], ['True Negatives', 'False Positives']],
texttemplate="%{text}", textfont={"size":20}))
fig.show()
# validation predictions of logistic regression
valid_pred_lr = pipe_lr.predict(X_valid)
# confusion matrix of validation set of logistic regression
fig = px.imshow(confusion_matrix(y_valid, valid_pred_lr), text_auto=True, labels=dict(y="Actual", x="Predicted"),
x=['Negative', 'Positive'],
y=['Negative', 'Positive'], title='Confusion Matrix of Logistic Regression')
fig.show()
The true negative value is 5841, while the true positive value is 6154. Overall, the model performed moderately at predicting the negative and positive class. The model had nearly half as many incorrect positive, and less than half as many negative class predictions, as the respective correct predictions.
# validation predictions of xg boost
valid_pred_rf = pipe_rf.predict(X_valid)
# confusion matrix of validation set of xg boost
fig = px.imshow(confusion_matrix(y_valid, valid_pred_rf), text_auto=True, labels=dict(y="Actual", x="Predicted"),
x=['Negative', 'Positive'],
y=['Negative', 'Positive'], title='Confusion Matrix of Random Forest')
fig.show()
The confusion matrix illustrates the true negative value of 8195 and a true positive vale of 7625, which are predicted values that match actual values. Overall, the model was excellent at predicting the negative class, and fairly good at predicting the positive class. This is further supported by the false negative value of 915, which are the instances where the model incorrectly predicted a negative class. Our model performed best when we using SMOTE to balance our datasets. SMOTE works by using the K nearest neighbors algorithm to create synthetic examples of the minority class, thereby balancing the data.
The confusion matrix on the validation set is used to illustrate how we expect the model will perform on the test set.
# validation f1 score of logistic regression
f1_score(y_valid, valid_pred_lr)
0.7068688260969446
# classification report
print(classification_report(y_valid, valid_pred_lr))
precision recall f1-score support 0 0.71 0.68 0.70 8559 1 0.69 0.72 0.71 8540 accuracy 0.70 17099 macro avg 0.70 0.70 0.70 17099 weighted avg 0.70 0.70 0.70 17099
The classification report breaks down the precision and recall of the model with respect to each class. Precision tells us how well the model identifies relevant instances, while recall tells us how well the model captures all relevant instances. A model high precision and recall is a strong model. With the Logistic regression model, we see moderate precision and recall with the negative class. The positive class has similar precision and recall. Consequently, the f1 scores of the negative and positive classes are both moderate.
# validation f1 score of random forest
f1_score(y_valid, valid_pred_rf)
0.9226208482061831
# classification report
print(classification_report(y_valid, valid_pred_rf))
precision recall f1-score support 0 0.90 0.96 0.93 8559 1 0.95 0.89 0.92 8540 accuracy 0.93 17099 macro avg 0.93 0.93 0.93 17099 weighted avg 0.93 0.93 0.93 17099
In our case with random forest, we see high precision and recall in the negative class. The positive class has high precision, and slightly lower recall. As F1 score is the harmonic mean of precision and recall, both classes have a high F1 score.
Based on the confusion matrices and classification reports, we expect the random forest model to perform better. The random forest model had more true positive and true negative values than the logistic regression model, when comparing performance on the validation set.
# final Linear regression
final_lr = pipe_lr.fit(X_final, y_final)
# Final xg boost model
final_rf = pipe_rf.fit(X_final, y_final)
Now that we have selected our final models, we use the full training set to fit the models.
# ordinal encoding days and months in order
days_test = pd.DataFrame(encoder_day.fit_transform(test.x3.to_numpy().reshape(-1,1)), columns=['day'])
months_test = pd.DataFrame(encoder_month.fit_transform(test.x60.to_numpy().reshape(-1,1)), columns=['month'])
# replace columns with ordinal columns
test['x3'] = days_test
test['x60'] = months_test
test.head()
x1 | x2 | x3 | x4 | x5 | x6 | x7 | x8 | x9 | x10 | x11 | x12 | x13 | x14 | x15 | x16 | x17 | x18 | x19 | x20 | x21 | x22 | x23 | x24 | x25 | x26 | x27 | x28 | x29 | x30 | x31 | x32 | x33 | x34 | x35 | x36 | x37 | x38 | x39 | x40 | x41 | x42 | x43 | x44 | x45 | x46 | x47 | x48 | x49 | x50 | x51 | x52 | x53 | x54 | x55 | x56 | x57 | x58 | x59 | x60 | x61 | x62 | x63 | x64 | x65 | x66 | x67 | x68 | x69 | x70 | x71 | x72 | x73 | x74 | x75 | x76 | x77 | x78 | x79 | x80 | x81 | x82 | x83 | x84 | x85 | x86 | x87 | x88 | x89 | x90 | x91 | x92 | x93 | x94 | x95 | x96 | x97 | x98 | x99 | x100 | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 4.747627 | 20.509439 | 2.0 | 2.299105 | -1.815777 | -0.752166 | 0.0098 | -3.240309 | 0.587948 | -0.260721 | 101.113628 | -0.812035 | 3.251085 | -0.004432 | -0.917763 | 102.193597 | 7.097387 | 75.256714 | 120.216190 | 8.269754 | 4.794438 | 56.134458 | -0.083641 | NaN | 5.740955 | -3.152540 | 924.013304 | 17.697125 | 1.402273 | NaN | no | 1.461654 | Florida | 57.930285 | 4.727300 | 66.416594 | 28.450466 | 33.499310 | 1 | 16.776061 | 99971.844703 | 0.492812 | -0.963616 | NaN | NaN | 1.266416 | 53.020326 | 61.840284 | NaN | 1825.413159 | 2.517924 | NaN | 70.880778 | NaN | 1.923125 | 0.239009 | NaN | 300.62999 | 0 | 4.0 | 5.461123 | 5.149473 | 1.179229 | 59.346822 | progressive | 2.219502 | 17.667792 | 5.557066 | -2.030901 | 0.636111 | -0.000234 | 62.363381 | 4.613878 | 2.415655 | 3.632041 | -0.382482 | mercedes | 55.938387 | 1.0 | 8.325299 | -1.274085 | -5.663245 | 2.402660 | -0.061040 | NaN | 64.002500 | 0.548974 | 3.420875 | 11.553023 | 5.904644 | NaN | 12.542333 | no | 3.107683 | 0.533904 | 12.438759 | 7.298306 | 0 | no | 93.567120 |
1 | 1.148654 | 19.301465 | 4.0 | 1.862200 | -0.773707 | -1.461276 | 0.0076 | 0.443209 | 0.522113 | -1.090886 | 104.791999 | 8.805876 | 1.651993 | NaN | -1.396065 | 99.356609 | 7.117930 | 79.061540 | -267.562586 | 5.668347 | 12.377178 | NaN | 0.321981 | female | 2.058123 | -0.442525 | 1107.628399 | 15.747234 | 2.027073 | NaN | yes | 0.608259 | North Carolina | 55.646392 | 0.789132 | 68.881807 | 32.242558 | -1.034680 | 1 | 11.959804 | 99910.554923 | 0.547935 | 1.001799 | 1.315020 | 3.229084 | 5.873890 | 49.116516 | 61.279131 | 9.360134 | 1818.390072 | 2.269700 | 0.336755 | 69.832156 | 2.666487 | 1.224195 | -1.214940 | NaN | 300.62999 | 0 | 6.0 | 6.520113 | 2.101449 | -0.871509 | NaN | allstate | 1.289800 | 14.402071 | 7.302161 | -1.553879 | 1.352019 | -0.000234 | 54.104054 | 9.010198 | 0.537178 | 1.489493 | -1.106853 | ford | 44.322947 | 1.0 | 6.088895 | 1.878944 | -8.237793 | 3.636347 | 3.726452 | 105.488589 | 53.387201 | -0.751229 | 0.295234 | 6.212811 | 4.876645 | -0.848567 | 7.213829 | yes | 4.276078 | NaN | 10.386987 | 12.527094 | 1 | yes | 98.607486 |
2 | 4.986860 | 18.769675 | 5.0 | 1.040845 | -1.548690 | 2.632948 | -0.0005 | -1.167885 | 5.739275 | 0.222975 | 102.109546 | 7.831517 | 3.055358 | 2.036434 | 1.057296 | NaN | 10.943217 | 87.567827 | -311.292903 | 3.219583 | 31.135956 | 50.048638 | 1.174485 | male | 1.609278 | -0.303259 | 1009.911508 | 12.008190 | 0.618778 | NaN | no | 0.680923 | NaN | 53.178113 | 0.869321 | 70.249633 | 35.207243 | 48.980294 | 1 | 14.564732 | 100729.380783 | 0.096947 | -0.490053 | NaN | 1.333292 | 0.750075 | 48.258898 | 63.737244 | 11.564194 | 1815.680559 | 1.704048 | NaN | 67.090400 | 1.547230 | NaN | 1.428580 | NaN | 300.62999 | 0 | 0.0 | 12.190433 | 1.793349 | -0.114922 | 48.121885 | progressive | -1.755707 | 14.402071 | 9.903803 | 1.720469 | 0.765756 | -0.000234 | 51.522621 | 11.700359 | -1.867170 | 0.383319 | -1.078648 | NaN | 48.854080 | 1.0 | 8.711055 | -0.073306 | -8.739095 | NaN | 2.138771 | NaN | 49.687134 | 2.641871 | 1.718243 | NaN | 4.455432 | 1.143388 | 10.483928 | no | 2.090868 | -1.780474 | 11.328177 | 11.628247 | 0 | yes | 94.578246 |
3 | 3.709183 | 18.374375 | 1.0 | -0.169882 | -2.396549 | -0.784673 | -0.0160 | -2.662226 | 1.548050 | 0.210141 | 82.653354 | 0.436885 | 1.578106 | NaN | -1.287913 | 102.410965 | 6.588790 | 71.825782 | 2229.149400 | 7.459929 | 1.822459 | 88.144007 | 0.909556 | female | 8.864059 | 0.641209 | 841.889126 | 20.904196 | 0.725017 | NaN | no | 0.622849 | Mississippi | 50.311869 | 0.453211 | 65.253390 | 34.432292 | 52.756665 | 1 | 18.503815 | 101476.778846 | 0.888038 | -0.007376 | -1.126059 | 1.129508 | -0.455920 | 44.525657 | 60.008453 | 12.852088 | 2251.680231 | 2.915405 | 5.895661 | 75.219207 | NaN | NaN | -0.415800 | NaN | 300.62999 | 0 | 6.0 | 6.865209 | 5.083537 | 1.685063 | 46.761738 | geico | -0.807993 | 14.402071 | 16.576216 | 0.033036 | 0.284538 | -0.000234 | 54.625974 | 13.160347 | -0.329204 | 2.171326 | -0.109125 | subaru | NaN | 1.0 | 11.742605 | -0.253294 | -6.641284 | 4.755348 | 2.138771 | NaN | NaN | 1.811825 | 0.461637 | 18.198978 | 3.947223 | 0.693646 | 3.862867 | no | 2.643847 | 1.662240 | 10.064961 | 10.550014 | 1 | no | 100.346261 |
4 | 3.801616 | 20.205541 | 0.0 | 2.092652 | -0.732784 | -0.703101 | 0.0186 | 0.056422 | 2.878167 | -0.457618 | 75.036421 | 8.034303 | 1.631426 | 0.643738 | 0.349166 | 101.513490 | 5.777599 | 74.602441 | -469.049530 | 8.245194 | 0.904920 | 51.705319 | -0.544762 | female | 2.408958 | 1.841905 | 885.172420 | 14.401750 | 4.059599 | NaN | yes | 1.073262 | Georgia | 39.646787 | -0.686812 | 71.673393 | 37.257458 | 64.572325 | 1 | 11.477353 | 99444.069807 | 0.597749 | 0.432984 | NaN | 2.973636 | 2.684343 | 46.377723 | 55.276157 | 15.245726 | 3377.213091 | 0.461064 | 9.296694 | 64.547880 | 2.196671 | NaN | 3.294733 | NaN | 300.62999 | 0 | 0.0 | NaN | 4.758357 | -1.053362 | 49.328246 | progressive | -0.943724 | 15.155869 | 24.834647 | 3.127852 | 1.427115 | -0.000234 | 55.277258 | 14.443014 | -1.075761 | 6.086487 | -1.002809 | ford | 51.429529 | 0.0 | 11.602066 | 0.091523 | -4.620275 | 2.060447 | 2.138771 | NaN | 49.747279 | 0.320393 | 0.930729 | 10.014853 | 1.637334 | -0.834763 | 3.632039 | yes | 4.074434 | NaN | 9.255766 | 12.716137 | 1 | yes | 102.578918 |
# Preprocess the test data
X_test_transformed = preprocessor.transform(test)
# shape of test set
X_test_transformed.shape
(10000, 100)
We follow the same preprocessing steps as the training set, to transform the test set for the model.
# test set predictions
valid_pred_lr = final_lr.predict_proba(X_test_transformed)
valid_pred_rf = final_rf.predict_proba(X_test_transformed)
We run predictions on the transformed test datasets, and extract the probabilities of each class. The model will assign a class based on the highest predicted probability. The default threshold is 0.5. If class 0 predicted probability is higher than the 0.5 threshold, the model will predict a class of 0. Conversely, if the predicted probability of the positive class is greater than the threshold, the model will predict class 1. Probabilities allow us to determine how confident the model is in each class prediction, as probabilities closer to 1 are more certain than those closer to 0.5.
# probabilities of positive class
lr_list = valid_pred_lr[:,1].tolist()
We extract the predicted probabilities of the positive class. In other words, these values represent the predicted probability that the target is class 1.
# Create a DataFrame from the list
lr_df = pd.DataFrame(lr_list)
# Save the DataFrame to a CSV file
#lr_df.to_csv('predictions/glmresults.csv', index=False, header=False)
We save the predictions to a csv file, where each value is the predicted probability of the positive class.
# probabilities of positive class
rf_list = valid_pred_rf[:,1].tolist()
We extract the probabilities of the positive class.
# Create a DataFrame from the list
rf_df = pd.DataFrame(rf_list)
# Save the DataFrame to a CSV file
#rf_df.to_csv('predictions/nonglmresults.csv', index=False, header=False)
We save the values to another csv file.
# logistic regression class predictions
lr_class = pd.DataFrame(final_lr.predict(X_test_transformed))
# class prediction counts
lr_class.value_counts()
1 9700 0 300 dtype: int64
The logistic regression model made 6347 negative predictions, and 3653 positive predictions.
# xg boost class predictions
rf_class = pd.DataFrame(final_rf.predict(X_test_transformed))
# class prediction counts
rf_class.value_counts()
0 7726 1 2274 dtype: int64
The random forest model made 896 more negative predictions than the logistic regression model.
One of the main issues with the dataset was the amount of missing values. Deleting missing values leads to a loss of valuable information, and model performance would suffer, unless the proportion of missing data is minimal. Imputing these missing values could recover some of the missing information, which can result in a better model. However, the reason why the data is missing, as well as the imputation method implemented, can have a significant impact on the model performance.
The datasets were cleaned and preprocessed with ordinal encoding, standard scaling, simple imputing, and SMOTE. Ordinal encoding allowed us to convert categorical features into numerical labels, to then train our models. Standard scaling was implemented to improve the performance of the models, as features with much higher scales will be given greater weights, merely because they have larger values. By scaling features to the same level, we ensure the model interprets the weights of each feature equally. Simple imputer was used to fill in the missing values, while SMOTE was used to balance the minority class.
Several models were trained, and two were selected: logistic regression and random forest. Logistic regression serves as a baseline model for the more complex random forest. The logistic model is simple and easy to to interpret, however, it assumes a linear relationship between the features and the target. Random forest is a more powerful model that can handle a large number of inputs, without suffering from overfitting, and the model is also less prone to overfitting with outliers and noisy data. However, random forest can still overfit noisy data, when datasets contain a large number of irrelevant features. Furthermore, random forest models are difficult to interpret, as they are comprised of several decision trees.
When it comes to selecting between the two models, our main determinant was model performance. If our decision was not based on performance, but on interpretability, we would choose logistic regression. However, Based on model performance on the validation set, I expect random forest to perform better on the test set. In addition, the models in the pipeline that did not assume linearity performed better.
AUC, or area under the curve, calculates the true positive rate against the false positive rate, where 1 represents a perfect model, and 0 is the worst model. As the AUC of a model is more often lower on the test set than on the validation set, we assume the random forest model will perform significantly better than logistic regression on the test set, as it has a high AUC on the validation set. In addition, the random forest made more correct predictions in both the positive and negative classes, as evident by the confusion matrix of the validation set. We estimate the AUC score of the logistic regression model to be between 0.60-0.80, while the AUC of the random forest model may be between 0.9-1.0. AUC was the appropriate metric to evaluate our models, as accuracy is not suitable for data with imbalanced classes. This was further illustrated by the dummy model. We found using SMOTE significantly increased our model performance among non-linear models, as SMOTE balanced the minority class in the data.
If we could not use a scoring metric to compare the two models, we can compare the predictions of the two models on the test set. We can compare the true positive and true negative values of both models, as the model with more correct predictions will perform better. We can also compare the false positive and false negative values of both models.
Overall, the appropriate model and scoring to implement depends on the the data and the business needs. Many factors can determine the appropriate machine learning algorithm to use, from limited resources and large datasets, to categorical values and model interpretability.