昇思MindSpore学习入门-模型训练

模型训练

模型训练一般分为四个步骤:

  1. 构建数据集。
  2. 定义神经网络模型。
  3. 定义超参、损失函数及优化器。
  4. 输入数据集进行训练与评估。

现在我们有了数据集和模型后,可以进行模型的训练与评估。

构建数据集

首先从数据集 Dataset加载代码,构建数据集。

定义神经网络模型

从网络构建中加载代码,构建一个神经网络模型。

定义超参、损失函数和优化器

超参

超参(Hyperparameters)是可以调整的参数,可以控制模型训练优化的过程,不同的超参数值可能会影响模型训练和收敛速度。目前深度学习模型多采用批量随机梯度下降算法进行优化,随机梯度下降算法的原理如下:

公式中,𝑛是批量大小(batch size),𝜂是学习率(learning rate)。另外,𝑤𝑡为训练轮次𝑡中的权重参数,∇𝑙为损失函数的导数。除了梯度本身,这两个因子直接决定了模型的权重更新,从优化本身来看,它们是影响模型性能收敛最重要的参数。一般会定义以下超参用于训练:

  • 训练轮次(epoch):训练时遍历数据集的次数。
  • 批次大小(batch size):数据集进行分批读取训练,设定每个批次数据的大小。batch size过小,花费时间多,同时梯度震荡严重,不利于收敛;batch size过大,不同batch的梯度方向没有任何变化,容易陷入局部极小值,因此需要选择合适的batch size,可以有效提高模型精度、全局收敛。
  • 学习率(learning rate):如果学习率偏小,会导致收敛的速度变慢,如果学习率偏大,则可能会导致训练不收敛等不可预测的结果。梯度下降法被广泛应用在最小化模型误差的参数优化算法上。梯度下降法通过多次迭代,并在每一步中最小化损失函数来预估模型的参数。学习率就是在迭代过程中,会控制模型的学习进度。

损失函数

损失函数(loss function)用于评估模型的预测值(logits)和目标值(targets)之间的误差。训练模型时,随机初始化的神经网络模型开始时会预测出错误的结果。损失函数会评估预测结果与目标值的相异程度,模型训练的目标即为降低损失函数求得的误差。

常见的损失函数包括用于回归任务的nn.MSELoss(均方误差)和用于分类的nn.NLLLoss(负对数似然)等。 nn.CrossEntropyLoss 结合了nn.LogSoftmaxnn.NLLLoss,可以对logits 进行归一化并计算预测误差。

优化器

模型优化(Optimization)是在每个训练步骤中调整模型参数以减少模型误差的过程。MindSpore提供多种优化算法的实现,称之为优化器(Optimizer)。优化器内部定义了模型的参数优化过程(即梯度如何更新至模型参数),所有优化逻辑都封装在优化器对象中。在这里,我们使用SGD(Stochastic Gradient Descent)优化器。

我们通过model.trainable_params()方法获得模型的可训练参数,并传入学习率超参来初始化优化器。

训练与评估

设置了超参、损失函数和优化器后,我们就可以循环输入数据来训练模型。一次数据集的完整迭代循环称为一轮(epoch)。每轮执行训练时包括两个步骤:

  1. 训练:迭代训练数据集,并尝试收敛到最佳参数。
  2. 验证/测试:迭代测试数据集,以检查模型性能是否提升。

接下来我们定义用于训练的train_loop函数和用于测试的test_loop函数。

使用函数式自动微分,需先定义正向函数forward_fn,使用value_and_grad获得微分函数grad_fn。然后,我们将微分函数和优化器的执行封装为train_step函数,接下来循环迭代数据集进行训练即可。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.mfbz.cn/a/763852.html

如若内容造成侵权/违法违规/事实不符,请联系我们进行投诉反馈qq邮箱809451989@qq.com,一经查实,立即删除!

相关文章

RT-Thread Studio与CubeMX联合编程之rtthread stm32h743的使用(十一)spi设备SFUD驱动的使用

我们要在rtthread studio 开发环境中建立stm32h743xih6芯片的工程。我们使用一块stm32h743及fpga的核心板完成相关实验,核心板如图: 1.建立新工程,选择相应的芯片型号及debug引脚及调试器 2.编译下载,可以看到串口打印正常 3.…

超实用的80个网络基础知识!(非常详细)零基础入门到精通,收藏这一篇就够了

点击上方 网络技术干货圈,选择 设为星标 优质文章,及时送达 转载请注明以下内容: 来源:公众号【网络技术干货圈】 作者:圈圈 ID:wljsghq 基础网络概念 1. 网络基础概述 什么是计算机网络 计算机网络是一…

全自动封箱机:如何助力企业实现智能化升级

在飞速发展的工业自动化时代,全自动封箱机以其高效、精准、稳定的特点,成为了生产线上的不可或缺的一员。它不仅大大地提高了生产效率,降低了人工成本,更在产品质量控制、安全性等方面发挥了重要作用。星派将深入探讨全自动封箱机…

基于SpringBoot民宿管理系统设计和实现(源码+LW+调试文档+讲解等)

💗博主介绍:✌全网粉丝10W,CSDN作者、博客专家、全栈领域优质创作者,博客之星、平台优质作者、专注于Java、小程序技术领域和毕业项目实战✌💗 🌟文末获取源码数据库🌟 感兴趣的可以先收藏起来,…

华为云物联网的使用

