Python画图实践

概述:Python画图实践:plotly、matplotlib、seaborn、pyechart

plotly

Install

1
pip install plotly

折线图

导入相应的库文件

1
2
3
import plotly
import pandas as pd
import plotly.graph_objects as go

数据读取展示

1
2
nz_weather = pd.read_csv("plotly-datasets/nz_weather.csv")
nz_weather.head()
DATE Auckland Christchurch Dunedin Hamilton Wellington
0 2000-01 115.4 47.2 174.8 96.2 91.8
1 2000-02 8.4 25.2 41 8.2 35.2
2 2000-03 57.2 60.8 74.2 33.8 53.4
3 2000-04 106.8 58.2 50 129.6 109.8
4 2000-05 128.2 62.0 ‘- 98.2 78.2
1
2
3
4
5
6
7
8
9
line1 = go.Scatter(x=nz_weather['DATE'],y=nz_weather['Auckland'],name='Auckland')
line2 = go.Scatter(x=nz_weather['DATE'],y=nz_weather['Christchurch'],name='Christchurch')
fig = go.Figure([line1,line2])
fig.update_layout(
title='New Zealand weather',
xaxis_title='Date',
yaxis_title='Weather'
)
fig.show()

散点图

数据准备

1
2
iris = pd.read_csv("plotly-datasets/iris.csv")
iris.head()
SepalLength SepalWidth PetalLength PetalWidth Name
0 5.1 3.5 1.4 0.2 Iris-setosa
1 4.9 3.0 1.4 0.2 Iris-setosa
2 4.7 3.2 1.3 0.2 Iris-setosa
3 4.6 3.1 1.5 0.2 Iris-setosa
4 5.0 3.6 1.4 0.2 Iris-setosa
1
2
3
4
5
6
7
8
scatter = go.Scatter(x=iris['SepalLength'],y=iris['SepalWidth'],mode='markers',marker={'size':10},name='散点类别')# 默认为线连接
# mode='markers+lines'表示为线与点的结合
layout = go.Layout(
showlegend=True,
legend=dict(x=0.9,y=1)
)
fig = go.Figure(scatter,layout)
fig.show()

对不同类别标记不同的颜色,而plotly不支持字符串形式的类别名,因此需要将类别转换为数字

1
2
3
4
5
6
7
8
9
iris.groupby('Name').count().index

name_to_color = {
'Iris-setosa':0,
'Iris-versicolor':1,
'Iris-virginica':2
}

iris['color'] = iris['Name'].map(name_to_color)
1
iris.head()
SepalLength SepalWidth PetalLength PetalWidth Name color
0 5.1 3.5 1.4 0.2 Iris-setosa 0
1 4.9 3.0 1.4 0.2 Iris-setosa 0
2 4.7 3.2 1.3 0.2 Iris-setosa 0
3 4.6 3.1 1.5 0.2 Iris-setosa 0
4 5.0 3.6 1.4 0.2 Iris-setosa 0
1
2
3
scatter = go.Scatter(x=iris['SepalLength'],y=iris['SepalWidth'],mode='markers',marker={'color':iris['color']})# 默认为线连接
fig = go.Figure(scatter)
fig.show()

使用plotly.express实现散点图(express是用更简短的语句实现)

1
2
3
import plotly.express as px
fig = px.scatter(iris,x='SepalLength',y='SepalWidth',color='Name')
fig.show()

使用scatter_matrix实现对多个组合散点图可视化

1
2
fig = px.scatter_matrix(iris,dimensions=['SepalLength','SepalWidth','PetalLength','PetalWidth'],color='Name')
fig.show()

3D散点图

数据准备

1
2
threedline = pd.read_csv('plotly-datasets/3d-line1.csv')
threedline.head()
x y z color
0 100.000000 0.613222 0.734706 0
1 99.238875 0.589852 0.781320 0
2 99.559608 0.599743 0.762566 0
3 97.931425 0.549296 0.859966 0
4 96.837832 0.515613 0.927150 0
1
2
3
line = go.Scatter3d(x=threedline['x'],y=threedline['y'],z=threedline['z'],mode='markers',marker={'size':3,'color':'green'})
fig = go.Figure(line)
fig.show()

1
2
fig = px.scatter_3d(threedline,x='x',y='y',z='z',color='color')
fig.show()

柱状图

1
2
3
4
bar = go.Bar(x=nz_weather['DATE'],y=nz_weather['Auckland'],text=nz_weather['Auckland'],textposition='outside')
# text=nz_weather['Auckland'],textposition='outside' 是对具体的柱状体进行数值标注
fig = go.Figure(bar)
fig.show()

