线性回归
线性回归是机器学习中最基本的模型,也是必须掌握的模型,其中涉及到最小二乘法,均方误差等,目的就是求得一条直线拟合一些点,首先生成如下图片的点,下面会用pytorch代码实现直线的拟合。
PyTorch代码实现
1 | import torch |
可以调整学习率、迭代次数、优化方法,看看不同的调整,会有什么不同的结果
- 学习率太小,可能导致迭代速度过慢,迭代次数结束后,损失还很大
- 学习率太大,可能导致无法收敛,并出现振荡,迭代结束后,损失依然大
- Adam使用的学习率与SDG使用的学习率并不一定会一样
最后拟合出来的图像如下图:
代码也可以查看我的GitHub仓库LinearRegression,如有错误,欢迎指出。