KNN算法直接作用于带标记的样本,属于有监督的算法,既可以用于分类也可以用于预测。它是一种懒惰的算法(lazy learning),因为它实际上并没有“训练”的过程,也不产生一个真实意义的“模型”。而只是一字不差地将所有的训练样本保存下来,等到需要对新样本进行分类的时候,将新样本与所有的训练样本进行比较,找出预期距离最接近的K个样本,然后基于这K个邻居所属的类别进行投票,决定分类效果。
算法过程:
- 计算距离:给定待分类的样本,计算该样本与其他每个样本的距离。
- 距离排序:将计算出的距离降序排序。
- 选择K个近邻:根据排序结果,选择距离最近的K个样本作为待分类样本的K个近邻。
- 决定分类:找出K个近邻的主要类别,即按投票方式决定待分类样本的类别。
KNN三要素:
- 距离度量
- 对欧式距离,点$P_1(x_{11},x_{12},\ldots,x_{1n})$与点$P_2(x_{21},x_{22},\ldots,x_{2n})$之间的距离为:$$d_{12} = \sqrt{\sum_{k=1}^n(x_{1k}-x_{2k})^2}$$
- 对曼哈顿距离,点$P_1(x_{11},x_{12},\ldots,x_{1n})$与点$P_2(x_{21},x_{22},\ldots,x_{2n})$之间的距离为:$$d_{12} = \sum_{k=1}^n|x_{1k}-x_{2k}|$$
- 对切比雪夫距离,点$P_1(x_{11},x_{12},\ldots,x_{1n})$与点$P_2(x_{21},x_{22},\ldots,x_{2n})$之间的距离为:$$d_{12} = \max_{i \in [1,n]}(|x_{1i}-x_{2i}|)$$
- 对夹角余弦距离,点$P_1(x_{11},x_{12},\ldots,x_{1n})$与点$P_2(x_{21},x_{22},\ldots,x_{2n})$之间的距离为:$$cos(\theta) = \frac{\sum_{k=1}^nx_{1k}x_{2k}}{\sqrt{\sum_{k=1}^nx_{1k}^2}\sqrt{\sum_{k=1}^nx_{2k}^2}}$$
- K值选择:选取恰当的K值大小
- 小:过拟合
- 大:平滑简单
- 分类决策
- 对距离加权
- 对样本加权
优点:
- 原理简单,易于理解,实现方便,不需要训练的过程。
- 特别适用于多分类问题,如根据基因的特征判断基因的功能
- 对异常值不敏感
缺点:
- 每次进行预测都要对所有样本进行重新扫描和计算距离,因此当样本集很大时,计算量会很大。
- 需要存储所有的样本。
- 结果的可解释性差,无法给出规则。
- 数据集中如果含有缺失值,需要特殊处理。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
from sklearn import datasets from sklearn.neighbors import KNeighborsClassifier# Neighbors:邻居 from sklearn.model_selection import train_test_split from sklearn.metrics import classification_report# 分类报告 from sklearn.metrics import confusion_matrix# 混淆矩阵 from sklearn.metrics import accuracy_score # 准备数据集,并分离训练集和验证集 iris = datasets.load_iris() X = iris.data Y = iris.target validation_size = 20 seed = 1# 随机数种子 X_train, X_validation, Y_train, Y_validation = train_test_split(X, Y, test_size=validation_size, random_state=seed) # 创建KNN分类器,并拟合数据集 knn = KNeighborsClassifier() knn.fit(X_train, Y_train) # 在验证集上进行预测,并输出accuracy score,混淆矩阵和分类报告 predictions = knn.predict(X_validation) print('------------------------验证成功率---------------------------') print(accuracy_score(Y_validation, predictions)) print('\n-------------------------混淆矩阵----------------------------') print(confusion_matrix(Y_validation, predictions)) print('\n-------------------------分类报告----------------------------') print(classification_report(Y_validation, predictions)) |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 |
------------------------验证成功率--------------------------- 1.0 -------------------------混淆矩阵---------------------------- [[8 0 0] [0 8 0] [0 0 4]] -------------------------分类报告---------------------------- precision recall f1-score support 0 1.00 1.00 1.00 8 1 1.00 1.00 1.00 8 2 1.00 1.00 1.00 4 accuracy 1.00 20 macro avg 1.00 1.00 1.00 20 weighted avg 1.00 1.00 1.00 20 |
文章有(1)条网友点评
获益匪浅,感谢大佬 {{keai}} {{keai}} {{keai}} {{keai}} {{keai}}