diff --git a/pkg/decode/decode.go b/pkg/decode/decode.go index 9e298f6a..64826956 100644 --- a/pkg/decode/decode.go +++ b/pkg/decode/decode.go @@ -5,6 +5,7 @@ package decode import ( "bytes" "compress/zlib" + "context" "encoding/hex" "fmt" "io/ioutil" @@ -116,12 +117,12 @@ type FormatOptions struct { func (FormatOptions) decodeOptions() {} // Decode try decode formats and return first success and all other decoder errors -func Decode(name string, description string, bb *bitio.Buffer, formats []*Format, opts ...Options) (*Value, interface{}, error) { +func Decode(ctx context.Context, name string, description string, bb *bitio.Buffer, formats []*Format, opts ...Options) (*Value, interface{}, error) { opts = append(opts, DecodeOptions{IsRoot: true}) - return decode(name, description, bb, formats, opts) + return decode(ctx, name, description, bb, formats, opts) } -func decode(name string, description string, bb *bitio.Buffer, formats []*Format, opts []Options) (*Value, interface{}, error) { +func decode(ctx context.Context, name string, description string, bb *bitio.Buffer, formats []*Format, opts []Options) (*Value, interface{}, error) { if formats == nil { panic("formats is nil, failed to register format?") } @@ -144,7 +145,7 @@ func decode(name string, description string, bb *bitio.Buffer, formats []*Format decodeErr := DecodeFormatsError{} for _, f := range formats { - d := NewDecoder(name, description, f, bb, decodeOpts) + d := NewDecoder(ctx, name, description, f, bb, decodeOpts) var decodeV interface{} @@ -152,6 +153,10 @@ func decode(name string, description string, bb *bitio.Buffer, formats []*Format decodeV = f.DecodeFn(d, formatOpts.InArg) }) + if ctx != nil && ctx.Err() != nil { + return nil, nil, ctx.Err() + } + if !rOk { switch panicV := r.RecoverV.(type) { case IOError, ValidateError, DecodeFormatsError: @@ -203,6 +208,7 @@ func decode(name string, description string, bb *bitio.Buffer, formats []*Format } type D struct { + Ctx context.Context Endian Endian Value *Value Options map[string]interface{} @@ -213,10 +219,11 @@ type D struct { } // TODO: new struct decoder? -func NewDecoder(name string, description string, format *Format, bb *bitio.Buffer, opts DecodeOptions) *D { +func NewDecoder(ctx context.Context, name string, description string, format *Format, bb *bitio.Buffer, opts DecodeOptions) *D { cbb := bb.Copy() return &D{ + Ctx: ctx, Endian: BigEndian, Value: &Value{ Name: name, @@ -449,6 +456,7 @@ func (d *D) AddChild(v *Value) { func (d *D) FieldDecoder(name string, bitBuf *bitio.Buffer, v interface{}) *D { return &D{ + Ctx: d.Ctx, Endian: d.Endian, Value: &Value{ Name: name, @@ -856,7 +864,7 @@ func (d *D) DecodeRangeFn(firstBit int64, nBits int64, fn func(d *D)) { func (d *D) Format(formats []*Format, opts ...Options) interface{} { bb := d.BitBufRange(d.Pos(), d.BitsLeft()) opts = append(opts, DecodeOptions{ReadBuf: d.readBuf, IsRoot: false, StartOffset: d.Pos()}) - dv, v, err := decode("", "", bb, formats, opts) + dv, v, err := decode(d.Ctx, "", "", bb, formats, opts) if dv == nil || dv.Errors() != nil { panic(err) } @@ -884,7 +892,7 @@ func (d *D) Format(formats []*Format, opts ...Options) interface{} { func (d *D) FieldTryFormat(name string, formats []*Format, opts ...Options) (*Value, interface{}, error) { bb := d.BitBufRange(d.Pos(), d.BitsLeft()) opts = append(opts, DecodeOptions{ReadBuf: d.readBuf, IsRoot: false, StartOffset: d.Pos()}) - dv, v, err := decode(name, "", bb, formats, opts) + dv, v, err := decode(d.Ctx, name, "", bb, formats, opts) if dv == nil || dv.Errors() != nil { return nil, nil, err } @@ -908,7 +916,7 @@ func (d *D) FieldFormat(name string, formats []*Format, opts ...Options) (*Value func (d *D) FieldTryFormatLen(name string, nBits int64, formats []*Format, opts ...Options) (*Value, interface{}, error) { bb := d.BitBufRange(d.Pos(), nBits) opts = append(opts, DecodeOptions{ReadBuf: d.readBuf, IsRoot: false, StartOffset: d.Pos()}) - dv, v, err := decode(name, "", bb, formats, opts) + dv, v, err := decode(d.Ctx, name, "", bb, formats, opts) if dv == nil || dv.Errors() != nil { return nil, nil, err } @@ -933,7 +941,7 @@ func (d *D) FieldFormatLen(name string, nBits int64, formats []*Format, opts ... func (d *D) FieldTryFormatRange(name string, firstBit int64, nBits int64, formats []*Format, opts ...Options) (*Value, interface{}, error) { bb := d.BitBufRange(firstBit, nBits) opts = append(opts, DecodeOptions{ReadBuf: d.readBuf, IsRoot: false, StartOffset: firstBit}) - dv, v, err := decode(name, "", bb, formats, opts) + dv, v, err := decode(d.Ctx, name, "", bb, formats, opts) if dv == nil || dv.Errors() != nil { return nil, nil, err } @@ -954,7 +962,7 @@ func (d *D) FieldFormatRange(name string, firstBit int64, nBits int64, formats [ func (d *D) FieldTryFormatBitBuf(name string, bb *bitio.Buffer, formats []*Format, opts ...Options) (*Value, interface{}, error) { opts = append(opts, DecodeOptions{ReadBuf: d.readBuf, IsRoot: true}) - dv, v, err := decode(name, "", bb, formats, opts) + dv, v, err := decode(d.Ctx, name, "", bb, formats, opts) if dv == nil || dv.Errors() != nil { return nil, nil, err } diff --git a/pkg/interp/funcs.go b/pkg/interp/funcs.go index 5acee03f..60c65ef7 100644 --- a/pkg/interp/funcs.go +++ b/pkg/interp/funcs.go @@ -592,7 +592,7 @@ func (i *Interp) _decode(c interface{}, a []interface{}) interface{} { return fmt.Errorf("%s: %w", formatName, err) } - dv, _, err := decode.Decode("", filename, bb, decodeFormats, decode.DecodeOptions{FormatOptions: opts}) + dv, _, err := decode.Decode(i.evalContext.ctx, "", filename, bb, decodeFormats, decode.DecodeOptions{FormatOptions: opts}) if dv == nil { var decodeFormatsErr decode.DecodeFormatsError if errors.As(err, &decodeFormatsErr) {