diff --git a/paddle/parameter/ParameterUpdaterBase.h b/paddle/parameter/ParameterUpdaterBase.h index 88148d9b769e9..b230e170c15f1 100644 --- a/paddle/parameter/ParameterUpdaterBase.h +++ b/paddle/parameter/ParameterUpdaterBase.h @@ -38,7 +38,7 @@ class ParameterUpdater { virtual void startPass() {} // called by Trainer then finishing a pass, ruturn true if pass accepted - virtual bool finishPass(real cost = 0) { return true; } + virtual bool finishPass() { return true; } // called by Trainer before backward() of a batch // Return the type of pass it needs. This pass type will be passed @@ -112,9 +112,9 @@ class ParameterUpdaterComposite : public ParameterUpdater { [&](int tid, size_t numThreads) { updaters_[tid]->startPass(); }); } - virtual bool finishPass(real cost = 0) { + virtual bool finishPass() { syncThreadPool_->execPlusOwner( - [&](int tid, size_t numThreads) { updaters_[tid]->finishPass(cost); }); + [&](int tid, size_t numThreads) { updaters_[tid]->finishPass(); }); return true; } diff --git a/paddle/trainer/ParameterUpdater.h b/paddle/trainer/ParameterUpdater.h index 4dae77567f8f4..c3207e63ce72b 100644 --- a/paddle/trainer/ParameterUpdater.h +++ b/paddle/trainer/ParameterUpdater.h @@ -102,9 +102,9 @@ class SgdLocalUpdater : public ParameterUpdater { * @param cost sum cost during one pass. * @return true if accept (used for owlqn). */ - virtual bool finishPass(real cost) { + virtual bool finishPass() { optimizer_->finishPass(); - return ParameterUpdater::finishPass(cost); + return ParameterUpdater::finishPass(); } /** @@ -220,9 +220,9 @@ class SgdUpdaterWithCpuAverager : public SgdLocalUpdater { averager_->startPass(); SgdLocalUpdater::startPass(); } - virtual bool finishPass(real cost) { + virtual bool finishPass() { averager_->finishPass(); - return SgdLocalUpdater::finishPass(cost); + return SgdLocalUpdater::finishPass(); } /// apply the averaged parameter to PARAMETER_VALUE diff --git a/paddle/trainer/RemoteParameterUpdater.cpp b/paddle/trainer/RemoteParameterUpdater.cpp index 630f55d998d9f..6939738203f41 100644 --- a/paddle/trainer/RemoteParameterUpdater.cpp +++ b/paddle/trainer/RemoteParameterUpdater.cpp @@ -309,7 +309,7 @@ void RemoteParameterUpdater::startPass() { } } -bool RemoteParameterUpdater::finishPass(real cost) { +bool RemoteParameterUpdater::finishPass() { if (localUpdater_) { localUpdater_->finishPass(); } @@ -712,7 +712,7 @@ void SparseRemoteParameterUpdater::startPass() { } } -bool SparseRemoteParameterUpdater::finishPass(real cost) { +bool SparseRemoteParameterUpdater::finishPass() { if (config_.algorithm() == TrainAlgorithm::SGD) { parameterClient_->waitPassFinish(); } else { diff --git a/paddle/trainer/RemoteParameterUpdater.h b/paddle/trainer/RemoteParameterUpdater.h index ec6ed443d33db..7794b209009a3 100644 --- a/paddle/trainer/RemoteParameterUpdater.h +++ b/paddle/trainer/RemoteParameterUpdater.h @@ -90,7 +90,7 @@ class RemoteParameterUpdater : public ParameterUpdater { */ virtual void finishBatch(real cost); virtual void startPass(); - virtual bool finishPass(real cost); + virtual bool finishPass(); #ifndef PADDLE_DISABLE_TIMER virtual void setForwardbackwardTime(uint64_t delta) { @@ -281,7 +281,7 @@ class SparseRemoteParameterUpdater : public ParameterUpdater { /// send all sparse related parameters to all pservers virtual void finishBatch(real cost); virtual void startPass(); - virtual bool finishPass(real cost); + virtual bool finishPass(); virtual void apply(); virtual void restore(); diff --git a/paddle/trainer/ThreadParameterUpdater.cpp b/paddle/trainer/ThreadParameterUpdater.cpp index 2a76d5723ccb6..870d4a4b0246f 100644 --- a/paddle/trainer/ThreadParameterUpdater.cpp +++ b/paddle/trainer/ThreadParameterUpdater.cpp @@ -70,7 +70,7 @@ void SgdThreadUpdater::startPass() { } } -bool SgdThreadUpdater::finishPass(real cost) { +bool SgdThreadUpdater::finishPass() { catchUpWith(); for (auto& para : parameters_) { diff --git a/paddle/trainer/ThreadParameterUpdater.h b/paddle/trainer/ThreadParameterUpdater.h index 198435c0f3005..880f1f9ddc49a 100644 --- a/paddle/trainer/ThreadParameterUpdater.h +++ b/paddle/trainer/ThreadParameterUpdater.h @@ -47,7 +47,7 @@ class SgdThreadUpdater : public ParameterUpdater { virtual void startPass(); // Use the finishPass() function of the base optimizer. - virtual bool finishPass(real cost); + virtual bool finishPass(); virtual void init(const std::vector& parameters); virtual PassType startBatch(int64_t batchSize); diff --git a/paddle/trainer/Trainer.cpp b/paddle/trainer/Trainer.cpp index 1eec2c432d235..031e3b7cf199c 100644 --- a/paddle/trainer/Trainer.cpp +++ b/paddle/trainer/Trainer.cpp @@ -537,7 +537,7 @@ void Trainer::trainOnePassBatch(int passId) { trainerInternal_.getGradientMachine()->onPassEnd(); - bool accepted = trainerInternal_.getParameterUpdater()->finishPass(cost); + bool accepted = trainerInternal_.getParameterUpdater()->finishPass(); globalStat.setThreadInfo(true); globalStat.printAllStatus();