plot_decision_boundaries(X, y, model_class, **model_params)
Description
Function to plot the decision boundaries of a classification model.
This uses just the first two columns of the data for fitting
the model as we need to find the predicted value for every point in
scatter plot.
Arguments:
X: Feature data as a Numpy array.
y: Label data as a Numpy array.
model_class: A Scikit-learn ML estimator class e.g. GaussianNB (imported from sklearn.naive_bayes) or
LogisticRegression (imported from sklearn.linear_model)
**model_params: Model parameters to be passed on to the ML estimator.
Returns:
A Matplotlib figure object (matplotlib.figure.Figure)
Typical code example:
from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import make_classification
X1, y1 = make_classification(n_features=10, n_samples=100,
n_redundant=0, n_informative=10,
n_clusters_per_class=1,class_sep=0.5)
_ = plot_decision_boundaries(X1,y1,KNeighborsClassifier,n_neighbors=5)
plt.show()