线性判别式分析(Linear Discriminant Analysis)简称LDA,是模式识别的经典算法。
线性判别式分析(Linear Discriminant Analysis)简称LDA,是模式识别的经典算法。通过对历史数据进行投影,以保证投影后同一类别的数据尽量靠近,不同类别的数据尽量分开。并生成线性判别模型对新生成的数据进行分离和预测。本篇文章使用机器学习库scikit-learn建立LDA模型,并通过绘图展示LDA的分类结果。
准备工作
首先是开始前的准备工作,导入需要使用的库文件,本篇文章中除了常规的数值计算库numpy,科学计算库pandas,和绘图库matplotlib以外,还有绘图库中的颜色库,以及机器学习中的数据预处理和LDA库。
1
2
3
4
5
6
7
8
9
10
11
12
|
#导入数值计算库
import numpy as np
#导入科学计算库
import pandas as pd
#导入绘图库
import matplotlib.pyplot as plt
#导入绘图色彩库产生内置颜色
from matplotlib.colors import ListedColormap
#导入数据预处理库
from sklearn import preprocessing
#导入linear discriminant analysis库
from sklearn.lda import LDA
|
读取数据
读取并创建名称为data的数据表,后面我们将使用这个数据表创建LDA模型并绘图。
1
2
|
#读取数据并创建名为data的数据表
data = pd.DataFrame(pd.read_csv( 'LDA_data.csv' ))
|
使用head函数查看数据表的前5行,这里可以看到数据表共有三个字段,分别为贷款金额loan_amnt,用户收入annual_inc和贷款状态loan_status。
1
2
|
#查看数据表的前5行
data.head()
|
设置模型特征X和目标Y
将数据表中的贷款金额和用户收入设置为模型特征X,将贷款状态设置为模型目标Y,也就是我们要分类的结果。
1
2
3
4
|
#设置贷款金额和用户收入为特征X
X = np.array(data[[ 'loan_amnt' , 'annual_inc' ]])
#设置贷款状态为目标Y
Y = np.array(data[ 'loan_status' ])
|
对特征进行标准化处理
贷款金额和用户收入间差异较大,属于两个不同量级的数据。因此需要对数据进行标准化处理,转化为无量纲的纯数值。
1
2
3
|
#特征数据进行标准化
scaler = preprocessing.StandardScaler().fit(X)
X_Standard = scaler.transform(X)
|
下面是经过标准化处理后的特征数据。
1
2
|
#查看标准化后的特征数据
X_Standard
|
创建LDA模型并拟合数据
将标准化后的特征X和目标Y代入到LDA模型中。下面是具体的代码和计算结果。
1
2
3
|
#创建LDA模型
clf = LDA()
clf.fit(X_Standard,Y)
|
绘图数据预处理
对绘图数据进行预处理,计算X和Y的边界值,并使用meshgrid函数计算坐标向量矩阵。
1
2
3
|
#设置X和Y的边界值
x_min, x_max = X_Standard[, 0 ]. min () - 1 , X_Standard[, 0 ]. max () + 1
y_min, y_max = X_Standard[, 1 ]. min () - 1 , X_Standard[, 1 ]. max () + 1
|
1
2
3
|
#使用meshgrid函数返回X和Y两个坐标向量矩阵
xx, yy = np.meshgrid(np.arange(x_min, x_max,h), np.arange(y_min, y_max,h))
Z = clf.predict(np.c_[xx.ravel(), yy.ravel()])
|
设置图表所使用的颜色,这里使用的是HEX值。
1
2
|
#设置colormap颜色
cm_bright = ListedColormap([ '#D9E021' , '#0D8ECF' ])
|
绘制LDA分类图表
首先绘制LDA分类图表的边界,这里使用之前计算的坐标矩阵,并设置的colormap颜色和透明度。
1
2
3
|
#绘制分类边界
Z = Z.reshape(xx.shape)
plt.pcolormesh(xx, yy, Z, cmap = cm_bright,alpha = 0.6 )
|
最后绘制LDA图表中的数据点,并设置colormap颜色以及图表标题。以下是具体代码和图表。
1
2
3
4
5
|
#绘制数据点
plt.scatter(X_Standard[, 0 ], X_Standard[, 1 ], c = Y, cmap = cm_bright)
plt.title( 'Linear Discriminant Analysis Classifiers' )
plt.axis( 'tight' )
plt.show()
|
本文为专栏文章,来自:蓝鲸,内容观点不代表本站立场,如若转载请联系专栏作者,本文链接:https://www.afenxi.com/26914.html 。