LibVideo: Ensure that syntax element counts don't overflow

Integer overflow could sometimes occur due to counts going above 255,
where the values should instead be clamped at their maximum to avoid
wrapping to 0.
This commit is contained in:
Zaggy1024 2022-09-25 03:18:55 -05:00 committed by Andrew Kaster
parent 7c87a8e302
commit 7d27273dc7
Notes: sideshowbarker 2024-07-17 06:07:23 +09:00

View File

@ -668,80 +668,83 @@ u8 TreeParser::calculate_token_probability(u8 node)
void TreeParser::count_syntax_element(SyntaxElementType type, int value)
{
auto increment = [](u8& count) {
count = min(static_cast<u32>(count) + 1, 255);
};
switch (type) {
case SyntaxElementType::Partition:
m_decoder.m_syntax_element_counter->m_counts_partition[m_ctx][value]++;
increment(m_decoder.m_syntax_element_counter->m_counts_partition[m_ctx][value]);
return;
case SyntaxElementType::IntraMode:
case SyntaxElementType::SubIntraMode:
m_decoder.m_syntax_element_counter->m_counts_intra_mode[m_ctx][value]++;
increment(m_decoder.m_syntax_element_counter->m_counts_intra_mode[m_ctx][value]);
return;
case SyntaxElementType::UVMode:
m_decoder.m_syntax_element_counter->m_counts_uv_mode[m_ctx][value]++;
increment(m_decoder.m_syntax_element_counter->m_counts_uv_mode[m_ctx][value]);
return;
case SyntaxElementType::Skip:
m_decoder.m_syntax_element_counter->m_counts_skip[m_ctx][value]++;
increment(m_decoder.m_syntax_element_counter->m_counts_skip[m_ctx][value]);
return;
case SyntaxElementType::IsInter:
m_decoder.m_syntax_element_counter->m_counts_is_inter[m_ctx][value]++;
increment(m_decoder.m_syntax_element_counter->m_counts_is_inter[m_ctx][value]);
return;
case SyntaxElementType::CompMode:
m_decoder.m_syntax_element_counter->m_counts_comp_mode[m_ctx][value]++;
increment(m_decoder.m_syntax_element_counter->m_counts_comp_mode[m_ctx][value]);
return;
case SyntaxElementType::CompRef:
m_decoder.m_syntax_element_counter->m_counts_comp_ref[m_ctx][value]++;
increment(m_decoder.m_syntax_element_counter->m_counts_comp_ref[m_ctx][value]);
return;
case SyntaxElementType::SingleRefP1:
m_decoder.m_syntax_element_counter->m_counts_single_ref[m_ctx][0][value]++;
increment(m_decoder.m_syntax_element_counter->m_counts_single_ref[m_ctx][0][value]);
return;
case SyntaxElementType::SingleRefP2:
m_decoder.m_syntax_element_counter->m_counts_single_ref[m_ctx][1][value]++;
increment(m_decoder.m_syntax_element_counter->m_counts_single_ref[m_ctx][1][value]);
return;
case SyntaxElementType::MVSign:
m_decoder.m_syntax_element_counter->m_counts_mv_sign[m_mv_component][value]++;
increment(m_decoder.m_syntax_element_counter->m_counts_mv_sign[m_mv_component][value]);
return;
case SyntaxElementType::MVClass0Bit:
m_decoder.m_syntax_element_counter->m_counts_mv_class0_bit[m_mv_component][value]++;
increment(m_decoder.m_syntax_element_counter->m_counts_mv_class0_bit[m_mv_component][value]);
return;
case SyntaxElementType::MVBit:
VERIFY(m_mv_bit < MV_OFFSET_BITS);
m_decoder.m_syntax_element_counter->m_counts_mv_bits[m_mv_component][m_mv_bit][value]++;
increment(m_decoder.m_syntax_element_counter->m_counts_mv_bits[m_mv_component][m_mv_bit][value]);
m_mv_bit = 0xFF;
return;
case SyntaxElementType::TXSize:
m_decoder.m_syntax_element_counter->m_counts_tx_size[m_decoder.m_max_tx_size][m_ctx][value]++;
increment(m_decoder.m_syntax_element_counter->m_counts_tx_size[m_decoder.m_max_tx_size][m_ctx][value]);
return;
case SyntaxElementType::InterMode:
m_decoder.m_syntax_element_counter->m_counts_inter_mode[m_ctx][value]++;
increment(m_decoder.m_syntax_element_counter->m_counts_inter_mode[m_ctx][value]);
return;
case SyntaxElementType::InterpFilter:
m_decoder.m_syntax_element_counter->m_counts_interp_filter[m_ctx][value]++;
increment(m_decoder.m_syntax_element_counter->m_counts_interp_filter[m_ctx][value]);
return;
case SyntaxElementType::MVJoint:
m_decoder.m_syntax_element_counter->m_counts_mv_joint[value]++;
increment(m_decoder.m_syntax_element_counter->m_counts_mv_joint[value]);
return;
case SyntaxElementType::MVClass:
m_decoder.m_syntax_element_counter->m_counts_mv_class[m_mv_component][value]++;
increment(m_decoder.m_syntax_element_counter->m_counts_mv_class[m_mv_component][value]);
return;
case SyntaxElementType::MVClass0FR:
VERIFY(m_mv_class0_bit < CLASS0_SIZE);
m_decoder.m_syntax_element_counter->m_counts_mv_class0_fr[m_mv_component][m_mv_class0_bit][value]++;
increment(m_decoder.m_syntax_element_counter->m_counts_mv_class0_fr[m_mv_component][m_mv_class0_bit][value]);
m_mv_class0_bit = 0xFF;
return;
case SyntaxElementType::MVClass0HP:
m_decoder.m_syntax_element_counter->m_counts_mv_class0_hp[m_mv_component][value]++;
increment(m_decoder.m_syntax_element_counter->m_counts_mv_class0_hp[m_mv_component][value]);
return;
case SyntaxElementType::MVFR:
m_decoder.m_syntax_element_counter->m_counts_mv_fr[m_mv_component][value]++;
increment(m_decoder.m_syntax_element_counter->m_counts_mv_fr[m_mv_component][value]);
return;
case SyntaxElementType::MVHP:
m_decoder.m_syntax_element_counter->m_counts_mv_hp[m_mv_component][value]++;
increment(m_decoder.m_syntax_element_counter->m_counts_mv_hp[m_mv_component][value]);
return;
case SyntaxElementType::Token:
m_decoder.m_syntax_element_counter->m_counts_token[m_tx_size][m_plane > 0][m_decoder.m_is_inter][m_band][m_ctx][min(2, value)]++;
increment(m_decoder.m_syntax_element_counter->m_counts_token[m_tx_size][m_plane > 0][m_decoder.m_is_inter][m_band][m_ctx][min(2, value)]);
return;
case SyntaxElementType::MoreCoefs:
m_decoder.m_syntax_element_counter->m_counts_more_coefs[m_tx_size][m_plane > 0][m_decoder.m_is_inter][m_band][m_ctx][value]++;
increment(m_decoder.m_syntax_element_counter->m_counts_more_coefs[m_tx_size][m_plane > 0][m_decoder.m_is_inter][m_band][m_ctx][value]);
return;
case SyntaxElementType::DefaultIntraMode:
case SyntaxElementType::DefaultUVMode: