博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
对随机梯度下降的一些使用心得
阅读量:4078 次
发布时间:2019-05-25

本文共 4630 字,大约阅读时间需要 15 分钟。

1:对于随机梯度下降SGD可能大家都比较了解,也很熟悉,说起来也很简单,在使用中我们一般用的是带mini batch的SGD。这个也描述起来很简单,但是在使用中还是有一些trick才可以的,最近在做一个实验,由于L-BFGS的速度太慢所以选择了带Mini-batch的SGD,我来说下我的心得。

2:其实梯度下降算法,在使用的时候无非是要考虑到2个方面,一个是方向,一个是步长,方向决定你是否走在了优化的道路上还是优化道路的负方向,步长是决定你要走多久才能到最优的地方。对于第一个问题很好解决,就是求梯度,梯度的负方向就是了。难的是求步长,如果步子太小,则需要很长的时间才能走到目的地,如果步子过大可能在目的地的周围来走震荡。所以重点在于如何选择步长。

3:对于随机梯度中,步长的选择方法有很多,最简单的莫过于设置一个比较小的步长,让算法慢慢去运行去就是了,也有别的方法就是可以计算步长的算法,这个我也试过了,反正不好弄,我就选择了最简单的小步长。但是何时算法自己知道差不多了可以停止了呢?我主要想说下这个问题:很多人都说设置迭代一定的次数或者比较两次梯度的变化,或者两次cost的变化,这个地方我不是特别同意,因为尤其是设置了一个小步长的时候,迭代一定次数当然可以,但是这个次数到底多少可以?没办法知道,所以如果设置了一定次数,次数过小的话肯定此时并没有达到最优的或者很接近最优点的地方,如果过大理论上是可以的,但是多少才算过大?你觉得10万次很多,但是不一定10万次算法可以达到,所以这个我觉得不太靠谱,对于比较两次梯度变化或者cost变化,同样存在这个问题,如果步长很小的话,那么同样连续两次之间的梯度和cost变化很小也是无法保证此时接近最优点的啊。

4:这里我介绍一个方法叫做early-stop,其实也是很成熟的方法了,大概思路是在训练的过程中,使用验证集周期性的来测试当前计算出来的参数,在验证集上测试当前参数对验证集的效果,如果效果可以,就保存起来,当很长一段时间都是此效果的话那么就迭代停止,该组参数就认为是不错的参数。这个方法叫做交叉验证,但是我看到有的地方写的是交叉验证用于超参的选择,而这个地方我不是选取的超参,所以不知道到底用对了没有。

5:我在下面贴出来我的整个sgd的matlab的代码,来大概说下这个early-stop。

function [ optParams ] = SGD( funObj,theta,data,labels,options )% Runs stochastic gradient descent with momentum to optimize the% parameters for the given objective.%% Parameters:%  funObj     -  function handle which accepts as input theta,%                data, labels and returns cost and gradient w.r.t%                to theta.%  theta      -  unrolled parameter vector%  data       -  stores data in m x n x numExamples tensor%  labels     -  corresponding labels in numExamples x 1 vector%  options    -  struct to store specific options for optimization%% Returns:%  opttheta   -  optimized parameter vector%% Options (* required)%  epochs*     - number of epochs through data%  alpha*      - initial learning rate%  minibatch*  - size of minibatch%  momentum    - momentum constant, defualts to 0.9%% Setupassert(all(isfield(options,{'epochs','alpha','minibatch'})),...        'Some options not defined');if ~isfield(options,'momentum')    options.momentum = 0.9;end;epochs = options.epochs;alpha = options.alpha;minibatch = options.minibatch;m = length(labels); % training set size% Setup for momentummom = 0.5;momIncrease = 20;velocity = zeros(size(theta));%%======================================================================%% SGD looppatience = options.patience;patienceIncreasement = options.patienceIncreasement;improvement = options.improvement;validationHandler = options.validationHandler;bestParams = [];bestValidationLoss = inf;validationFrequency = min(ceil(m/minibatch), patience/2);doneLooping = false;it = 0;e = 0;while (e < epochs) && (~doneLooping)	e = e + 1;        % randomly permute indices of data for quick minibatch sampling    rp = randperm(m);        for s=1:minibatch:(m-minibatch+1)        it = it + 1;        % increase momentum after momIncrease iterations        if it == momIncrease            mom = options.momentum;        end;        % get next randomly selected minibatch        mb_data = data(:,rp(s:s+minibatch-1));        mb_labels = labels(rp(s:s+minibatch-1));        % evaluate the objective function on the next minibatch        [cost grad] = funObj(theta,mb_data,mb_labels);                % early stop        if mod(it, validationFrequency) == 0            validationLoss = validationHandler(theta);            if validationLoss < bestValidationLoss                fprintf('validate=====================================current cost:%f, last cost:%f\n', validationLoss, bestValidationLoss);                if validationLoss < bestValidationLoss*improvement                    patience = max(patience, it* patienceIncreasement);                    bestParams.param = theta;                    bestParams.loss = validationLoss;                end                 bestValidationLoss = validationLoss;            end        end                if patience < it            doneLooping = true;            fprintf('stop due to patience[%d] greater than iterate[%d]\n', patience, it);            break;        end                % Instructions: Add in the weighted velocity vector to the        % gradient evaluated above scaled by the learning rate.        % Then update the current weights theta according to the        % sgd update rule                %%% YOUR CODE HERE %%%        velocity = mom*velocity + alpha*grad;        theta = theta - velocity;       % fprintf('Epoch %d: Cost on iteration %d is %f\n',e,it,cost);    end;    % aneal learning rate by factor of two after each epoch    alpha = alpha/2.0;end;optParams = theta;end
这个代码中包含了对步长的一些侧率,比如momentum了,还有每个epoch之后除以2了,这些可以忽略的,直接看early-stop就好了。我想说明的是对patience的改变,因为当前得到一个更好的结果的话,说明迭代需要时间久点,那么就把patience变大让多迭代点次数,最外层的迭代次数我一般设置的非常大。。
patience

转载地址:http://xdini.baihongyu.com/

你可能感兴趣的文章
缓存篇-使用Redis进行分布式锁应用
查看>>
缓存篇-Redisson的使用
查看>>
phpquery抓取网站内容简单介绍
查看>>
找工作准备的方向(4月22日写的)
查看>>
关于fwrite写入文件后打开查看是乱码的问题
查看>>
用结构体指针前必须要用malloc,不然会出现段错误
查看>>
Linux系统中的美
查看>>
一些实战项目(linux应用层编程,多线程编程,网络编程)
查看>>
我觉得专注于去学东西就好了,与世无争。
查看>>
原来k8s docker是用go语言写的,和现在所讲的go是一个东西!
查看>>
进程的创建分为两步,先fork(),再exec()
查看>>
可折叠机架
查看>>
不要用XXD电调
查看>>
弄底层基础的东西往往慢,枯燥,要慢慢磨
查看>>
使用STM32Cube可以直接生成使用FreeRTOS的工程
查看>>
STM32CubeMX 真的不要太好用
查看>>
STM32CubeMX介绍、下载与安装
查看>>
感觉也可以自己做公众号,写一些好点的文章,也和讲课类似,输出倒逼输入
查看>>
STM32CubeMX使用方法及功能介绍
查看>>
pixhawk固件的安装
查看>>