Skip to content

Instantly share code, notes, and snippets.

@sepiabrown
Created January 28, 2026 08:25
Show Gist options
  • Select an option

  • Save sepiabrown/80f244e8a4770a3379a863934240220a to your computer and use it in GitHub Desktop.

Select an option

Save sepiabrown/80f244e8a4770a3379a863934240220a to your computer and use it in GitHub Desktop.
def find_optimal_threshold(y_true, y_scores, priority='sensitivity', min_value=0.95):
"""
Find optimal threshold with constrained optimization.
Parameters
----------
y_true : array-like
Ground truth labels (0 or 1)
y_scores : array-like
Predicted probabilities or continuous scores
priority : str, 'sensitivity' or 'specificity'
The metric to constrain (must be >= min_value)
min_value : float
Minimum acceptable value for the priority metric
Returns
-------
dict with threshold, sensitivity, specificity, fpr
"""
fpr, tpr, thresholds = roc_curve(y_true, y_scores)
specificity = 1 - fpr
if priority == 'sensitivity':
constrained = tpr
to_maximize = specificity
constrained_name = 'sensitivity'
maximize_name = 'specificity'
elif priority == 'specificity':
constrained = specificity
to_maximize = tpr
constrained_name = 'specificity'
maximize_name = 'sensitivity'
else:
raise ValueError("priority must be 'sensitivity' or 'specificity'")
valid_mask = constrained >= min_value
if not valid_mask.any():
print(f"Warning: No threshold achieves {constrained_name} >= {min_value}")
print(f"Max achievable {constrained_name}: {constrained.max():.4f}")
best_idx = np.argmax(constrained)
else:
valid_indices = np.where(valid_mask)[0]
best_idx = valid_indices[np.argmax(to_maximize[valid_mask])]
return {
'threshold': thresholds[best_idx],
'sensitivity': tpr[best_idx],
'specificity': specificity[best_idx],
'fpr': fpr[best_idx],
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment