Support vector machines, the full name in English is Support Vecto Machines, or SVM for short. It is a very excellent classification model, especially in small sample, nonlinear and high-dimensional pattern recognition. SVM was proposed by the Vapnik team in 1992. It was initially used to solve binary classification problems, and later gradually developed into an algorithm that can handle multi-classification problems.
Python is a concise and powerful programming language that implements numerous machine learning algorithm packages, including SVM. This article will introduce the steps to implement the support vector machine algorithm through Python.
1. Prepare data
Let’s construct a simple set of training data. Create an example dataset where x1 represents height, x2 represents weight, and y is the class label (0 or 1).
import numpy as np import matplotlib.pyplot as plt np.random.seed(7) X_train = np.array([[167, 75], [182, 80], [176, 85], [156, 50], [173, 70], [183, 90], [178, 75], [156, 45], [162, 55], [163, 50], [159, 45], [180, 85]]) y_train = np.array([0, 1, 1, 0, 0, 1, 0, 0, 0, 0, 0, 1]) plt.scatter(X_train[y_train == 0][:, 0], X_train[y_train == 0][:, 1], c='r', s=40, label='Male') plt.scatter(X_train[y_train == 1][:, 0], X_train[y_train == 1][:, 1], c='b', s=40, label='Female') plt.legend() plt.xlabel('Height') plt.ylabel('Weight') plt.show()
In this dataset, we classify people as male or female.
2. Select a classifier
Next, we need to select a classifier suitable for this problem, namely SVM. There are many variants of SVM, but here, we are using linear SVM.
Let’s construct a SVM model:
from sklearn.svm import SVC svm = SVC(kernel='linear') svm.fit(X_train, y_train)
Here, we use theSVC
class and specify thekernel
parameter aslinear
, indicating that we use a linear kernel.
3. Draw the decision boundary
We want to know the performance of the model, so we can draw the decision boundary of the classifier:
def plot_decision_boundary(model, ax=None): if ax is None: ax = plt.gca() x_min, x_max = ax.get_xlim() y_min, y_max = ax.get_ylim() xx, yy = np.meshgrid(np.linspace(x_min, x_max, 100), np.linspace(y_min, y_max, 100)) Z = model.predict(np.c_[xx.ravel(), yy.ravel()]).reshape(xx.shape) ax.contourf(xx, yy, Z, alpha=0.2) ax.contour(xx, yy, Z, colors='black', linewidths=0.5) ax.set_xlim([x_min, x_max]) ax.set_ylim([y_min, y_max]) plt.scatter(X_train[y_train == 0][:, 0], X_train[y_train == 0][:, 1], c='r', s=40, label='Male') plt.scatter(X_train[y_train == 1][:, 0], X_train[y_train == 1][:, 1], c='b', s=40, label='Female') plot_decision_boundary(svm) plt.legend() plt.xlabel('Height') plt.ylabel('Weight') plt.show()
After the run is completed, you can See that the decision boundary of the classifier is plotted.
4. Predict new data
We can use the trained model to predict new data.
X_test = np.array([[166, 70], [185, 90], [170, 75]]) y_test = svm.predict(X_test) print(y_test)
Here, we use thepredict
function to predict three new data samples. It will return their category.
Conclusion
In this article, we introduced how to use the support vector machine algorithm in Python. We built a classifier by creating a simple training dataset and using linear SVM. We also plotted the classifier's decision boundaries and used the model to predict new data samples. SVM is also a very popular algorithm on many occasions and can achieve good performance in many fields. If you want to master more machine learning algorithms when processing data, then SVM is also worth learning.
The above is the detailed content of Support vector machine algorithm example in Python. For more information, please follow other related articles on the PHP Chinese website!