Skip to content

Instantly share code, notes, and snippets.

@rhoit
Created February 16, 2026 03:19
Show Gist options
  • Select an option

  • Save rhoit/0ed0fd42aac991851176047324c618e3 to your computer and use it in GitHub Desktop.

Select an option

Save rhoit/0ed0fd42aac991851176047324c618e3 to your computer and use it in GitHub Desktop.
import sklearn.tree
import sklearn.metrics
import matplotlib.pyplot as plt
import sklearn.tree
import sklearn.preprocessing
import pandas as pd
df = pd.read_csv('./data/cricket.csv')
labelEncoder = sklearn.preprocessing.LabelEncoder()
features = ['outlook', 'temperature', 'humidity', 'wind']
for col in df[features]:
df[col] = labelEncoder.fit_transform(df[col])
decisionTreeClassifier = sklearn.tree.DecisionTreeClassifier(
criterion = 'entropy',
random_state = 42,
)
# since small data, no split
decisionTreeClassifier.fit(X=df[features], y=df['play'])
fig, axes = plt.subplots(2, 2, layout='tight')
for i in range(4):
decisionTreeClassifier = sklearn.tree.DecisionTreeClassifier(
criterion = 'entropy',
random_state = 42,
max_depth = i + 1,
)
decisionTreeClassifier.fit(X=df[features], y=df['play'])
predict = decisionTreeClassifier.predict(df[features])
accuracy = sklearn.metrics.accuracy_score(df['play'], predict)
ax = axes.ravel()[i]
ax.set_title('max_depth={} accuracy={}'.format(i+1, accuracy))
sklearn.tree.plot_tree(
decisionTreeClassifier,
ax = ax,
feature_names = features,
filled = True,
class_names = None, # ['✗', '✓'], # handled by filled color
rounded = True,
fontsize = 8,
impurity = False,
)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment