From 27f67d70ed54b080eeb0191f9154b54801349243 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9ssica=20Amaral?= Date: Mon, 3 Apr 2023 21:53:08 -0300 Subject: [PATCH] updates librarys and methods on Chapter 9 --- 09_tabular.ipynb | 21 +++++++++++---------- clean/09_tabular.ipynb | 21 +++++++++++---------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/09_tabular.ipynb b/09_tabular.ipynb index e3ed14c..f902be1 100644 --- a/09_tabular.ipynb +++ b/09_tabular.ipynb @@ -26,7 +26,7 @@ "from fastai.tabular.all import *\n", "from sklearn.ensemble import RandomForestRegressor\n", "from sklearn.tree import DecisionTreeRegressor\n", - "from dtreeviz.trees import *\n", + "import dtreeviz\n", "from IPython.display import Image, display_svg, SVG\n", "\n", "pd.options.display.max_rows = 20\n", @@ -4483,9 +4483,10 @@ ], "source": [ "samp_idx = np.random.permutation(len(y))[:500]\n", - "dtreeviz(m, xs.iloc[samp_idx], y.iloc[samp_idx], xs.columns, dep_var,\n", - " fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n", - " orientation='LR')" + "viz_model = dtreeviz.model(m, xs.iloc[samp_idx], y.iloc[samp_idx],\n", + " feature_names=xs.columns, target_name=dep_var)\n", + "viz_model.view(fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n", + " orientation='LR')" ] }, { @@ -7402,9 +7403,10 @@ "source": [ "m = DecisionTreeRegressor(max_leaf_nodes=4).fit(xs, y)\n", "\n", - "dtreeviz(m, xs.iloc[samp_idx], y.iloc[samp_idx], xs.columns, dep_var,\n", - " fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n", - " orientation='LR')" + "viz_model = dtreeviz.model(m, xs.iloc[samp_idx], y.iloc[samp_idx],\n", + " feature_names=xs.columns, target_name=dep_var)\n", + "viz_model.view(fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n", + " orientation='LR')" ] }, { @@ -8672,11 +8674,10 @@ } ], "source": [ - "from sklearn.inspection import plot_partial_dependence\n", + "from sklearn.inspection import PartialDependenceDisplay\n", "\n", "fig,ax = plt.subplots(figsize=(12, 4))\n", - "plot_partial_dependence(m, valid_xs_final, ['YearMade','ProductSize'],\n", - " grid_resolution=20, ax=ax);" + "PartialDependenceDisplay.from_estimator(m, valid_xs_final, ['YearMade','ProductSize'], ax=ax)" ] }, { diff --git a/clean/09_tabular.ipynb b/clean/09_tabular.ipynb index 57ef84d..c62fde1 100644 --- a/clean/09_tabular.ipynb +++ b/clean/09_tabular.ipynb @@ -24,7 +24,7 @@ "from fastai.tabular.all import *\n", "from sklearn.ensemble import RandomForestRegressor\n", "from sklearn.tree import DecisionTreeRegressor\n", - "from dtreeviz.trees import *\n", + "import dtreeviz\n", "from IPython.display import Image, display_svg, SVG\n", "\n", "pd.options.display.max_rows = 20\n", @@ -402,9 +402,10 @@ "outputs": [], "source": [ "samp_idx = np.random.permutation(len(y))[:500]\n", - "dtreeviz(m, xs.iloc[samp_idx], y.iloc[samp_idx], xs.columns, dep_var,\n", - " fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n", - " orientation='LR')" + "viz_model = dtreeviz.model(m, xs.iloc[samp_idx], y.iloc[samp_idx],\n", + " feature_names=xs.columns, target_name=dep_var)\n", + "viz_model.view(fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n", + " orientation='LR')" ] }, { @@ -425,9 +426,10 @@ "source": [ "m = DecisionTreeRegressor(max_leaf_nodes=4).fit(xs, y)\n", "\n", - "dtreeviz(m, xs.iloc[samp_idx], y.iloc[samp_idx], xs.columns, dep_var,\n", - " fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n", - " orientation='LR')" + "viz_model = dtreeviz.model(m, xs.iloc[samp_idx], y.iloc[samp_idx],\n", + " feature_names=xs.columns, target_name=dep_var)\n", + "viz_model.view(fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n", + " orientation='LR')" ] }, { @@ -888,11 +890,10 @@ "metadata": {}, "outputs": [], "source": [ - "from sklearn.inspection import plot_partial_dependence\n", + "from sklearn.inspection import PartialDependenceDisplay\n", "\n", "fig,ax = plt.subplots(figsize=(12, 4))\n", - "plot_partial_dependence(m, valid_xs_final, ['YearMade','ProductSize'],\n", - " grid_resolution=20, ax=ax);" + "PartialDependenceDisplay.from_estimator(m, valid_xs_final, ['YearMade','ProductSize'], ax=ax)" ] }, {