From 2a65036680710e5d5b667e3e888f5b37dbdc4961 Mon Sep 17 00:00:00 2001 From: twystd Date: Sun, 25 Aug 2024 20:06:41 -0700 Subject: [PATCH] midi: fixed fuzzing errors --- format/midi/metaevents.go | 187 ++++++++++++++++++++++++++++---------- format/midi/midi.go | 15 ++- format/midi/sysex.go | 36 +++++++- 3 files changed, 185 insertions(+), 53 deletions(-) diff --git a/format/midi/metaevents.go b/format/midi/metaevents.go index 34accdfd..817d3138 100644 --- a/format/midi/metaevents.go +++ b/format/midi/metaevents.go @@ -203,16 +203,19 @@ func decodeMetaEvent(d *decode.D, event uint8, ctx *context) { func decodeSequenceNumber(d *decode.D) { d.FieldUintFn("sequenceNumber", func(d *decode.D) uint64 { - data := vlf(d) seqno := uint64(0) - if len(data) > 0 { - seqno += uint64(data[0]) - } + if data, err := vlf(d); err != nil { + d.Errorf("%v", err) + } else { + if len(data) > 0 { + seqno += uint64(data[0]) + } - if len(data) > 1 { - seqno <<= 8 - seqno += uint64(data[1]) + if len(data) > 1 { + seqno <<= 8 + seqno += uint64(data[1]) + } } return seqno @@ -221,66 +224,123 @@ func decodeSequenceNumber(d *decode.D) { func decodeText(d *decode.D) { d.FieldStrFn("text", func(d *decode.D) string { - return string(vlf(d)) + if data, err := vlf(d); err != nil { + d.Errorf("%v", err) + } else { + return string(data) + } + + return "" }) } func decodeCopyright(d *decode.D) { d.FieldStrFn("copyright", func(d *decode.D) string { - return string(vlf(d)) + if data, err := vlf(d); err != nil { + d.Errorf("%v", err) + } else { + return string(data) + } + + return "" }) } func decodeTrackName(d *decode.D) { d.FieldStrFn("name", func(d *decode.D) string { - return string(vlf(d)) + if data, err := vlf(d); err != nil { + d.Errorf("%v", err) + } else { + return string(data) + } + + return "" }) } func decodeInstrumentName(d *decode.D) { d.FieldStrFn("instrument", func(d *decode.D) string { - return string(vlf(d)) + if data, err := vlf(d); err != nil { + d.Errorf("%v", err) + } else { + return string(data) + } + + return "" }) } func decodeLyric(d *decode.D) { d.FieldStrFn("lyric", func(d *decode.D) string { - return string(vlf(d)) + if data, err := vlf(d); err != nil { + d.Errorf("%v", err) + } else { + return string(data) + } + + return "" }) } func decodeMarker(d *decode.D) { d.FieldStrFn("marker", func(d *decode.D) string { - return string(vlf(d)) + if data, err := vlf(d); err != nil { + d.Errorf("%v", err) + } else { + return string(data) + } + + return "" }) } func decodeCuePoint(d *decode.D) { d.FieldStrFn("cue", func(d *decode.D) string { - return string(vlf(d)) + if data, err := vlf(d); err != nil { + d.Errorf("%v", err) + } else { + return string(data) + } + + return "" }) } func decodeProgramName(d *decode.D) { d.FieldStrFn("program", func(d *decode.D) string { - return string(vlf(d)) + if data, err := vlf(d); err != nil { + d.Errorf("%v", err) + } else { + return string(data) + } + + return "" }) } func decodeDeviceName(d *decode.D) { d.FieldStrFn("device", func(d *decode.D) string { - return string(vlf(d)) + if data, err := vlf(d); err != nil { + d.Errorf("%v", err) + } else { + return string(data) + } + + return "" }) } func decodeMIDIChannelPrefix(d *decode.D) { d.FieldUintFn("channel", func(d *decode.D) uint64 { channel := uint64(0) - data := vlf(d) - for _, b := range data { - channel <<= 8 - channel |= uint64(b & 0x00ff) + if data, err := vlf(d); err != nil { + d.Errorf("%v", err) + } else { + for _, b := range data { + channel <<= 8 + channel |= uint64(b & 0x00ff) + } } return channel @@ -290,11 +350,14 @@ func decodeMIDIChannelPrefix(d *decode.D) { func decodeMIDIPort(d *decode.D) { d.FieldUintFn("port", func(d *decode.D) uint64 { channel := uint64(0) - data := vlf(d) - for _, b := range data { - channel <<= 8 - channel |= uint64(b & 0x00ff) + if data, err := vlf(d); err != nil { + d.Errorf("%v", err) + } else { + for _, b := range data { + channel <<= 8 + channel |= uint64(b & 0x00ff) + } } return channel @@ -304,11 +367,14 @@ func decodeMIDIPort(d *decode.D) { func decodeTempo(d *decode.D) { d.FieldUintFn("tempo", func(d *decode.D) uint64 { tempo := uint64(0) - data := vlf(d) - for _, b := range data { - tempo <<= 8 - tempo |= uint64(b & 0x00ff) + if data, err := vlf(d); err != nil { + d.Errorf("%v", err) + } else { + for _, b := range data { + tempo <<= 8 + tempo |= uint64(b & 0x00ff) + } } return tempo @@ -318,10 +384,16 @@ func decodeTempo(d *decode.D) { func decodeSMPTEOffset(d *decode.D) { d.FieldStruct("offset", func(d *decode.D) { var data []uint8 - d.FieldStrFn("bytes", func(d *decode.D) string { - data = vlf(d) + var err error - return fmt.Sprintf("%v", data) + d.FieldStrFn("bytes", func(d *decode.D) string { + if data, err = vlf(d); err != nil { + d.Errorf("%v", err) + } else { + return fmt.Sprintf("%v", data) + } + + return "[]" }) if len(data) > 0 { @@ -353,10 +425,16 @@ func decodeSMPTEOffset(d *decode.D) { func decodeTimeSignature(d *decode.D) { d.FieldStruct("signature", func(d *decode.D) { var data []uint8 - d.FieldStrFn("bytes", func(d *decode.D) string { - data = vlf(d) + var err error - return fmt.Sprintf("%v", data) + d.FieldStrFn("bytes", func(d *decode.D) string { + if data, err = vlf(d); err != nil { + d.Errorf("%v", err) + } else { + return fmt.Sprintf("%v", data) + } + + return "[]" }) if len(data) > 0 { @@ -384,17 +462,20 @@ func decodeTimeSignature(d *decode.D) { func decodeKeySignature(d *decode.D) { d.FieldUintFn("key", func(d *decode.D) uint64 { - data := vlf(d) key := uint64(0) - if len(data) > 0 { - key <<= 8 - key |= uint64(data[0]) & 0x00ff - } + if data, err := vlf(d); err != nil { + d.Errorf("%v", err) + } else { + if len(data) > 0 { + key <<= 8 + key |= uint64(data[0]) & 0x00ff + } - if len(data) > 1 { - key <<= 8 - key |= uint64(data[1]) & 0x00ff + if len(data) > 1 { + key <<= 8 + key |= uint64(data[1]) & 0x00ff + } } return key @@ -404,17 +485,31 @@ func decodeKeySignature(d *decode.D) { func decodeEndOfTrack(d *decode.D) { d.FieldUintFn("length", func(d *decode.D) uint64 { - return uint64(len(vlf(d))) + length := 0 + + if data, err := vlf(d); err != nil { + d.Errorf("%v", err) + } else { + length = len(data) + } + + return uint64(length) }) } func decodeSequencerSpecificEvent(d *decode.D) { d.FieldStruct("info", func(d *decode.D) { var data []uint8 - d.FieldStrFn("bytes", func(d *decode.D) string { - data = vlf(d) + var err error - return fmt.Sprintf("%v", data) + d.FieldStrFn("bytes", func(d *decode.D) string { + if data, err = vlf(d); err != nil { + d.Errorf("%v", err) + } else { + return fmt.Sprintf("%v", data) + } + + return "[]" }) if len(data) > 2 && data[0] == 0x00 { diff --git a/format/midi/midi.go b/format/midi/midi.go index d3dafef3..305dbff2 100644 --- a/format/midi/midi.go +++ b/format/midi/midi.go @@ -27,7 +27,7 @@ func init() { format.MIDI, &decode.Format{ Description: "Standard MIDI file", - Groups: []*decode.Group{format.Probe}, + Groups: []*decode.Group{format.Probe}, DecodeFn: decodeMIDI, }) @@ -38,14 +38,19 @@ func decodeMIDI(d *decode.D) any { d.Endian = decode.BigEndian // ... decode header + println(">> 1") if err := skipTo(d, "MThd"); err != nil { d.Errorf("%v", err) } else { + println(">> 2") d.FieldStruct("header", decodeMThd) // ... decode tracks + println(">> 3") d.FieldArray("tracks", func(d *decode.D) { + println(">> 4") for d.BitsLeft() > 0 { + println(">> 5", d.BitsLeft()) if err := skipTo(d, "MTrk"); err != nil { d.Errorf("%v", err) } else { @@ -188,10 +193,14 @@ func vlq(d *decode.D) uint64 { return vlq } -func vlf(d *decode.D) []uint8 { +func vlf(d *decode.D) ([]uint8, error) { N := int(vlq(d)) - return d.BytesLen(N) + if int64(N*8) > d.BitsLeft() { + return nil, fmt.Errorf("invalid field length") + } else { + return d.BytesLen(N), nil + } } func flush(d *decode.D, format string, args ...any) { diff --git a/format/midi/sysex.go b/format/midi/sysex.go index 970d60f4..d555c173 100644 --- a/format/midi/sysex.go +++ b/format/midi/sysex.go @@ -67,10 +67,16 @@ func decodeSysExEvent(d *decode.D, status uint8, ctx *context) { func decodeSysExMessage(d *decode.D, ctx *context) { var bytes []uint8 + var err error d.FieldStrFn("bytes", func(d *decode.D) string { - bytes = vlf(d) - return fmt.Sprintf("%v", bytes) + if bytes, err = vlf(d); err != nil { + d.Errorf("%v", err) + } else { + return fmt.Sprintf("%v", bytes) + } + + return "[]" }) if len(bytes) < 1 { @@ -102,7 +108,18 @@ func decodeSysExMessage(d *decode.D, ctx *context) { func decodeSysExContinuation(d *decode.D, ctx *context) { d.FieldStrFn("data", func(d *decode.D) string { - data := vlf(d) + var data []uint8 + var err error + + d.FieldStrFn("bytes", func(d *decode.D) string { + if data, err = vlf(d); err != nil { + d.Errorf("%v", err) + } else { + return fmt.Sprintf("%v", data) + } + + return "[]" + }) if len(data) > 0 && data[len(data)-1] == 0xf7 { ctx.casio = false @@ -120,7 +137,18 @@ func decodeSysExContinuation(d *decode.D, ctx *context) { func decodeSysExEscape(d *decode.D, ctx *context) { d.FieldStrFn("data", func(d *decode.D) string { - data := vlf(d) + var data []uint8 + var err error + + d.FieldStrFn("bytes", func(d *decode.D) string { + if data, err = vlf(d); err != nil { + d.Errorf("%v", err) + } else { + return fmt.Sprintf("%v", data) + } + + return "[]" + }) return fmt.Sprintf("%v", data) })