动态柱状图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import plotly.express as px
from vega_datasets import data
df = data.disasters()
df = df[df.Year > 1990]
# print(df)
fig = px.bar(df,
y="Entity",
x="Deaths",
animation_frame="Year",
orientation= 'h' ,
range_x=[0, df.Deaths.max()],
color="Entity")
# improve aesthetics (size, grids etc.)
fig.update_layout(width=1000,
height=800,
xaxis_showgrid=False,
yaxis_showgrid=False,
title_text='Evolution of Natural Disasters' ,
showlegend=False)
fig.update_xaxes(title_text= 'Number of Deaths' )
fig.update_yaxes(title_text= '' )
fig.show()

直方图

1
2
3
4
hist = go.Histogram(x=nz_weather['Auckland'],xbins={'size':10})# 每区间范围为10
fig = go.Figure(hist)
fig.update_layout(bargap=0.1)# 直方图之间加入间隔
fig.show()

拼状图

1
2
3
4
5
6
7
8
9
10
11
12
labels = ['Oxygen','Hydrogen','Carbon_Dioxide','Nitrogen']
values = [4500,2500,1053,500]
colors = ['#FEBFB3', '#E1396C', '#96D38C', '#D0F9B1']

fig = go.Pie(labels=labels, values=values,
hoverinfo='label+percent', textinfo='value',
textfont=dict(size=20),
marker=dict(colors=colors,
line=dict(color='#000000', width=2)))

fig = go.Figure(trace)
fig.show()

手抄

官网帮助文档:https://plotly.com/python/#fundamentals

在线制作网站:https://chart-studio.plotly.com/

matplotlib

基础知识:

axex: 设置坐标轴边界和表面的颜色、坐标刻度值大小和网格的显示

backend: 设置目标暑促TkAgg和GTKAgg

figure: 控制dpi、边界颜色、图形大小、和子区( subplot)设置

font: 字体集(font family)、字体大小和样式设置

grid: 设置网格颜色和线性

legend: 设置图例和其中的文本的显示

line: 设置线条(颜色、线型、宽度等)和标记

patch: 是填充2D空间的图形对象,如多边形和圆。控制线宽、颜色和抗锯齿设置等。

savefig: 可以对保存的图形进行单独设置。例如,设置渲染的文件的背景为白色。

verbose: 设置matplotlib在执行期间信息输出,如silent、helpful、debug和debug-annoying。

xticks和yticks: 为x,y轴的主刻度和次刻度设置颜色、大小、方向,以及标签大小。

绘图过程涉及中文/负号:

步骤一:打开设置文件

1
2
import matplotlib
matplotlib.matplotlib_fname()

会显示matplotlibrc文件的地址

步骤二:修改matplotlibrc文件

将文件中的

1
#font.family: sans-serif

去掉注释,修改为

1
font.family: Microsoft YaHei

可显示为中文

1
2
3
plt.rcParams['font.sans-serif']=['SimHei'] #用来正常显示中文标签

plt.rcParams['axes.unicode_minus']=False #用来正常显示负号

图部分名称:

折线图

(1)导入库

1
2
import matplotlib.pyplot as plt
%matplotlib inline

(2)生成数据

1
2
x= range(100)
y= [i**2 for i in x]

(3)绘图

1
2
3
plt.plot(x, y, linewidth=1, label = "test", color='red', linestyle=':', marker='|')
plt.legend(loc='upper left')#图例位置
plt.show()

其中:

linestyle可选参数:

1
2
3
4
'-'       solid line style
'--' dashed line style
'-.' dash-dot line style
':' dotted line style

marker可选参数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
'.'       point marker
',' pixel marker
'o' circle marker
'v' triangle_down marker
'^' triangle_up marker
'<' triangle_left marker
'>' triangle_right marker
'1' tri_down marker
'2' tri_up marker
'3' tri_left marker
'4' tri_right marker
's' square marker
'p' pentagon marker
'*' star marker
'h' hexagon1 marker
'H' hexagon2 marker
'+' plus marker
'x' x marker
'D' diamond marker
'd' thin_diamond marker
'|' vline marker
'_' hline marker

颜色参考:

另一例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# 生成数据集
x = np.linspace(-3,3,50)
y1 = 2*x+1
y2 = x**2

