From cb868cfa41072e08987e1c32f117483445ba197d Mon Sep 17 00:00:00 2001 From: Ben Wiederhake Date: Sat, 23 Oct 2021 15:43:59 +0200 Subject: [PATCH] AK+Everywhere: Make Base64 decoding fallible --- AK/Base64.cpp | 39 ++++++++++++------- AK/Base64.h | 3 +- Tests/AK/TestBase64.cpp | 14 ++++--- Userland/Applications/Mail/MailWidget.cpp | 2 +- Userland/Applications/PixelPaint/Image.cpp | 4 +- Userland/Libraries/LibCrypto/ASN1/PEM.cpp | 6 ++- Userland/Libraries/LibHTTP/HttpRequest.cpp | 5 ++- .../LibWeb/Bindings/WindowObject.cpp | 6 ++- .../LibWeb/Loader/ResourceLoader.cpp | 14 +++++-- Userland/Utilities/base64.cpp | 6 ++- Userland/Utilities/telws.cpp | 6 ++- 11 files changed, 73 insertions(+), 32 deletions(-) diff --git a/AK/Base64.cpp b/AK/Base64.cpp index 006eeebf1d8..4d0a419f4ad 100644 --- a/AK/Base64.cpp +++ b/AK/Base64.cpp @@ -6,10 +6,7 @@ #include #include -#include -#include #include -#include #include #include @@ -33,7 +30,8 @@ static constexpr auto make_alphabet() static constexpr auto make_lookup_table() { constexpr auto alphabet = make_alphabet(); - Array table {}; + Array table; + table.fill(-1); for (size_t i = 0; i < alphabet.size(); ++i) { table[alphabet[i]] = i; } @@ -50,19 +48,31 @@ size_t calculate_base64_encoded_length(ReadonlyBytes input) return ((4 * input.size() / 3) + 3) & ~3; } -ByteBuffer decode_base64(const StringView& input) +Optional decode_base64(const StringView& input) { - auto get = [&](const size_t offset, bool* is_padding = nullptr) -> u8 { + auto get = [&](const size_t offset, bool* is_padding) -> Optional { constexpr auto table = make_lookup_table(); if (offset >= input.length()) return 0; if (input[offset] == '=') { - if (is_padding) - *is_padding = true; + if (!is_padding) + return {}; + *is_padding = true; return 0; } - return table[static_cast(input[offset])]; + i16 result = table[static_cast(input[offset])]; + if (result < 0) + return {}; + VERIFY(result < 256); + return { result }; }; +#define TRY_GET(index, is_padding) \ + ({ \ + auto _temporary_result = get(index, is_padding); \ + if (!_temporary_result.has_value()) \ + return {}; \ + _temporary_result.value(); \ + }) Vector output; output.ensure_capacity(calculate_base64_decoded_length(input)); @@ -71,10 +81,10 @@ ByteBuffer decode_base64(const StringView& input) bool in2_is_padding = false; bool in3_is_padding = false; - const u8 in0 = get(i); - const u8 in1 = get(i + 1); - const u8 in2 = get(i + 2, &in2_is_padding); - const u8 in3 = get(i + 3, &in3_is_padding); + const u8 in0 = TRY_GET(i, nullptr); + const u8 in1 = TRY_GET(i + 1, nullptr); + const u8 in2 = TRY_GET(i + 2, &in2_is_padding); + const u8 in3 = TRY_GET(i + 3, &in3_is_padding); const u8 out0 = (in0 << 2) | ((in1 >> 4) & 3); const u8 out1 = ((in1 & 0xf) << 4) | ((in2 >> 2) & 0xf); @@ -87,8 +97,7 @@ ByteBuffer decode_base64(const StringView& input) output.append(out2); } - // FIXME: Handle OOM failure. - return ByteBuffer::copy(output).release_value(); + return ByteBuffer::copy(output); } String encode_base64(ReadonlyBytes input) diff --git a/AK/Base64.h b/AK/Base64.h index 51ca2e66305..c1c5388629c 100644 --- a/AK/Base64.h +++ b/AK/Base64.h @@ -7,6 +7,7 @@ #pragma once #include +#include #include #include @@ -16,7 +17,7 @@ size_t calculate_base64_decoded_length(const StringView&); size_t calculate_base64_encoded_length(ReadonlyBytes); -ByteBuffer decode_base64(const StringView&); +Optional decode_base64(const StringView&); String encode_base64(ReadonlyBytes); diff --git a/Tests/AK/TestBase64.cpp b/Tests/AK/TestBase64.cpp index 95a90d06c79..ffc7d15c15e 100644 --- a/Tests/AK/TestBase64.cpp +++ b/Tests/AK/TestBase64.cpp @@ -13,7 +13,9 @@ TEST_CASE(test_decode) { auto decode_equal = [&](const char* input, const char* expected) { - auto decoded = decode_base64(StringView(input)); + auto decoded_option = decode_base64(StringView(input)); + EXPECT(decoded_option.has_value()); + auto decoded = decoded_option.value(); EXPECT(String::copy(decoded) == String(expected)); EXPECT(StringView(expected).length() <= calculate_base64_decoded_length(StringView(input).bytes())); }; @@ -27,12 +29,12 @@ TEST_CASE(test_decode) decode_equal("Zm9vYmFy", "foobar"); } -TEST_CASE(test_decode_nocrash) +TEST_CASE(test_decode_invalid) { - // Any output is fine, we only check that we don't crash here. - decode_base64(StringView("asdf\xffqwer")); - decode_base64(StringView("asdf\x80qwer")); - // TODO: Handle decoding failure. + EXPECT(!decode_base64(StringView("asdf\xffqwe")).has_value()); + EXPECT(!decode_base64(StringView("asdf\x80qwe")).has_value()); + EXPECT(!decode_base64(StringView("asdf:qwe")).has_value()); + EXPECT(!decode_base64(StringView("asdf=qwe")).has_value()); } TEST_CASE(test_encode) diff --git a/Userland/Applications/Mail/MailWidget.cpp b/Userland/Applications/Mail/MailWidget.cpp index 709f7406b50..3a5d9636f47 100644 --- a/Userland/Applications/Mail/MailWidget.cpp +++ b/Userland/Applications/Mail/MailWidget.cpp @@ -493,7 +493,7 @@ void MailWidget::selected_email_to_load() if (selected_alternative_encoding.equals_ignoring_case("7bit") || selected_alternative_encoding.equals_ignoring_case("8bit")) { decoded_data = encoded_data; } else if (selected_alternative_encoding.equals_ignoring_case("base64")) { - decoded_data = decode_base64(encoded_data); + decoded_data = decode_base64(encoded_data).value_or(ByteBuffer()); } else if (selected_alternative_encoding.equals_ignoring_case("quoted-printable")) { decoded_data = IMAP::decode_quoted_printable(encoded_data); } else { diff --git a/Userland/Applications/PixelPaint/Image.cpp b/Userland/Applications/PixelPaint/Image.cpp index 5439c9168e5..5796472f3ee 100644 --- a/Userland/Applications/PixelPaint/Image.cpp +++ b/Userland/Applications/PixelPaint/Image.cpp @@ -100,8 +100,10 @@ Result, String> Image::try_create_from_pixel_paint_json(Jso auto bitmap_base64_encoded = layer_object.get("bitmap").as_string(); auto bitmap_data = decode_base64(bitmap_base64_encoded); + if (!bitmap_data.has_value()) + return String { "Base64 decode failed"sv }; - auto bitmap = try_decode_bitmap(bitmap_data); + auto bitmap = try_decode_bitmap(bitmap_data.value()); if (!bitmap) return String { "Layer bitmap decode failed"sv }; diff --git a/Userland/Libraries/LibCrypto/ASN1/PEM.cpp b/Userland/Libraries/LibCrypto/ASN1/PEM.cpp index 53768f67b4c..c901febe416 100644 --- a/Userland/Libraries/LibCrypto/ASN1/PEM.cpp +++ b/Userland/Libraries/LibCrypto/ASN1/PEM.cpp @@ -35,7 +35,11 @@ ByteBuffer decode_pem(ReadonlyBytes data) break; } auto b64decoded = decode_base64(lexer.consume_line().trim_whitespace(TrimMode::Right)); - if (!decoded.try_append(b64decoded.data(), b64decoded.size())) { + if (!b64decoded.has_value()) { + dbgln("Failed to decode PEM, likely bad Base64"); + return {}; + } + if (!decoded.try_append(b64decoded.value().data(), b64decoded.value().size())) { dbgln("Failed to decode PEM, likely OOM condition"); return {}; } diff --git a/Userland/Libraries/LibHTTP/HttpRequest.cpp b/Userland/Libraries/LibHTTP/HttpRequest.cpp index 5175f75d9c7..fc8e0036944 100644 --- a/Userland/Libraries/LibHTTP/HttpRequest.cpp +++ b/Userland/Libraries/LibHTTP/HttpRequest.cpp @@ -197,7 +197,10 @@ Optional HttpRequest::parse_http_ba auto token = value.substring_view(6); if (token.is_empty()) return {}; - auto decoded_token = String::copy(decode_base64(token)); + auto decoded_token_bb = decode_base64(token); + if (!decoded_token_bb.has_value()) + return {}; + auto decoded_token = String::copy(decoded_token_bb.value()); auto colon_index = decoded_token.find(':'); if (!colon_index.has_value()) return {}; diff --git a/Userland/Libraries/LibWeb/Bindings/WindowObject.cpp b/Userland/Libraries/LibWeb/Bindings/WindowObject.cpp index c2073534b22..3de8c062336 100644 --- a/Userland/Libraries/LibWeb/Bindings/WindowObject.cpp +++ b/Userland/Libraries/LibWeb/Bindings/WindowObject.cpp @@ -386,11 +386,15 @@ JS_DEFINE_OLD_NATIVE_FUNCTION(WindowObject::atob) } auto string = TRY_OR_DISCARD(vm.argument(0).to_string(global_object)); auto decoded = decode_base64(StringView(string)); + if (!decoded.has_value()) { + vm.throw_exception(global_object, JS::ErrorType::InvalidFormat, "Base64"); + return {}; + } // decode_base64() returns a byte string. LibJS uses UTF-8 for strings. Use Latin1Decoder to convert bytes 128-255 to UTF-8. auto decoder = TextCodec::decoder_for("windows-1252"); VERIFY(decoder); - return JS::js_string(vm, decoder->to_utf8(decoded)); + return JS::js_string(vm, decoder->to_utf8(decoded.value())); } JS_DEFINE_OLD_NATIVE_FUNCTION(WindowObject::btoa) diff --git a/Userland/Libraries/LibWeb/Loader/ResourceLoader.cpp b/Userland/Libraries/LibWeb/Loader/ResourceLoader.cpp index 02ba6d19b01..1c2bb062088 100644 --- a/Userland/Libraries/LibWeb/Loader/ResourceLoader.cpp +++ b/Userland/Libraries/LibWeb/Loader/ResourceLoader.cpp @@ -153,10 +153,18 @@ void ResourceLoader::load(LoadRequest& request, Functionsend(buffer, false); + if (buffer.has_value()) { + socket->send(buffer.value(), false); + } else { + outln("Could not send message : Base64 string contains an invalid character."); + } continue; } if (line == ".exit") {