Tuesday, August 28, 2018

Machine Learning 2. Partial Dependence Plots (PDP).

Sometimes it seems that ML models are something like black-box - you can't see how model is working and how you can view and improve it's logic. To do so partial dependence plots are used.  PDP shows how each variable or predictor (features) affect the model's predictions, they can be interpreted similarly as coefficients in DT models.

Our data will be:
[admin@localhost ~]$ cat > test.csv
Rooms,Price,Floors,Area,HouseColor
1,300,1,30,red
1,400,1,50,green
3,400,1,65,blue
2,200,1,45,green
5,700,3,120,yellow
,400,2,70,blue
,300,1,40,blue
4,,2,95,brown

We'll use PDP to understand relationship between Price and other variables. So that PDP helps to find data insights and also see something you might think being important to be used in model building and prediction. PDP is calculated only after the model has been trained (fit).

>>> test_file_path = "~/test.csv"
>>> import pandas as pd
>>> test_data = pd.read_csv(test_file_path)
>>> test_data.dropna(axis=0,subset=['Price'],inplace=True)
>>> y = test_data.Price
>>> X = test_data.drop(['Price'],axis=1)
>>> X = X.select_dtypes(exclued=['object'])
>>> from sklearn.preprocessing import Imputer
>>> test_imputer = Imputer()
>>> X = test_imputer.fit_transform(X)
>>> # for now sklearn supports PDP only for GradientBoostingRegressor
>>> from sklearn.ensemble import GradientBoostingRegressor
>>> test_model = GradientBoostingRegressor()
>>> test_model.fit(X,y)
>>> from sklearn.ensemble.partial_dependence import partial_dependence, plot_partial_dependence
>>> test_plots = plot_partial_dependence(gbrt=test_model,X=X,features=[0,1,2],feature_names=['Rooms', 'Floors', 'Area'],grid_resolution=10)
Options described:

  • gbrt - which GBR model to use
  • X - which dataset used to train model specified in gbrt option
  • features - index of columns of the dataset specified in X option which will be used in plotting (each index/column will create 1 PDP)
  • feature_names - how to name columns selected in features option
  • grid_resolution - number of values to plot on x axis
Negative values mean that Price would be less than average Price for that variable. 

No comments:

Post a Comment