Skip to content

Instantly share code, notes, and snippets.

@rhoit
Created February 16, 2026 03:00
Show Gist options
  • Select an option

  • Save rhoit/ed9e862a93f99d76c955dcdfef4cf041 to your computer and use it in GitHub Desktop.

Select an option

Save rhoit/ed9e862a93f99d76c955dcdfef4cf041 to your computer and use it in GitHub Desktop.
import sklearn.tree
import sklearn.preprocessing
import matplotlib.pyplot as plt
import pandas as pd
df = pd.read_csv('data/movie.csv')
labelEncoder = sklearn.preprocessing.LabelEncoder()
fig, axes = plt.subplots(1, 2, layout='tight')
for i, name in enumerate(('popcorn', 'coke')):
decisionTreeClassifier = sklearn.tree.DecisionTreeClassifier(
criterion = 'gini',
random_state = 42,
)
df_encoded = pd.DataFrame()
df_encoded[name] = labelEncoder.fit_transform(df[name])
decisionTreeClassifier.fit(X=df_encoded, y=df['watch'])
sklearn.tree.plot_tree(
decisionTreeClassifier,
ax = axes[i],
feature_names = [name],
filled = True,
class_names = ['naah', 'watch'],
rounded = True,
fontsize = 14,
node_ids = True,
# proportion = True,
# impurity = False,
)
plt.show()
fig.savefig(f:='./py-sklearn-plt/movie-tree-1d.png'); return f
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment