scikit-learn K近邻法类库使用的经验总结

电子说

1.3w人已加入

描述

本文对scikit-learn中KNN相关的类库使用做了一个总结,主要关注于类库调参时的一个经验总结,且该文非常详细地介绍了类的参数含义,这是小编见过最详细的KNN类库参数介绍 。

目录

scikit-learn 中KNN相关的类库概述

K近邻法和限定半径最近邻法类库参数小结

使用KNeighborsClassifier做分类的实例

1. scikit-learn中KNN相关的类库概述

在scikit-learn 中,与近邻法这一大类相关的类库都在sklearn.neighbors包之中。KNN分类树的类是KNeighborsClassifier,KNN回归树的类是KNeighborsRegressor。除此之外,还有KNN的扩展,即限定半径最近邻分类树的类RadiusNeighborsClassifier和限定半径最近邻回归树的类RadiusNeighborsRegressor, 以及最近质心分类算法NearestCentroid。

在这些算法中,KNN分类和回归的类参数完全一样。限定半径最近邻法分类和回归的类的主要参数也和KNN基本一样。

比较特别是的最近质心分类算法,由于它是直接选择最近质心来分类,所以仅有两个参数,距离度量和特征选择距离阈值,比较简单,因此后面就不再专门讲述最近质心分类算法的参数。

另外几个在sklearn.neighbors包中但不是做分类回归预测的类也值得关注。kneighbors_graph类返回用KNN时和每个样本最近的K个训练集样本的位置。radius_neighbors_graph返回用限定半径最近邻法时和每个样本在限定半径内的训练集样本的位置。NearestNeighbors是个大杂烩,它即可以返回用KNN时和每个样本最近的K个训练集样本的位置,也可以返回用限定半径最近邻法时和每个样本最近的训练集样本的位置,常常用在聚类模型中。

2.  K近邻法和限定半径最近邻法类库参数小结

本节对K近邻法和限定半径最近邻法类库参数做一个总结。包括KNN分类树的类KNeighborsClassifier,KNN回归树的类KNeighborsRegressor, 限定半径最近邻分类树的类RadiusNeighborsClassifier和限定半径最近邻回归树的类RadiusNeighborsRegressor。这些类的重要参数基本相同,因此我们放到一起讲:

函数

函数

函数

函数

3. 使用KNeighborsClassifier做分类的实例

完整代码见github: 

https://github.com/ljpzzz/machinelearning/blob/master/classic-machine-learning/knn_classifier.ipynb

3.1 生成随机数据

首先,我们生成我们分类的数据,代码如下:

import numpy as np import matplotlib.pyplot as plt from sklearn.datasets.samples_generator import make_classification # X为样本特征,Y为样本类别输出, 共1000个样本,每个样本2个特征,输出有3个类别,没有冗余特征,每个类别一个簇 X, Y = make_classification(n_samples=1000, n_features=2, n_redundant=0,                    n_clusters_per_class=1, n_classes=3)plt.scatter(X[:, 0], X[:, 1], marker='o', c=Y)plt.show()

先看看我们生成的数据图如下。由于是随机生成,如果你也跑这段代码,生成的随机数据分布会不一样。下面是我某次跑出的原始数据图。

函数

接着我们用KNN来拟合模型,我们选择K=15,权重为距离远近。代码如下:

from matplotlib.colors import ListedColormap cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF'])cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF']) #确认训练集的边界 x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 #生成随机数据来做测试集然后预测 xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02),                         np.arange(y_min, y_max, 0.02))Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) # 画出测试集数据 Z = Z.reshape(xx.shape)plt.figure()plt.pcolormesh(xx, yy, Z, cmap=cmap_light) # 也画出所有的训练集数据 plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=cmap_bold)plt.xlim(xx.min(), xx.max())plt.ylim(yy.min(), yy.max())plt.title("3-Class classification (k = 15, weights = 'distance')" )生成的图如下,可以看到大多数数据拟合不错,仅有少量的异常点不在范围内。

函数

打开APP阅读更多精彩内容
声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表电子发烧友网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉

全部0条评论

快来发表一下你的评论吧 !

×
20
完善资料,
赚取积分