# 设置画布大小
plt.figure(figsize=[8,8])
# 绘制点线图
l1, = plt.plot(x,y2)
# color表示线图颜色,linewidth表示线图粗细,linestyle表示线图样式(可选:'-.' ':')
l2, = plt.plot(x,y1,color='red',linewidth=1.0,linestyle='--')

# 限制坐标轴的范围
plt.xlim((-1,2))
plt.ylim((-2,3))

# 为坐标轴重命名
plt.xlabel('I am x')
plt.ylabel('I am y')

# 设置字体格式(避免设置中文显示异常)
plt.rcParams['font.sans-serif'] = ['SimHei']
# 避免负号显示为方块
plt.rcParams['axes.unicode_minus']=False
# 设置图表标题
plt.title('点线图')

# 设置图例说明
plt.legend(handles=[l1,l2,],labels=['aaa','bbb'],loc='upper left')

# 设置新的坐标轴显示
new_ticks = np.linspace(-1,2,5)
plt.xticks(new_ticks)
# 设置坐标轴显示时使用r字体更好看,\ 表空格
plt.yticks([-2,-1.8,-1,1.22,3],
[r'$very\ bad$','$bad$','$normal$','$good$','$very\ good$'])
# 显示图表
plt.show()

子图例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import matplotlib.pyplot as plt

# 设置画布,默认第一块画布
plt.figure()

# 在2×2中选取第一块区域
plt.subplot(2,2,1)
# 绘制点线图[0,0]到[1,1]
plt.plot([0,1],[0,1])

# 在2×2中选选取第二块区域
plt.subplot(2,2,2)
# 绘制点线图[0,0]到[1,2]
plt.plot([0,1],[0,2])

plt.subplot(2,2,3)
plt.plot([0,1],[0,3])

plt.subplot(2,2,4)
plt.plot([0,1],[0,4])

# 设置第二块画布
plt.figure(num=2)

plt.subplot(2,1,1)
plt.plot([0,1],[1,1])

plt.subplot(2,3,4)
plt.plot([0,1],[0,2])

plt.subplot(2,3,5)
plt.plot([0,1],[0,2])

plt.subplot(2,3,6)
plt.plot([0,1],[0,2])

plt.show()

柱状图

(1)导入库

1
2
import numpy as np
import matplotlib.pyplot as plt

(2)生成数据

1
2
population_ages = np.random.rand(50)*50
bins = np.arange(len(population_ages))

(3)绘图

1
2
3
4
5
6
7
8
9
10
# 绘制直方图
# histtype可选: {‘bar’, ‘barstacked’, ‘step’, ‘stepfilled’}
# 其中默认为bar,目测barstacked同bar效果一致
plt.hist(population_ages, bins, histtype='bar', rwidth=0.8)
# 设置坐标轴名称
plt.xlabel('x')
plt.ylabel('y')
plt.title('Hist Graph')#设置标题
plt.legend()# 生成图例
plt.show()

另一例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
n = 12
X = np.arange(n)

# uniform: 均匀分布 且左闭右开
Y1 = (1 - X / float(n)) * np.random.uniform(0.5, 1.0, n)
Y2 = (1 - X / float(n)) * np.random.uniform(0.5, 1.0, n)

# 绘制条形条,设置表明颜色, 边缘颜色
plt.bar(X, +Y1,facecolor='#9999ff',edgecolor='white')
plt.bar(X, -Y2,facecolor='#ff9999',edgecolor='white')

# 设置文字(数字)且确定位置
for x,y in zip(X,Y1):
# ha: horizontal alignmant
plt.text(x,y+0.15,'%.2f'%y,ha='center')
for x,y in zip(X,Y2):
plt.text(x,-y-0.15,'-%.2f'%y,ha='center')

# 设置字体格式(避免设置中文显示异常)
plt.rcParams['font.sans-serif'] = ['SimHei']
# 避免负号显示为方块
plt.rcParams['axes.unicode_minus']=False
# 设置标题
plt.title("条形图")
plt.xlim(-.5, n)
plt.xticks(())
plt.ylim(-1.25, 1.25)
plt.yticks(())
# 显示图表
plt.show()

饼图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# 设置内容列表
slices = [7,2,2,13]
# 设置对应内容名称
activities = ['sleeping','eating','working','playing']
# 设置对应颜色
cols = ['c','m','r','b']
# 绘制饼图
# startangle: 起始绘制角度,默认图是从x轴正方向逆时针画起,如设定=90则从y轴正方向画起
# shadow: 是否设置阴影效果
# explode: 每一块距离中心距离
# autopct: 设置百分号显示格式
plt.pie(slices,
labels=activities,colors=cols,
startangle=90,shadow= True,
explode=(0,0.12,0,0),
autopct='%1.2f%%')
# 设置标题
plt.title('Interesting Graph Check it out')
plt.show()

散点图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# mean: 0 ; std: 1 ; num: n
n = 1024
X = np.random.normal(0,1,n)
Y = np.random.normal(0,1,n)
# 设置散点的颜色
T = np.arctan2(Y,X)

# 绘制散点图(s表示size,c表示color,alpha表示不透明度)
plt.scatter(X,Y,s=30,c=T,alpha=0.5)

# 限制坐标轴范围
plt.xlim((-1.5,1.5))
plt.ylim((-1.5,1.5))

# 不显示坐标轴
plt.xticks(())
plt.yticks(())

plt.show()

3D散点:

1
2
3
4
5
6
7
8
9
10
11
12
13
# 创建数据集
x = np.random.normal(0, 1, 100)
y = np.random.normal(0, 1, 100)
z = np.random.normal(0, 1, 100)

# 设置画布
fig = plt.figure()
# 建立三维坐标轴
ax = Axes3D(fig)
# 绘制散点图
ax.scatter(x, y, z)
# 显示图表
plt.show()

等高线图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# 绘制等高线函数
def f(x,y):
return (1 - x / 2 + x**5 + y**3) * np.exp(-x**2 -y**2)

n = 256
x = np.linspace(-3, 3, n)
y = np.linspace(-3, 3, n)
# 互相影响其输出形状的函数
X,Y = np.meshgrid(x, y)

# 其中第三个参数表示分为10个部分,默认为0表示最少分层两部分
plt.contourf(X,Y,f(X,Y),8,alpha=0.8,cmap=plt.cm.hot)

plt.xticks(())
plt.yticks(())

plt.show()

添加标签:

1
2
3
4
# 绘制黑白类型的等高线
C= plt.contour(X,Y,f(X,Y),8,alpha=0.8,colors='black')
# 将其添加标签,使用内置显示和设置字体大小
plt.clabel(C,inline=True,fontsize=10)

3D平面图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
# 设置画布
fig = plt.figure()
# 设置3D坐标轴
ax = Axes3D(fig)
# 创建数据集
X = np.arange(-4,4,0.25)
Y = np.arange(-4,4,0.25)
X,Y = np.meshgrid(X,Y)
R = X+Y
Z = np.sin(R)

# 绘制3D图
# rstride/cstride:表示步伐大小,其越小超平面过渡越细致
ax.plot_surface(X,Y,Z,rstride=1,cstride=1,cmap=plt.get_cmap('rainbow'))

# 设置是否将Z轴投影到XY平面
# offset表示底部2D平面的y轴坐标面
# ax.contourf(X,Y,Z,zdir='z',offset=-2,cmap='rainbow')

ax.set_zlim(-2,2)

plt.show()

seaborn

参考文档:http://seaborn.pydata.org/index.html

热力图

(1)导入库

1
2
3
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

(2)生成数据

1
2
np.random.seed(0)# 设置相同种子以生成相同的随机数
sns.set()# 重置参数

(3)绘图

1
2
plt.figure(figsize=(10, 8))# 设置画布大小
uniform_data = np.random.rand(10, 12)# 生成数属于[0, 1)区间

(4)结合公开数据集绘图

1
2
3
flights = sns.load_dataset("flights")# 加载数据
flights = flights.pivot("month", "year", "passengers") # 行为月份;列为年份;值为乘客数
ax = sns.heatmap(flights,annot=True, fmt="d",linewidths=.5)#绘制及显示具体数值,格式化输出,线条间隔距离.5

若以上图片出现显示不全,原因在于matplotlib版本问题,安装3.1.0即可解决:

1
pip install matplotlib==3.1.0

切换风格:

1
ax = sns.heatmap(flights, cmap="YlGnBu")

联合分布图

(1)导入库

1
2
3
import tushare as ts
import seabron as sns
import matplotlib.pyplot as plt

(2)生成数据

1
stockdata = ts.get_k_data('000001')

(3)绘图

1
2
sns.jointplot("open", "close", stockdata)
plt

其中可以添加回归属性:

1
sns.jointplot("open", "close", stockdata, kind='reg')

多变量图

(1)导入库

1
2
3
4
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
from sklearn.datasets import load_iris

(2)生成数据

1
2
3
4
5
from sklearn.datasets import load_iris
irisdata = load_iris()
labels = list(irisdata.target)
irisdata = pd.DataFrame(data=irisdata.data, columns=irisdata.feature_names)
irisdata['labels'] = labels

