import matplotlib.pyplot as pltimport numpy as npdef draw_common_graph():x = np.arange(1, 10)y = x * 2# 指定x和y的最小值和最大值,依次为:x_min, x_max, y_min, y_max# plt.axis([-np.pi, np.pi, -2 * np.pi, 2 * np.pi])# 指定x和y的值,以及线条的样式设置# 通过设置label,用于添加对应的标题,在后续创建图例中会显示plt.plot(x, y, color="r", linestyle="--", marker="*", linewidth=1.0, label="test")# 指定x轴和y轴的标题名称plt.xlabel("epoch")plt.ylabel("acc")# 多图叠加plt.plot(x, x ** 2, color="b", linestyle="--", marker="*", linewidth=1.0, label="train")# 为生成的图像设置标题plt.title("draw common graph")# 创建图例# loc指定图例的位置# bbox_to_anchor用于细化图例的位置,是一个二元组,第一个数值用于控制legend的向右移动,第二个数值用于控制legend的向上移动# 图例会展示对应的图像plot里面的label值plt.legend(loc="upper left", bbox_to_anchor=(0.2, 0.95))# 为图像添加网格# 默认是给x和y都加网格,即axis="both",也可以通过设置axis="x"或"y"指定只在x或y上添加网格plt.grid(color="k", linestyle=":")# 在指定的点位置上添加注释plt.text(4, 20, "test text", fontsize=10)# 设置X轴的间距跨度plt.xticks(range(1, x.max(), 2))# 设置Y轴的间距刻度plt.yticks(range(1, 100, 20))plt.show()def main():draw_common_graph()if __name__ == "__main__":main()
最终生成的图像如下:
参考博客:
matplotlib画图基础教程
pyplot通过xticks和yticks修改轴距
