分类算法之邻近算法:KNN(应用篇)

分类算法之邻近算法:KNN(应用篇)

起步

这次使用的训练集由 sklearn 模块提供,关于虹膜(一种鸢尾属植物)的数据。



数据载入

from sklearn import datasets
iris = datasets.load_iris()

数据存储在 .data 成员中,它是一个 (n_samples, n_features) numpy 数组:

print(iris.data)
# [[ 5.1  3.5  1.4  0.2]
#  [ 4.9  3.   1.4  0.2]
#  ...

它有四个特征,萼片长度,萼片宽度,花瓣长度,花瓣宽度 (sepal length, sepal width, petal length and petal width)。



它的品种分类有山鸢尾,变色鸢尾,菖蒲锦葵(Iris setosa, Iris versicolor, Iris virginica.)三种。

print iris.data.shape
# output:(150L, 4L)

这是一个含有 150 个数据的训练集。

构造 KNN 分类器

from sklearn import neighbors
knn = neighbors.KNeighborsClassifier(n_neighbors=5)

n_neighbors 参数级是指定获取 K 个邻近点。

训练

训练的函数一般就是 fit

knn.fit(iris.data, iris.target) 

测试

模拟一些测试数据,使用刚刚的模型进行预测:

predict = knn.predict([[0.1, 0.2, 0.3, 0.4]])
print(predict) # output: [0] 

文章被以下专栏收录

    大多数关于人工智能的算法都分为理论篇和应用篇,前者讲述算法由来和推导过程,后者则会结合python来进行算法的实践。如果您有关于AI的文章,欢迎投稿!

    微信公众号同名,欢迎投稿。全平台约20万开发者关注,会员来自全球十多个国家和地区,拥有十多个线上线下技术社群,向本专栏投稿即默认发布到Python中文社区全平台。GitHub:github.com/pycn