Merged PR 23407: Fix incorrect/missing gradient accumulation for affine biases

This PR fixes incorrect/missing gradient accumulation with delay > 1 or large effective batch size of biases of affine operations.
This commit is contained in:
Marcin Junczys-Dowmunt 2022-04-08 16:00:04 +00:00
parent 16bfa0c913
commit d5c7372a67
3 changed files with 6 additions and 5 deletions

View File

@ -11,6 +11,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
### Added ### Added
### Fixed ### Fixed
- Fix incorrect/missing gradient accumulation with delay > 1 or large effective batch size of biases of affine operations.
- Fixed case augmentation with multi-threaded reading. - Fixed case augmentation with multi-threaded reading.
- Scripts using PyYAML now use `safe_load`; see https://msg.pyyaml.org/load - Scripts using PyYAML now use `safe_load`; see https://msg.pyyaml.org/load

View File

@ -1 +1 @@
v1.11.5 v1.11.6

View File

@ -334,7 +334,7 @@ public:
false, false,
1.0, 1.0,
scalar_, computeTypeB)), scalar_, computeTypeB)),
NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC)) NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 1.f, 1.f, computeTypeC))
}; };
if(transA_ && !transB_) if(transA_ && !transB_)
@ -353,7 +353,7 @@ public:
false, false,
1.0, 1.0,
scalar_, computeTypeB)), scalar_, computeTypeB)),
NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC)) NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 1.f, 1.f, computeTypeC))
}; };
if(transA_ && transB_) if(transA_ && transB_)
@ -372,7 +372,7 @@ public:
true, true,
1.0, 1.0,
scalar_, computeTypeB)), scalar_, computeTypeB)),
NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC)) NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 1.f, 1.f, computeTypeC))
}; };
return { return {
@ -390,7 +390,7 @@ public:
false, false,
1.0, 1.0,
scalar_, computeTypeB)), scalar_, computeTypeB)),
NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 0.f, 1.f, computeTypeC)) NodeOp(Prod(child(2)->grad(), child(3)->val(), adj_, true, false, 1.f, 1.f, computeTypeC))
}; };
} }