实践一下人工智能编程,一个简单的线性回归问题
hive-105017·@cheva·
0.000 HBD实践一下人工智能编程,一个简单的线性回归问题
继续学习pytorch深度学习的内容。之前都是讲的理论,今天实际上手写一写代码。要解决的问题很简单。也是所谓机器学习中最常见的入门题材——线性回归。这种关系在现实中很常见。比如身高和鞋子的码数一般都存在线性关系。通过这个小题目,主要是看一眼人工智能到底是如何学习的。 这里有先构造一个存在线性关系的数据组。 X和Y之间的关系式Y=WX+b。然后我们把生成的数据分成两部分,80%作为训练集,20%作为测试集。我们程序的目的就是从这些数字中猜出w和b。 程序解决问题的思路其实很简单,就是通过不断的猜测来逼近真实值。大家应该都玩过一种猜数字的游戏,就是一个人心里想一个数,然后让你来猜。你先随便说一个数如,出题人就会告诉你这个数与他心里想的那个数相比,是偏大还是偏小。经过反复几次尝试,并根据出题人的反馈调整之后,最后一般都能猜得出来。人工智能的做法其实跟这个游戏一模一样。它会先随机设参数。然后通过这个参数计算结果,将结果与实际值进行比较。看还差多远。当然,计算机中所有的一切都要转化成数字,所以我们就要构造一个损失函数,来告诉人工智能他的猜测与真实答案到底有多大的距离。然后人工智能就会通过一系列的优化算法。来调整这些参数,然后再猜一遍,如此往往复循环很多次以后就能够逼近我们预先设置的真实值了。 具体到pytorch中的使用步骤大体如下: - 第一步是创建模型。先用pytorch.nn类来创建一个子类。在其中定义W和B两个需要优化的参数。然后创立一个实例,这个就是我们需要优化的模型。 - 第二步用随机数先设置一下w,b参数,用这个随机设置的参数来计算预测值,这一步其实就是瞎猜。 - 第三步调用损失函数进行计算出一个数值来衡量与预测值和实际值的差距有多大。 - 第四步将损失函数反向传播,调用优化器更新模型的参数。 最后把这段代码放进一个循环里面。反复执行直到得到理想的,接近真实数据的参数。以下是具体的代码: ``` weight = 0.7 bias = 0.3 #创建数据 start =0 end = 1 step = 0.02 x = torch.arange(start,end,step).unsqueeze(dim=1) y = weight*x+bias x[:10],y[:10] #分割训练集和测试集 train_split=int(0.8*len(x)) x_train,y_train = x[:train_split],y[:train_split] x_test,y_test = x[train_split:],y[train_split:] len(x_train),len(y_train),len(x_test),len(y_test) ##################################################### from torch import nn class LinearRegressionModel(nn.Module): def __init__(self): super().__init__() #初始化模型参数 self.weights = nn.Parameter(torch.randn(1, requires_grad=True, dtype=torch.float)) self.bias = nn.Parameter(torch.randn(1, requires_grad=True, dtype=torch.float)) #定义模型的计算方式,这里是线性方程: def forward(self,x:torch.Tensor) -> torch.Tensor: return self.weights*x+self.bias # 设置seed值, torch.manual_seed(42) #创建一个模型实例: model_0=LinearRegressionModel() #查看模型网络参数: print(list(model_0.parameters())) with torch.inference_mode(): y_preds = model_0(x_test) print(f"测试次数:{len(x_test)}") print(f'推断次数:{len(y_test)}') print(f'预测值:\n{y_preds}') ########################################### # 建立训练循环 #设置纪元数量 epochs=100 #建立空列表,跟踪训练过程 train_loss_value=[] test_loss_value=[] epoch_count=[] print(f'更新前模型参数值:{model_0.state_dict()}') for epoch in range(100): y_preds = model_0(x_train) #将模型设为训练模式 model_0.train() loss = loss_fn(y_preds,y_train) if epoch%10 == 0: print(f'损失函数值:{loss}') optimizer.zero_grad() loss.backward() optimizer.step() if epoch%10 == 0: print(f'纪元{epoch}次更新后的模型参数值:{model_0.state_dict()}') #测试模型 model_0.eval() with torch.inference_mode(): test_pred = model_0(x_test) test_loss = loss_fn(test_pred,y_test) epoch_count.append(epoch) train_loss_value.append(loss.detach().numpy()) test_loss_value.append(test_loss.detach().numpy()) if epoch%10 == 0: print(f'测试集损失值:{test_loss}') ``` 训练过程的输出: 更新前模型参数值:OrderedDict([('weights', tensor([0.3367])), ('bias', tensor([0.1288]))]) 损失函数值:0.31288138031959534 纪元0次更新后的模型参数值:OrderedDict([('weights', tensor([0.3406])), ('bias', tensor([0.1388]))]) 测试集损失值:0.48106518387794495 损失函数值:0.1976713240146637 纪元10次更新后的模型参数值:OrderedDict([('weights', tensor([0.3796])), ('bias', tensor([0.2388]))]) 测试集损失值:0.3463551998138428 损失函数值:0.08908725529909134 纪元20次更新后的模型参数值:OrderedDict([('weights', tensor([0.4184])), ('bias', tensor([0.3333]))]) 测试集损失值:0.21729660034179688 损失函数值:0.053148526698350906 纪元30次更新后的模型参数值:OrderedDict([('weights', tensor([0.4512])), ('bias', tensor([0.3768]))]) 测试集损失值:0.14464017748832703 损失函数值:0.04543796554207802 纪元40次更新后的模型参数值:OrderedDict([('weights', tensor([0.4748])), ('bias', tensor([0.3868]))]) 测试集损失值:0.11360953003168106 损失函数值:0.04167863354086876 纪元50次更新后的模型参数值:OrderedDict([('weights', tensor([0.4938])), ('bias', tensor([0.3843]))]) 测试集损失值:0.09919948130846024 损失函数值:0.03818932920694351 纪元60次更新后的模型参数值:OrderedDict([('weights', tensor([0.5116])), ('bias', tensor([0.3788]))]) 测试集损失值:0.08886633068323135 损失函数值:0.03476089984178543 纪元70次更新后的模型参数值:OrderedDict([('weights', tensor([0.5288])), ('bias', tensor([0.3718]))]) 测试集损失值:0.0805937647819519 损失函数值:0.03132382780313492 纪元80次更新后的模型参数值:OrderedDict([('weights', tensor([0.5459])), ('bias', tensor([0.3648]))]) 测试集损失值:0.07232122868299484 损失函数值:0.02788739837706089 纪元90次更新后的模型参数值:OrderedDict([('weights', tensor([0.5629])), ('bias', tensor([0.3573]))]) 测试集损失值:0.06473556160926819 训练过程中,损失函数不断下降的图表: 
👍 wherein, backscratcher, sinzzer, theb0red1, vonaurolacu, nworb, unitqm, cnstm, cugel, mnurhiver, waivio.curator, crypt0gnome, joshglen, resiliencia, otage, jasonmunapasee, sanach, truce, softmetal, rightwing670, bittrio, e-r-k-a-n, brada2550, aekraj, odditiesandends, photosnap, randomblock1, reloadbeatbox, greendeliverence, koenau, ireenchew, solaiman, zellypearl, xyzxyz, kryptoking1, zanoz, duwiky, hattaarshavin, numpypython, aaronli, elderson, moeenali, funkymunky20000, idayrus, ferrate, paleotwist, rocketpower, xylliana, mistural, huolian, kaylinart, webdeals, sixexgames, sames, bitrocker2020, mawit07, myach, cheva, ikrahch, mimidee74, yameen, marfonso, nureza, steemegg, hans001, h-hamilton, zartisht, ricestrela, eythorphoto, genepoolrentsclr, mikehere, zainalbakri, fatman, dbooster, babarakas43, indiebandguru, nattybongo, jluvs2fly, maeusenews, brucutu, hivevote, killerwot, joelibra, kevinwong, mattroconnor, steemxp, curtley, logicforce, elizacheng, mrpointp, mrspointm, jychbetter, winniex, windowglass, auleo, cherryng, nostalgic1212, aellly, tanlikming, graceli, julian2013, cherryzz, ying82, atyh, bo022, lovequeen, jingjing1616, emma-emma, ericaliu, abundancelife, lovelingling, everlandd, celeste413, zhangyan-123, alpha-omega, arconite, archisteem, justinchicken, bert0, memeteca, tresor, bnk, badmusgreene, ragnar94, mangou007, sharkface, roncoejr, misbitcap, thomasbrown, headcrypto, byebyehamburgers, steemturbo, evanstinger, naej, randumb, happyphoenix, pouchon, moemanmoesly, hypersonic1, minhaz007, cryptovues, intrepidthinker, lilzmom902, supremebape, linakay, lordgod, rymlen, jaforce, tworealsolutions, slacktmusic, makingthebest09, chinyerevivian, foreversteem, ano123, q-news, coreyssteemit77, moro1992, ragnarhewins90, bigtakosensei, c4cristi3, chimzycash, b0s, davidke20, good-karma, esteemapp, esteem.app, ecency, ecency.stats, kimzwarch, namchau, halleyleow, joeliew, vamos-amigo, bichen, mia-cc, fintian, twicejoy, xecency, marriot5464, drwom, sasaadrian, photographercr, mao317, rivalhw, cn-reader, philipmak,