鱼C论坛

 找回密码
 立即注册
查看: 143|回复: 1

[基础知识] 【深度学习实战】线性回归问题

[复制链接]
发表于 2024-6-7 17:01:01 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
import numpy as np
import torch
import torch.nn as nn 
np.random.seed(42)
# 生成随机数据
x = np.random.rand(100,1)
y = 1 + 2*x + 0.1*np.random.randn(100,1)
# 将数据转化成pytorch tensor
x_tensor = torch.from_numpy(x).float()
y_tensor = torch.from_numpy(y).float()
# 设置超参数
learning_rate = 0.1
num_epochs = 1000
input_dim = 1
output_dim = 1
model = nn.Linear(input_dim,output_dim)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(),lr = learning_rate)
# 开始训练
for epoch in range(num_epochs):
    y_pred = model(x_tensor)
    loss = criterion(y_pred,y_tensor)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
print("w:",model.weight.data)
print("b:",model.bias.data)
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

发表于 2024-6-7 17:01:09 | 显示全部楼层
这段代码是一个简单的线性回归问题的实现。让我们逐行来解释一下。

首先,导入必要的库并设置随机种子,以确保结果的可重复性:
import numpy as np
import torch
import torch.nn as nn 
np.random.seed(42)

然后,生成随机数据集。这里使用了numpy库来生成100个随机数作为输入特征x,并生成对应的输出标签y:
x = np.random.rand(100,1)
y = 1 + 2*x + 0.1*np.random.randn(100,1)

接下来,将生成的数据转化成PyTorch Tensor,以便与PyTorch模型进行计算:
x_tensor = torch.from_numpy(x).float()
y_tensor = torch.from_numpy(y).float()

然后,设置学习率、迭代次数、输入和输出维度,并初始化线性模型、损失函数和优化器:
learning_rate = 0.1
num_epochs = 1000
input_dim = 1
output_dim = 1
model = nn.Linear(input_dim,output_dim)
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

接下来是模型的训练过程。在每个epoch中,通过模型前向传播获得预测结果y_pred,计算当前的损失值loss,然后通过反向传播更新模型参数:
for epoch in range(num_epochs):
    y_pred = model(x_tensor)
    loss = criterion(y_pred, y_tensor)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

最后,输出线性模型的权重和偏置:
print("w:", model.weight.data)
print("b:", model.bias.data)

这样就完成了线性回归问题的训练和预测过程。希望这个解释对你有帮助!

以上回复来自 FishC AI 小助理,如未能正确解答您的问题,请继续追问。
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复 支持 反对

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-9-8 08:48

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表