This commit is contained in:
Jéssica Amaral 2025-02-03 13:00:50 +00:00 committed by GitHub
commit 6d12243a7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 22 additions and 20 deletions

View File

@ -26,7 +26,7 @@
"from fastai.tabular.all import *\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.tree import DecisionTreeRegressor\n",
"from dtreeviz.trees import *\n",
"import dtreeviz\n",
"from IPython.display import Image, display_svg, SVG\n",
"\n",
"pd.options.display.max_rows = 20\n",
@ -4483,9 +4483,10 @@
],
"source": [
"samp_idx = np.random.permutation(len(y))[:500]\n",
"dtreeviz(m, xs.iloc[samp_idx], y.iloc[samp_idx], xs.columns, dep_var,\n",
" fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n",
" orientation='LR')"
"viz_model = dtreeviz.model(m, xs.iloc[samp_idx], y.iloc[samp_idx],\n",
" feature_names=xs.columns, target_name=dep_var)\n",
"viz_model.view(fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n",
" orientation='LR')"
]
},
{
@ -7402,9 +7403,10 @@
"source": [
"m = DecisionTreeRegressor(max_leaf_nodes=4).fit(xs, y)\n",
"\n",
"dtreeviz(m, xs.iloc[samp_idx], y.iloc[samp_idx], xs.columns, dep_var,\n",
" fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n",
" orientation='LR')"
"viz_model = dtreeviz.model(m, xs.iloc[samp_idx], y.iloc[samp_idx],\n",
" feature_names=xs.columns, target_name=dep_var)\n",
"viz_model.view(fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n",
" orientation='LR')"
]
},
{
@ -8672,11 +8674,10 @@
}
],
"source": [
"from sklearn.inspection import plot_partial_dependence\n",
"from sklearn.inspection import PartialDependenceDisplay\n",
"\n",
"fig,ax = plt.subplots(figsize=(12, 4))\n",
"plot_partial_dependence(m, valid_xs_final, ['YearMade','ProductSize'],\n",
" grid_resolution=20, ax=ax);"
"PartialDependenceDisplay.from_estimator(m, valid_xs_final, ['YearMade','ProductSize'], ax=ax)"
]
},
{

View File

@ -24,7 +24,7 @@
"from fastai.tabular.all import *\n",
"from sklearn.ensemble import RandomForestRegressor\n",
"from sklearn.tree import DecisionTreeRegressor\n",
"from dtreeviz.trees import *\n",
"import dtreeviz\n",
"from IPython.display import Image, display_svg, SVG\n",
"\n",
"pd.options.display.max_rows = 20\n",
@ -402,9 +402,10 @@
"outputs": [],
"source": [
"samp_idx = np.random.permutation(len(y))[:500]\n",
"dtreeviz(m, xs.iloc[samp_idx], y.iloc[samp_idx], xs.columns, dep_var,\n",
" fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n",
" orientation='LR')"
"viz_model = dtreeviz.model(m, xs.iloc[samp_idx], y.iloc[samp_idx],\n",
" feature_names=xs.columns, target_name=dep_var)\n",
"viz_model.view(fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n",
" orientation='LR')"
]
},
{
@ -425,9 +426,10 @@
"source": [
"m = DecisionTreeRegressor(max_leaf_nodes=4).fit(xs, y)\n",
"\n",
"dtreeviz(m, xs.iloc[samp_idx], y.iloc[samp_idx], xs.columns, dep_var,\n",
" fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n",
" orientation='LR')"
"viz_model = dtreeviz.model(m, xs.iloc[samp_idx], y.iloc[samp_idx],\n",
" feature_names=xs.columns, target_name=dep_var)\n",
"viz_model.view(fontname='DejaVu Sans', scale=1.6, label_fontsize=10,\n",
" orientation='LR')"
]
},
{
@ -888,11 +890,10 @@
"metadata": {},
"outputs": [],
"source": [
"from sklearn.inspection import plot_partial_dependence\n",
"from sklearn.inspection import PartialDependenceDisplay\n",
"\n",
"fig,ax = plt.subplots(figsize=(12, 4))\n",
"plot_partial_dependence(m, valid_xs_final, ['YearMade','ProductSize'],\n",
" grid_resolution=20, ax=ax);"
"PartialDependenceDisplay.from_estimator(m, valid_xs_final, ['YearMade','ProductSize'], ax=ax)"
]
},
{