This commit is contained in:
Jéssica Amaral 2025-02-03 13:00:50 +00:00 committed by GitHub
commit 6d12243a7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 20 deletions

View File

@ -26,7 +26,7 @@
"from fastai.tabular.all import *\n", "from fastai.tabular.all import *\n",
"from sklearn.ensemble import RandomForestRegressor\n", "from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.tree import DecisionTreeRegressor\n", "from sklearn.tree import DecisionTreeRegressor\n",
"from dtreeviz.trees import *\n", "import dtreeviz\n",
"from IPython.display import Image, display_svg, SVG\n", "from IPython.display import Image, display_svg, SVG\n",
"\n", "\n",
"pd.options.display.max_rows = 20\n", "pd.options.display.max_rows = 20\n",
@ -4483,8 +4483,9 @@
], ],
"source": [ "source": [
"samp_idx = np.random.permutation(len(y))[:500]\n", "samp_idx = np.random.permutation(len(y))[:500]\n",
"dtreeviz(m, xs.iloc[samp_idx], y.iloc[samp_idx], xs.columns, dep_var,\n", "viz_model = dtreeviz.model(m, xs.iloc[samp_idx], y.iloc[samp_idx],\n",
" fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n", " feature_names=xs.columns, target_name=dep_var)\n",
"viz_model.view(fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n",
" orientation='LR')" " orientation='LR')"
] ]
}, },
@ -7402,8 +7403,9 @@
"source": [ "source": [
"m = DecisionTreeRegressor(max_leaf_nodes=4).fit(xs, y)\n", "m = DecisionTreeRegressor(max_leaf_nodes=4).fit(xs, y)\n",
"\n", "\n",
"dtreeviz(m, xs.iloc[samp_idx], y.iloc[samp_idx], xs.columns, dep_var,\n", "viz_model = dtreeviz.model(m, xs.iloc[samp_idx], y.iloc[samp_idx],\n",
" fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n", " feature_names=xs.columns, target_name=dep_var)\n",
"viz_model.view(fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n",
" orientation='LR')" " orientation='LR')"
] ]
}, },
@ -8672,11 +8674,10 @@
} }
], ],
"source": [ "source": [
"from sklearn.inspection import plot_partial_dependence\n", "from sklearn.inspection import PartialDependenceDisplay\n",
"\n", "\n",
"fig,ax = plt.subplots(figsize=(12, 4))\n", "fig,ax = plt.subplots(figsize=(12, 4))\n",
"plot_partial_dependence(m, valid_xs_final, ['YearMade','ProductSize'],\n", "PartialDependenceDisplay.from_estimator(m, valid_xs_final, ['YearMade','ProductSize'], ax=ax)"
" grid_resolution=20, ax=ax);"
] ]
}, },
{ {

View File

@ -24,7 +24,7 @@
"from fastai.tabular.all import *\n", "from fastai.tabular.all import *\n",
"from sklearn.ensemble import RandomForestRegressor\n", "from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.tree import DecisionTreeRegressor\n", "from sklearn.tree import DecisionTreeRegressor\n",
"from dtreeviz.trees import *\n", "import dtreeviz\n",
"from IPython.display import Image, display_svg, SVG\n", "from IPython.display import Image, display_svg, SVG\n",
"\n", "\n",
"pd.options.display.max_rows = 20\n", "pd.options.display.max_rows = 20\n",
@ -402,8 +402,9 @@
"outputs": [], "outputs": [],
"source": [ "source": [
"samp_idx = np.random.permutation(len(y))[:500]\n", "samp_idx = np.random.permutation(len(y))[:500]\n",
"dtreeviz(m, xs.iloc[samp_idx], y.iloc[samp_idx], xs.columns, dep_var,\n", "viz_model = dtreeviz.model(m, xs.iloc[samp_idx], y.iloc[samp_idx],\n",
" fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n", " feature_names=xs.columns, target_name=dep_var)\n",
"viz_model.view(fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n",
" orientation='LR')" " orientation='LR')"
] ]
}, },
@ -425,8 +426,9 @@
"source": [ "source": [
"m = DecisionTreeRegressor(max_leaf_nodes=4).fit(xs, y)\n", "m = DecisionTreeRegressor(max_leaf_nodes=4).fit(xs, y)\n",
"\n", "\n",
"dtreeviz(m, xs.iloc[samp_idx], y.iloc[samp_idx], xs.columns, dep_var,\n", "viz_model = dtreeviz.model(m, xs.iloc[samp_idx], y.iloc[samp_idx],\n",
" fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n", " feature_names=xs.columns, target_name=dep_var)\n",
"viz_model.view(fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n",
" orientation='LR')" " orientation='LR')"
] ]
}, },
@ -888,11 +890,10 @@
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
"source": [ "source": [
"from sklearn.inspection import plot_partial_dependence\n", "from sklearn.inspection import PartialDependenceDisplay\n",
"\n", "\n",
"fig,ax = plt.subplots(figsize=(12, 4))\n", "fig,ax = plt.subplots(figsize=(12, 4))\n",
"plot_partial_dependence(m, valid_xs_final, ['YearMade','ProductSize'],\n", "PartialDependenceDisplay.from_estimator(m, valid_xs_final, ['YearMade','ProductSize'], ax=ax)"
" grid_resolution=20, ax=ax);"
] ]
}, },
{ {