This is the plot module with some utility functions

plot_decision_boundaries[source]

plot_decision_boundaries(X, y, model_class, **model_params)

Description

Function to plot the decision boundaries of a classification model. This uses just the first two columns of the data for fitting the model as we need to find the predicted value for every point in scatter plot.

Arguments:

X: Feature data as a Numpy array.

y: Label data as a Numpy array.

model_class: A Scikit-learn ML estimator class e.g. GaussianNB (imported from sklearn.naive_bayes) or LogisticRegression (imported from sklearn.linear_model)

**model_params: Model parameters to be passed on to the ML estimator.

Returns:

A Matplotlib figure object (matplotlib.figure.Figure)

Typical code example:

from sklearn.neighbors import KNeighborsClassifier
from sklearn.datasets import make_classification

X1, y1 = make_classification(n_features=10, n_samples=100,
                                 n_redundant=0, n_informative=10,
                                 n_clusters_per_class=1,class_sep=0.5)

_ = plot_decision_boundaries(X1,y1,KNeighborsClassifier,n_neighbors=5)
plt.show()