机器学习中的欠拟合和过拟合
定义
欠拟合 (Underfitting):
- 定义:模型在训练数据和测试数据上都表现不佳。这表明模型没有很好地捕捉数据中的趋势。
- 现象:训练集和测试集准确率都很低。
- 原因:模型过于简单,参数过少,无法捕捉数据的复杂关系。
过拟合 (Overfitting):
- 定义:模型在训练数据上表现非常好,但在测试数据上表现不佳。这表明模型捕捉到了训练数据中的噪声和细节,而不是数据的整体趋势。
- 现象:训练集准确率高,测试集准确率低。
- 原因:模型过于复杂,参数过多,导致对训练数据的过度拟合。
图形表达——以线性回归为例
import numpy as np import matplotlib.pyplot as plt from sklearn.linear_model import LinearRegression from sklearn.metrics import mean_squared_error from sklearn.model_selection import train_test_split # 设置随机数种子 np.random.seed(666) # 解决中文显示问题 plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False # 生成数据 x = np.random.uniform(-3, 3, size=100) X = x.reshape(-1, 1) y = 0.5 * x**2 + x + np.random.normal(0, 1, size=100) # 绘制原始数据 plt.figure(figsize=(12, 8)) plt.scatter(X, y, label='原始数据', color='blue') # 模拟欠拟合:线性回归 X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=5) linear = LinearRegression() linear.fit(X_train, y_train) y_predict = linear.predict(X_test) plt.plot(x, linear.predict(X), color='red', label='线性回归 (欠拟合)') print(f"线性回归训练集均方误差: {mean_squared_error(y_train, linear.predict(X_train)):.4f}") print(f"线性回归测试集均方误差: {mean_squared_error(y_test, y_predict):.4f}") # 模拟合适拟合:二次回归 X2 = np.hstack([X, X**2]) X_train, X_test, y_train, y_test = train_test_split(X2, y, random_state=5) linear.fit(X_train, y_train) y_predict2 = linear.predict(X_test) plt.plot(np.sort(x), linear.predict(X2)[np.argsort(x)], color='green', label='二次回归 (合适拟合)') print(f"二次回归训练集均方误差: {mean_squared_error(y_train, linear.predict(X_train)):.4f}") print(f"二次回归测试集均方误差: {mean_squared_error(y_test, y_predict2):.4f}") # 模拟过拟合:高次多项式回归 X10 = np.hstack([X2, X**3, X**4, X**5, X**6, X**7, X**8, X**9, X**10]) X_train, X_test, y_train, y_test = train_test_split(X10, y, random_state=5) linear.fit(X_train, y_train) y_predict3 = linear.predict(X_test) plt.plot(np.sort(x), linear.predict(X10)[np.argsort(x)], color='orange', label='高次多项式回归 (过拟合)') print(f"高次多项式回归训练集均方误差: {mean_squared_error(y_train, linear.predict(X_train)):.4f}") print(f"高次多项式回归测试集均方误差: {mean_squared_error(y_test, y_predict3):.4f}") # 添加图例和标签 plt.xlabel('x 值', fontsize=14) plt.ylabel('y 值', fontsize=14) plt.title('欠拟合、合适拟合和过拟合示例', fontsize=16) plt.legend(fontsize=12) plt.grid(True) # 显示图形 plt.show()一次回归训练集均方误差: 3.0496
一次回归测试集均方误差: 3.1531
二次回归训练集均方误差: 1.0951
二次回归测试集均方误差: 1.1119
高次多项式回归训练集均方误差: 0.9992
高次多项式回归测试集均方误差: 1.4146
测试集和训练集上 的均方误差随着模型复杂度提高而减小,拟合效果越好,但在很多高次项加入时出现了过拟合。
免责声明:我们致力于保护作者版权,注重分享,被刊用文章因无法核实真实出处,未能及时与作者取得联系,或有版权异议的,请联系管理员,我们会立即处理! 部分文章是来自自研大数据AI进行生成,内容摘自(百度百科,百度知道,头条百科,中国民法典,刑法,牛津词典,新华词典,汉语词典,国家院校,科普平台)等数据,内容仅供学习参考,不准确地方联系删除处理! 图片声明:本站部分配图来自人工智能系统AI生成,觅知网授权图片,PxHere摄影无版权图库和百度,360,搜狗等多加搜索引擎自动关键词搜索配图,如有侵权的图片,请第一时间联系我们,邮箱:ciyunidc@ciyunshuju.com。本站只作为美观性配图使用,无任何非法侵犯第三方意图,一切解释权归图片著作权方,本站不承担任何责任。如有恶意碰瓷者,必当奉陪到底严惩不贷!