这里我们设置三个属性 1.温度DHT11_T 上传 2.湿度DHT11_H 上传 3.风扇motor 远程控制(云平台控制设备端) 发布主题: $oc/devices/{device_id}/sys/properties/report 发布主题时,需要上传数据,这个数据格式是JSON格式…

充气膜羽毛球馆投资需要多少钱—轻空间

充气膜羽毛球馆是一种现代化的运动设施,以其灵活的结构设计和高效的能耗管理受到广泛关注。投资建设一个充气膜羽毛球馆,涉及多个方面的成本,包括基础建设、膜材选择、系统配置以及运营维护费用。轻空间将详细分析投资建设充气膜羽毛球馆的成…

【C++知识点总结全系列 (06)】:STL六大组件详细介绍与总结(配置器、容器、迭代器、适配器、算法、仿函数)

STL六大组件目录 前言1、配置器(1)What(2)Why(3)HowA.调用new和delete实现内存分配与销毁B.STL Allocator (4)allocator类A.WhatB.HowC.allocator的算法 2、容器(1)What(2)Which(有哪些容器)(3)序列容器(顺序容器)A.WhichB.array&…

Langchain-Chatchat本地部署记录,三分钟学会!

1.前言: 最近AI爆发式的火,忆往昔尤记得16,17那会移动互联网是特别火热的,也造富了一批公司和个人,出来了很多精妙的app应用。现在轮到AI发力了,想想自己也应该参与到这场时代的浪潮之中,所以就找了开源的…

【微服务网关——https与http2代理实现】

1.https与http2代理 1.1 重新认识https与http2 https是http安全版本http2是一种传输协议两者并没有本质联系 1.1.1 https与http的区别 HTTP(超文本传输协议)和 HTTPS(安全超文本传输协议)是用于在网络上交换数据的两种协议。H…

7月刷题指南|考研数学强化30天吃透《严选题》

马上就要进入7月份了,相信很多小伙伴的基础阶段已经接近尾声了。特别是数二的同学们,应该已经完成了基础部分。而数一和数三的同学由于多了一门概率论,可能需要更多的时间。不管是哪种情况,我个人认为,最晚也应该在暑假…

Qt 使用代码布局,而不使用UI布局

一、工程的建立: 1、打开Qt Creator,文件,新建文件或项目 2、选择Application,Qt Widgets Application 3、写入名称,选择qmake 4、选择基类Base class,去除Generate form 务必选择QWidget,若…

django开源电子文档管理系统_Django简介、ORM、核心模块

Django简介 Django是一种开源的大而且全的Web应用框架,是由python语言来编写的。他采用了MVC模式,Django最初是被开发来用于管理劳伦斯出版集团下的一些以新闻为主内容的网站。一款CMS(内容管理系统)软件。并于 2005 年 7 月在 BSD 许可证下发布。这套框…

传神论文中心|第15期人工智能领域论文推荐

在人工智能领域的快速发展中,我们不断看到令人振奋的技术进步和创新。近期,开放传神(OpenCSG)社区发现了一些值得关注的成就。传神社区本周也为对AI和大模型感兴趣的读者们提供了一些值得一读的研究工作的简要概述以及它们各自的论…

什么是脏读、幻读、不可重复读

数据库事务 数据库事务是指作为单个逻辑工作单元执行的一系列操作,这些操作要么全部成功执行,要么全部失败回滚,以保持数据库的一致性和完整性。在多线程或多用户同时操作时,难免会出现错乱与冲突,这就需要引入事务的…

【C# winForm】ProgressBar进度条

1.控件介绍 进度条通常用于显示代码的执行进程进度,在一些复杂功能交互体验时告知用户进程还在继续。 在属性栏中,有三个值常用: Value表示当前值,Minimum表示进度条范围下限,Maximum表示进度条范围上限。 2.简单实…

【产品经理】订单处理12-订单的取消与反取消

在电商ERP系统中,订单取消与反取消也是常见功能之一。 订单取消与反取消也是电商ERP系统的常见功能,本次主要讲解下订单取消与反取消的逻辑。 一、订单取消 在电商ERP系统中,订单取消一般由审单员操作,此类取消一般是由于上下游…

商家团购app微信小程序模板

手机微信商家团购小程序页面,商家订餐外卖小程序前端模板下载。包含:团购主页、购物车订餐页面、我的订单、个人主页等。 商家团购app微信小程序模板

sublime如何运行Html文件?

背景: 在sublime上面写了html代码以后,怎么运行html文件来进行debug呢?如果去点击保存的HTML文件,每次这样就会很麻烦,能不能直接在sublime里面点什么就可以直接打开浏览器运行呢?答案是OK的。 1-确认Vie…

Android面试题经典之Glide取消加载以及线程池优化

本文首发于公众号“AntDream”,欢迎微信搜索“AntDream”或扫描文章底部二维码关注,和我一起每天进步一点点 Glide通过生命周期取消加载 生命周期回调过程 onStop —>RequestManager.onStop –>RequestTracker.pauseRequest –> SingleRequest…

SpringSecurity6 | 基于数据库实现登录认证

SpringSecurity6 | 基于数据库认证 ✅作者简介:大家好,我是Leo,热爱Java后端开发者,一个想要与大家共同进步的男人😉😉 🍎个人主页:Leo的博客 💞当前专栏: 循序渐进学SpringSecurity6 ✨特色专栏: MySQL学习 🥭本文内容: SpringSecurity6 | 基于数据库实现登…
最新文章