查看数据属性:

1
irisdata.columns
1
['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)','petal width (cm)', 'labels']

(3)绘图

1
2
3
sns.set()#使用默认配色
sns.pairplot(irisdata,hue="labels")#hue 选择分类列
plt.show()

设置调色板:

1
sns.pairplot(irisdata,hue="labels", palette="husl")

添加回归设置:

1
sns.pairplot(irisdata,hue="labels", palette="husl",kind="reg")

指定变量:

1
sns.pairplot(irisdata, vars=["sepal length (cm)", "sepal width (cm)"],hue="labels",palette="husl")

箱体图

(1)导入库

(2)使用iris数据

(3)绘图

1
2
sns.boxplot(x = irisdata["labels"],y = irisdata["sepal length (cm)"])
plt

pyechart

参考文档:https://pyecharts.org/#/zh-cn/quickstart

查看库版本

1
2
import pyecharts
print(pyecharts.__version__)

柱状图

1
2
3
4
5
6
from pyecharts.charts import Bar
bar = Bar()
bar.add_xaxis(["衬衫", "羊毛衫", "雪纺衫", "裤子", "高跟鞋", "袜子"])
bar.add_yaxis("商家A", [5, 20, 36, 10, 75, 90])
# bar.render()
bar.render_notebook()# 在Jupyter中直接显示

其中,render 会生成本地 HTML 文件,默认会在当前目录生成 render.html 文件,也可以传入路径参数,如 bar.render(“mycharts.html”)

上述绘制方法也支持链式调用,同时可以使用options添加主副标题等操作

1
2
3
4
5
6
7
8
9
from pyecharts.charts import Bar
from pyecharts import options as opts
bar = (
Bar()
.add_xaxis(["衬衫", "羊毛衫", "雪纺衫", "裤子", "高跟鞋", "袜子"])
.add_yaxis("商家A", [5, 20, 36, 10, 75, 90])
)
bar.set_global_opts(title_opts=opts.TitleOpts(title="主标题", subtitle="副标题"))
bar.render_notebook()

多变量柱状图:

1
2
3
4
5
6
7
8
9
10
11
from pyecharts.charts import Bar
from pyecharts import options as opts

bar = (
Bar()
.add_xaxis(["衬衫", "毛衣", "领带", "裤子", "风衣", "高跟鞋", "袜子"])
.add_yaxis("商家A", [114, 55, 27, 101, 125, 27, 105])
.add_yaxis("商家B", [57, 134, 137, 129, 145, 60, 49])
.set_global_opts(title_opts=opts.TitleOpts(title="某商场销售情况"))
)
bar.render_notebook()

多变量横向展示(从右往左):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
from pyecharts import options as opts
from pyecharts.charts import Bar

bar = (
Bar()
.add_xaxis(["衬衫", "毛衣", "领带", "裤子", "风衣", "高跟鞋", "袜子"])
.add_yaxis("商家A", [114, 55, 27, 101, 125, 27, 105])
.add_yaxis("商家B", [57, 134, 137, 129, 145, 60, 49])
.reversal_axis()
.set_series_opts(label_opts=opts.LabelOpts(position="right"))
.set_global_opts(title_opts=opts.TitleOpts(title="主标题"))
)

bar.render_notebook()

折线图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import pyecharts.options as opts
from pyecharts.charts import Line
x_data = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
y_data1 = [820, 932, 901, 934, 1290, 1330, 1320]
y_data2 = [444, 24, 553, 222, 900, 113, 30]

(
Line()
.add_xaxis(xaxis_data=x_data)
.add_yaxis(
series_name="d1",
y_axis=y_data1,
symbol="circle",# 顶点样式 'emptyCircle', 'circle', 'rect', 'roundRect', 'triangle', 'diamond', 'pin', 'arrow', 'none'
symbol_size =7,
is_symbol_show=True,# 是否显示顶点
label_opts=opts.LabelOpts(is_show=False),
)
.add_yaxis(
series_name="d2",
y_axis=y_data2,
symbol="emptyCircle",
is_symbol_show=True,
label_opts=opts.LabelOpts(is_show=False),
)
.set_global_opts(
tooltip_opts=opts.TooltipOpts(is_show=False),
yaxis_opts=opts.AxisOpts(
type_="value",
axistick_opts=opts.AxisTickOpts(is_show=True),
splitline_opts=opts.SplitLineOpts(is_show=True),
),
xaxis_opts=opts.AxisOpts(type_="category",),
)
.render_notebook()
)

