当前位置:嗨网首页>书籍在线阅读

17-经典的鸢尾花数据集

  
选择背景色: 黄橙 洋红 淡粉 水蓝 草绿 白色 选择字体: 宋体 黑体 微软雅黑 楷体 选择字体大小: 恢复默认

7.5.2 经典的鸢尾花数据集

就像经典计算机科学问题一样,机器学习中也有经典的数据集。这些数据集可用于验证新技术,并与现有技术进行比较。它们还能用作首次学习机器学习算法的良好起点。最著名的机器学习数据集或许就是鸢尾花数据集了。该数据集最初收集于20世纪30年代,包含150个鸢尾花(花很漂亮)植物样本,分为3个不同的品种,每个品种50个样本。每种植物以4种不同的属性进行考量:萼片长度、萼片宽度、花瓣长度和花瓣宽度。

值得注意的是,神经网络并不关心各个属性所代表的含义。它的训练模型并不会区分萼片长度和花瓣长度的重要程度。如果需要进行这种区分,则由该神经网络的用户进行适当的调整。

本书附带的源码库包含了一个以鸢尾花数据集为特征值的逗号分隔值(CSV)文件[2]。鸢尾花数据集来自美国加利福尼亚大学的UCI机器学习库[3]。CSV文件只是一个文本文件,其值以逗号分隔。CSV文件是表格式数据(包括电子表格)的通用交换格式。

以下是iris.csv中的一些数据行:

5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa

每行代表一个数据点,其中的4个数字分别代表4种属性(萼片长度、萼片宽度、花瓣长度和花瓣宽度),再次声明,它们实际代表的意义是无所谓的。每行末尾的名称代表鸢尾花的特定品种。这5行都属于同一品种,因为此样本是从iris.csv文件开头读取的,3类品种数据是各自放在一起保存的,每个品种都有50行数据。

为了从磁盘读取CSV文件,我们将会用到Python标准库中的一些函数。 csv 模块将有助于我们以结构化的方式读取数据。内置的 open() 函数将会创建一个用于传给 csv.reader() 的文件对象。在代码清单7-14中,除这几行读取文件的代码之外,其余都只是对CSV文件中的数据进行重新排列,以备神经网络训练和验证之用。

代码清单7-14 iris_test.py

import csv
from typing import List
from util import normalize_by_feature_scaling
from network import Network
from random import shuffle
if __name__ == "__main__":
   iris_parameters: List[List[float]] = []
   iris_classifications: List[List[float]] = []
   iris_species: List[str] = []
   with open('iris.csv', mode='r') as iris_file:
       irises: List = list(csv.reader(iris_file))
       shuffle(irises) # get our lines of data in random order
       for iris in irises:
           parameters: List[float] = [float(n) for n in iris[0:4]]
           iris_parameters.append(parameters)
           species: str = iris[4]
           if species == "Iris-setosa":
               iris_classifications.append([1.0, 0.0, 0.0])
           elif species == "Iris-versicolor":
               iris_classifications.append([0.0, 1.0, 0.0])
           else:
               iris_classifications.append([0.0, 0.0, 1.0])
           iris_species.append(species)
    normalize_by_feature_scaling(iris_parameters)

iris_parameters 代表每个样本的4种属性集,这些样本将用于对鸢尾花进行分类。 iris_classifications 是每个样本的实际分类。此处的神经网络将包含3个输出神经元,每个神经元代表一种可能的品种。例如,最终输出的 [0.9, 0.3, 0.1] 将代表山鸢尾(iris-setosa),因为第一个神经元代表该品种,这里它的数值最大。

为了训练,正确答案是已知的,因此每条鸢尾花数据都带有预先标记的答案。对于应为山鸢尾的花朵数据, iris_classifications 中的数据项将会是 [1.0, 0.0, 0.0] 。这些值将用于计算每步训练后的误差。 iris_species 直接对应每条花朵数据应该归属的英文类别名称。山鸢尾在数据集中将被标记为 "Iris-setosa"

警告  上述代码中缺少了错误检查代码,这会让代码变得相当危险,因此这些代码不适用于生产环境,但用来测试是没有问题的。

代码清单7-15中的代码定义了神经网络对象。

代码清单7-15 iris_test.py(续)

iris_network: Network = Network([4, 6, 3], 0.3)

layer_structure 参数给定了包含3层(1个输入层、1个隐藏层和1个输出层)的网络 [4, 6, 3] 。输入层包含4个神经元,隐藏层包含6个神经元,输出层包含3个神经元。输入层中的4个神经元直接映射到用于对每个样本进行分类的4个参数。输出层中的3个神经元直接映射到3个不同的品种,对于每次的输入,我们都要分类为这3个品种。与其他一些公式相比,隐藏层的6个神经元存放的更多是一些尝试和误差的结果。 learning_rate 也是如此。如果神经网络算法的准确度不够理想,不妨多次尝试这两个值(隐藏层中的神经元数量和学习率)。具体代码如代码清单7-16所示。

代码清单7-16 iris_test.py(续)

def iris_interpret_output(output: List[float]) -> str:
    if max(output) == output[0]:
        return "Iris-setosa"
    elif max(output) == output[1]:
        return "Iris-versicolor"
    else:
        return "Iris-virginica"

iris_interpret_output() 是一个实用函数,将会被传给神经网络对象的 validate() 方法,用于识别正确的分类。

至此,终于可以对神经网络对象进行训练了。具体代码如代码清单7-17所示。

代码清单7-17 iris_test.py(续)

# train over the first 140 irises in the data set 50 times
iris_trainers: List[List[float]] = iris_parameters[0:140]
iris_trainers_corrects: List[List[float]] = iris_classifications[0:140]
for _ in range(50):
    iris_network.train(iris_trainers, iris_trainers_corrects)

这里将对150条鸢尾花数据集的前140条进行训练。还记得吧,从CSV文件中读取的数据行是经过重新排列的。这确保了每次运行程序时,训练的都是数据集的不同子集。注意,这140条鸢尾花数据会被训练50次。训练的次数将对神经网络的训练时间产生很大影响。一般来说,训练次数越多,神经网络算法就越准确。最后的测试代码将会用数据集中的最后10条鸢尾花数据来验证分类的正确性。具体代码如代码清单7-18所示。

代码清单7-18 iris_test.py(续)

# test over the last 10 of the irises in the data set
iris_testers: List[List[float]] = iris_parameters[140:150]
iris_testers_corrects: List[str] = iris_species[140:150]
iris_results = iris_network.validate(iris_testers, iris_testers_corrects, iris_
  interpret_output)
print(f"{iris_results[0]} correct of {iris_results[1]} = {iris_results[2] * 100}%")

上述所有工作引出了最终求解的问题:在数据集中随机选取10条鸢尾花数据,这里的神经网络对象可以对其中多少条数据进行正确分类?每个神经元的起始权重都是随机的,因此每次不同的运行都可能会得出不同的结果。不妨试着对学习率、隐藏神经元的数量和训练迭代次数进行调整,以便让神经网络对象变得更加准确。

最终应该会得出类似如下的结果:

9 correct of 10 = 90.0%