Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove unused cost parameter in ParameterUpdater #970

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions paddle/parameter/ParameterUpdaterBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class ParameterUpdater {
virtual void startPass() {}

// called by Trainer then finishing a pass, ruturn true if pass accepted
virtual bool finishPass(real cost = 0) { return true; }
virtual bool finishPass() { return true; }
Copy link
Member

@jacquesqiao jacquesqiao Dec 21, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个 cost一开始是用来做啥的,看起来是记录了每一个batch的cost,现在是使用另外一个机制替代了么


// called by Trainer before backward() of a batch
// Return the type of pass it needs. This pass type will be passed
Expand Down Expand Up @@ -112,9 +112,9 @@ class ParameterUpdaterComposite : public ParameterUpdater {
[&](int tid, size_t numThreads) { updaters_[tid]->startPass(); });
}

virtual bool finishPass(real cost = 0) {
virtual bool finishPass() {
syncThreadPool_->execPlusOwner(
[&](int tid, size_t numThreads) { updaters_[tid]->finishPass(cost); });
[&](int tid, size_t numThreads) { updaters_[tid]->finishPass(); });
return true;
}

Expand Down
8 changes: 4 additions & 4 deletions paddle/trainer/ParameterUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,9 @@ class SgdLocalUpdater : public ParameterUpdater {
* @param cost sum cost during one pass.
* @return true if accept (used for owlqn).
*/
virtual bool finishPass(real cost) {
virtual bool finishPass() {
optimizer_->finishPass();
return ParameterUpdater::finishPass(cost);
return ParameterUpdater::finishPass();
}

/**
Expand Down Expand Up @@ -220,9 +220,9 @@ class SgdUpdaterWithCpuAverager : public SgdLocalUpdater {
averager_->startPass();
SgdLocalUpdater::startPass();
}
virtual bool finishPass(real cost) {
virtual bool finishPass() {
averager_->finishPass();
return SgdLocalUpdater::finishPass(cost);
return SgdLocalUpdater::finishPass();
}

/// apply the averaged parameter to PARAMETER_VALUE
Expand Down
4 changes: 2 additions & 2 deletions paddle/trainer/RemoteParameterUpdater.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -309,7 +309,7 @@ void RemoteParameterUpdater::startPass() {
}
}

bool RemoteParameterUpdater::finishPass(real cost) {
bool RemoteParameterUpdater::finishPass() {
if (localUpdater_) {
localUpdater_->finishPass();
}
Expand Down Expand Up @@ -711,7 +711,7 @@ void SparseRemoteParameterUpdater::startPass() {
}
}

bool SparseRemoteParameterUpdater::finishPass(real cost) {
bool SparseRemoteParameterUpdater::finishPass() {
if (config_.algorithm() == TrainAlgorithm::SGD) {
parameterClient_->waitPassFinish();
} else {
Expand Down
4 changes: 2 additions & 2 deletions paddle/trainer/RemoteParameterUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class RemoteParameterUpdater : public ParameterUpdater {
*/
virtual void finishBatch(real cost);
virtual void startPass();
virtual bool finishPass(real cost);
virtual bool finishPass();

#ifndef PADDLE_DISABLE_TIMER
virtual void setForwardbackwardTime(uint64_t delta) {
Expand Down Expand Up @@ -281,7 +281,7 @@ class SparseRemoteParameterUpdater : public ParameterUpdater {
/// send all sparse related parameters to all pservers
virtual void finishBatch(real cost);
virtual void startPass();
virtual bool finishPass(real cost);
virtual bool finishPass();

virtual void apply();
virtual void restore();
Expand Down
2 changes: 1 addition & 1 deletion paddle/trainer/ThreadParameterUpdater.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void SgdThreadUpdater::startPass() {
}
}

bool SgdThreadUpdater::finishPass(real cost) {
bool SgdThreadUpdater::finishPass() {
catchUpWith();

for (auto& para : parameters_) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/trainer/ThreadParameterUpdater.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class SgdThreadUpdater : public ParameterUpdater {
virtual void startPass();

// Use the finishPass() function of the base optimizer.
virtual bool finishPass(real cost);
virtual bool finishPass();

virtual void init(std::vector<ParameterPtr>& parameters);
virtual PassType startBatch(int64_t batchSize);
Expand Down
2 changes: 1 addition & 1 deletion paddle/trainer/Trainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ void Trainer::trainOnePassBatch(int passId) {

trainerInternal_.getGradientMachine()->onPassEnd();

bool accepted = trainerInternal_.getParameterUpdater()->finishPass(cost);
bool accepted = trainerInternal_.getParameterUpdater()->finishPass();

globalStat.setThreadInfo(true);
globalStat.printAllStatus();
Expand Down