diff --git a/src/regex_impl.hh b/src/regex_impl.hh index 7997994c6..3f0d5850d 100644 --- a/src/regex_impl.hh +++ b/src/regex_impl.hh @@ -287,7 +287,15 @@ public: ArrayView captures() const { if (m_captures >= 0) - return { m_saves[m_captures].pos, m_program.save_count }; + { + auto& saves = m_saves[m_captures]; + for (int i = 0; i < m_program.save_count; ++i) + { + if ((saves.valid_mask & (1 << i)) == 0) + saves.pos[i] = Iterator{}; + } + return { saves.pos, m_program.save_count }; + } return {}; } @@ -295,12 +303,15 @@ private: struct Saves { int32_t refcount; - int32_t next_free; + union { + int32_t next_free; + uint32_t valid_mask; + }; Iterator* pos; }; template - int16_t new_saves(Iterator* pos) + int16_t new_saves(Iterator* pos, uint32_t valid_mask) { kak_assert(not copy or pos != nullptr); const auto count = m_program.save_count; @@ -310,18 +321,16 @@ private: Saves& saves = m_saves[res]; m_first_free = saves.next_free; kak_assert(saves.refcount == 1); - if (copy) - std::copy_n(pos, count, saves.pos); - else - std::fill_n(saves.pos, count, Iterator{}); - + if constexpr (copy) + std::copy_n(pos, std::bit_width(valid_mask), saves.pos); + saves.valid_mask = valid_mask; return res; } auto* new_pos = reinterpret_cast(operator new (count * sizeof(Iterator))); for (size_t i = 0; i < count; ++i) new (new_pos+i) Iterator{copy ? pos[i] : Iterator{}}; - m_saves.push_back({1, 0, new_pos}); + m_saves.push_back({1, {.valid_mask=valid_mask}, new_pos}); return static_cast(m_saves.size() - 1); } @@ -418,16 +427,17 @@ private: } break; case CompiledRegex::Save: - if (mode & RegexMode::NoSaves) + if constexpr (mode & RegexMode::NoSaves) break; if (thread.saves < 0) - thread.saves = new_saves(nullptr); - else if (m_saves[thread.saves].refcount > 1) + thread.saves = new_saves(nullptr, 0); + else if (auto& saves = m_saves[thread.saves]; saves.refcount > 1) { - --m_saves[thread.saves].refcount; - thread.saves = new_saves(m_saves[thread.saves].pos); + --saves.refcount; + thread.saves = new_saves(saves.pos, saves.valid_mask); } m_saves[thread.saves].pos[inst.param.save_index] = pos; + m_saves[thread.saves].valid_mask |= (1 << inst.param.save_index); break; case CompiledRegex::CharClass: if (pos == config.end)