K线图

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from pyecharts import options as opts
from pyecharts.charts import Kline

data = [
[2320.26, 2320.26, 2287.3, 2362.94],
[2300, 2291.3, 2288.26, 2308.38],
[2295.35, 2346.5, 2295.35, 2345.92],
[2347.22, 2358.98, 2337.35, 2363.8],
[2360.75, 2382.48, 2347.89, 2383.76],
[2383.43, 2385.42, 2371.23, 2391.82],
[2377.41, 2419.02, 2369.57, 2421.15],
[2425.92, 2428.15, 2417.58, 2440.38],
[2411, 2433.13, 2403.3, 2437.42],
[2432.68, 2334.48, 2427.7, 2441.73],
[2430.69, 2418.53, 2394.22, 2433.89],
[2416.62, 2432.4, 2414.4, 2443.03],
]


def kline_base() -> Kline:
c = (
Kline()
.add_xaxis(["2017/7/{}".format(i + 1) for i in range(12)])
.add_yaxis("kline", data)
.set_global_opts(
yaxis_opts=opts.AxisOpts(is_scale=True),
xaxis_opts=opts.AxisOpts(is_scale=True),
title_opts=opts.TitleOpts(title="Kline-基本示例"),
)
)
return c
kline_base().render_notebook()

调整主题颜色:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
from pyecharts import options as opts
from pyecharts.charts import Kline

data = [
[2320.26, 2320.26, 2287.3, 2362.94],
[2300, 2291.3, 2288.26, 2308.38],
[2295.35, 2346.5, 2295.35, 2345.92],
[2347.22, 2358.98, 2337.35, 2363.8],
[2360.75, 2382.48, 2347.89, 2383.76],
[2383.43, 2385.42, 2371.23, 2391.82],
[2377.41, 2419.02, 2369.57, 2421.15],
[2425.92, 2428.15, 2417.58, 2440.38],
[2411, 2433.13, 2403.3, 2437.42],
[2432.68, 2334.48, 2327.7, 2441.73],
[2430.69, 2418.53, 2394.22, 2433.89],
[2416.62, 2432.4, 2414.4, 2443.03],
]


def kline_base() -> Kline:
c = (
Kline()
.add_xaxis(["2017/7/{}".format(i + 1) for i in range(12)])
.add_yaxis("kline", data,
itemstyle_opts=opts.ItemStyleOpts(
color="#ec0000",
color0="#00da3c",
border_color="#8A0000",
border_color0="#008F28",
),)
.set_global_opts(
yaxis_opts=opts.AxisOpts(is_scale=True),
xaxis_opts=opts.AxisOpts(is_scale=True),
title_opts=opts.TitleOpts(title="Kline-基本示例"),
)
)
return c
kline_base().render_notebook()

实践

绘制ROC

1
2
3
4
5
6
7
8
9
10
11
12
from sklearn.datasets import load_digits
from sklearn.model_selection import train_test_split
from sklearn.naive_bayes import GaussianNB
import matplotlib.pyplot as plt
import scikitplot as skplt
X, y = load_digits(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33)
nb = GaussianNB()
nb.fit(X_train, y_train)
predicted_probas = nb.predict_proba(X_test)
skplt.metrics.plot_roc(y_test, predicted_probas)
plt.show()

绘制混淆矩阵

1
2
3
4
5
6
7
8
9
10
from sklearn.ensemble import RandomForestClassifier
from sklearn.datasets import load_digits as load_data
from sklearn.model_selection import cross_val_predict
import matplotlib.pyplot as plt
import scikitplot as skplt
X, y = load_data(return_X_y=True)
classifier = RandomForestClassifier()
predictions = cross_val_predict(classifier, X, y)
plot = skplt.metrics.plot_confusion_matrix(y, predictions, normalize=True)
plt.show()

卡尔曼滤波

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
import numpy as np
from pykalman import KalmanFilter
import matplotlib.pyplot as plt

kf = KalmanFilter(n_dim_obs=1,
n_dim_state=1,
initial_state_mean=23,
initial_state_covariance=5,
transition_matrices=[1],
observation_matrices=[1],
observation_covariance=4,
transition_covariance=np.eye(1),
transition_offsets=None)

actual = [23]*100
sim = actual + np.random.normal(0,1,100)
state_means, state_covariance = kf.filter(sim)

plt.plot(actual,'r-')
plt.plot(sim,'k-')
plt.plot(state_means,'g-')
plt.show()