机器学习(二)-简单线性回归

news/2024/12/23 16:23:14 标签: 机器学习, 线性回归, 人工智能

文章目录

    • 1. 简单线性回归理论
    • 2. python通过简单线性回归预测房价
      • 2.1 预测数据
      • 2.2导入标准库
      • 2.3 导入数据
      • 2.4 划分数据集
      • 2.5 导入线性回归模块
      • 2.6 对测试集进行预测
      • 2.7 计算均方误差 J
      • 2.8 计算参数 w0、w1
      • 2.9 可视化训练集拟合结果
      • 2.10 可视化测试集拟合结果
      • 2.11 保存模型
      • 2.12 加载模型并预测

机器学习和统计学中,简单线性回归是一种基础而强大的工具,用于建立自变量与因变量之间的关系。

假设你是一个房产中介,想通过房屋面积来预测房价。简单线性回归可以帮助你找到房屋面积与房价之间的线性关系,进而为客户提供更合理的报价。

本文将带你深入了解简单线性回归的理论基础、公式推导以及如何在Python中实现这一模型。

1. 简单线性回归理论

简单线性回归的基本假设是,因变量 Y(例如房价)与自变量 X(例如人口)之间存在线性关系。我们可以用以下的线性方程来表示这种关系:
在这里插入图片描述

其中:

  • y 是因变量(我们要预测的变量)。

  • x 是自变量(我们用来进行预测的变量)。

  • w0是截距(当x=0) 时,y的值)。

  • w1是斜率(自变量变化一个单位时,因变量的变化量)。

我们的目标是求 w0和w1的值,来找到一条跟预测值相关的直线。

从图中我们可以看出预测值与真实值之间存在误差,那么我们引入机器学习中的一个概念均方误差,它表示的是这些差值的平方和的平均数。这些误差的表达式如下:
在这里插入图片描述

均方误差的表达式如下:
在这里插入图片描述

2. python通过简单线性回归预测房价

2.1 预测数据

数据如下:

polulation,median_house_value
961,3.03
234,0.68
1074,2.92
1547,4.24
805,2.39
597,1.59
784,2.21
498,1.31
1602,4.28
292,0.54
1499,4.18
718,1.95
180,0.43
1202,3.62
1258,3.48
453,1.08
845,2.31
1032,2.96
384,0.68
896,2.62
425,0.82
928,2.95
1324,3.59
1435,4.02
543,1.62
1132,3.34
328,0.76
638,1.54
1389,3.78
692,1.79

x 轴是人口数量,y轴是房价

2.2导入标准库

# 导入标准库
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
matplotlib.use('TkAgg')

2.3 导入数据

# 导入数据集
dataset = pd.read_csv('Data.csv')
x = dataset.iloc[:, :-1]
y = dataset.iloc[:, 1]

2.4 划分数据集

# 数据集划分 训练集/测试集
from sklearn.model_selection import train_test_split
X_train,X_test,y_train,y_test = train_test_split(x,y,test_size=0.2,random_state=0)

2.5 导入线性回归模块

# 简单线性回归算法
from sklearn.linear_model import LinearRegression
regressor = LinearRegression()
regressor.fit(X_train, y_train)

2.6 对测试集进行预测

# 对测试集进行预测
y_pred = regressor.predict(X_test)

2.7 计算均方误差 J

# 计算J
J = 1/X_train.shape[0] * np.sum((regressor.predict(X_train) - y_train)**2)
print("J = {}".format(J))

输出结果:

J = 0.031198935319832692

2.8 计算参数 w0、w1

# 计算参数 w0、w1
w0 = regressor.intercept_
w1 = regressor.coef_[0]
print("w0 = {}, w1 = {}".format(w0, w1))

输出结果:

w0 = -0.16411984840092098, w1 = 0.0029383965595942067

2.9 可视化训练集拟合结果

# 可视化训练集拟合结果
plt.figure(1)
plt.scatter(X_train, y_train, color = 'red')
plt.plot(X_train, regressor.predict(X_train), color = 'blue')
plt.title('population VS median_house_value (training set)')
plt.xlabel('population')
plt.ylabel('median_house_value')
plt.show()

输出结果:
在这里插入图片描述

可以很好的看到拟合的直线可以很好的表示原始数据的人口和房价的走势

2.10 可视化测试集拟合结果

# 可视化测试集拟合结果
plt.figure(2)
plt.scatter(X_test, y_test, color = 'red')
plt.plot(X_train, regressor.predict(X_train), color = 'blue')
plt.title('population VS median_house_value (test set)')
plt.xlabel('population')
plt.ylabel('median_house_value')
plt.show()

输出结果:
在这里插入图片描述

可以看到,拟合的直线在测试集上的表现是相当不错了,说明我们训练的线性模型有很好的应用效果。

2.11 保存模型

# 保存模型
import pickle
with open('../model/simple_house_price_model.pkl','wb') as file:
    pickle.dump(regressor,file);

2.12 加载模型并预测

import pickle
import numpy as np
import pandas as pd
# 加载模型并预测
with open('../model/simple_house_price_model.pkl','rb') as file:
    model = pickle.load(file)

x_test = np.array([693,694])
x_test = pd.DataFrame(x_test)
x_test.columns=['polulation']
y_pred = model.predict(x_test)
print(y_pred)

输出结果:

[1.87218897 1.87512736]

http://www.niftyadmin.cn/n/5796747.html

相关文章

tslib(触摸屏输入设备的轻量级库)的学习、编译及测试记录

目录 tslib的简介tslib的源码和make及make install后得到的文件下载tslib的主要功能tslib的工作原理tslib的核心组成部分tslib的框架和核心函数分析tslib的框架tslib的核心函数ts_setup()的分析(对如何获取设备名和数据处理流程的分析)函数ts_setup()自身的主要代码ts_setup()对…

材料性质预测、分子生成、分类等研究方向的大语言模型构建与应用

流程 数据准备 收集和预处理大规模材料相关数据集。格式化数据以适应模型输入。 模型预训练 基于Transformer架构进行大规模无监督预训练。任务:掩码语言模型(MLM)或自回归生成任务。 任务微调 针对特定任务(性质预测、分子生成、…

Docker Compose 安装 Harbor

我使用的系统是rocky Linux 9 1. 准备环境 确保你的系统已经安装了以下工具: DockerDocker ComposeOpenSSL(用于生成证书)#如果不需要通过https连接的可以不设置 1.1 安装 Docker 如果尚未安装 Docker,可以参考以下命令安装&…

15款行业大数据报告下载网站

1、CAICT中国信通院 http://www.caict.ac.cn/ 国家高端产业智库,研究报告免费下载PDF版本。 2、阿里研究院 http://www.aliresearch.com/ 阿里出品,阿里相关产品数据报告。 3、企鹅智库 https://re.qq.com/ 腾讯旗下数据报告。 4、CBNData https:…

数据结构经典算法总复习(下卷)

第五章:树和二叉树 先序遍历二叉树的非递归算法。 void PreOrderTraverse(BiTree T, void (*Visit)(TElemType)) {//表示用于查找的函数的指针Stack S; BiTree p T;InitStack(S);//S模拟工作栈while (p || !StackEmpty(S)) {//S为空且下一个结点为空,意味着结束遍…

centos集群部署seata

文章目录 场景环境介绍100.64.0.3节点部署seata100.64.0.4节点部署seatanacos注册的效果我的配置,可以参考下 场景 生产环境都是以集群的方式部署seata, 这里演示下部署方式 环境介绍 2台centos7.9的开发机(内网ip100.64.0.4 ,100.64.0.3)jdk17一个nacos服务一个8.0.40版本的m…

Kotlin - 协程结构化并发Structured Concurrency

前言 Kotlin的Project Lead,Roman Elizarov的一片文章https://elizarov.medium.com/structured-concurrency-722d765aa952介绍了Structured Concurrency发展的背景。相对Kotlin1.1时代,后来新增的Structured Concurrency理念,也就是我们现在所…

【Linux系统编程】:信号(2)——信号的产生

1.前言 我们会讲解五种信号产生的方式: 通过终端按键产生信号,比如键盘上的CtrlC。kill命令。本质上是调用kill()调用函数接口产生信号硬件异常产生信号软件条件产生信号 前两种在前一篇文章中做了介绍,本文介绍下面三种. 2. 调用函数产生信号 2.1 k…