[protobuf] Fix parsing of oneof fields (#302)

Closes #297 

In protobuf, unions themselves are optional, so while parsing if any value fails
to parse, its default must not be considered. When the default is considered and
first value in the union is not specified, the union still gets parsed as the
default of the first value. So, in this instance:

oneof foo {
  int32 foo_int = 1;
  string foo_str = 2;
}

foo_str will never be parsed, as foo_int will have a default, i.e. 0 and so
foo_str "blah" will always wrongly parse as foo_int 0.

Co-authored-by: Akshay Mankar <itsakshaymankar@gmail.com>
This commit is contained in:
Alejandro Serrano 2021-04-30 13:17:11 +02:00 committed by GitHub
parent 28ef6fc97d
commit 1d8319ef98
7 changed files with 176 additions and 83 deletions

View File

@ -611,7 +611,5 @@ instance ( ProtoBridgeOneFieldValue sch t, KnownNat thisId
protoToUnionFieldValue
= Z <$> p <|> S <$> protoToUnionFieldValue @_ @_ @restIds
where fieldId = fromInteger $ natVal (Proxy @thisId)
p = case defaultOneFieldValue of
Nothing -> do r <- one (Just <$> protoToOneFieldValue) Nothing `at` fieldId
maybe empty pure r
Just d -> one protoToOneFieldValue d `at` fieldId <|> pure d
p = do r <- one (Just <$> protoToOneFieldValue) Nothing `at` fieldId
maybe empty pure r

View File

@ -1,11 +1,17 @@
{-# language DataKinds #-}
{-# language DeriveAnyClass #-}
{-# language DeriveGeneric #-}
{-# language DerivingStrategies #-}
{-# language OverloadedStrings #-}
{-# language ScopedTypeVariables #-}
{-# language TypeApplications #-}
{-# language TypeFamilies #-}
{-# language CPP #-}
{-# language DataKinds #-}
{-# language DeriveAnyClass #-}
{-# language DeriveGeneric #-}
{-# language DerivingVia #-}
{-# language EmptyCase #-}
{-# language FlexibleInstances #-}
{-# language MultiParamTypeClasses #-}
{-# language OverloadedStrings #-}
{-# language ScopedTypeVariables #-}
{-# language TemplateHaskell #-}
{-# language TypeApplications #-}
{-# language TypeFamilies #-}
{-# language TypeOperators #-}
module Main where
import qualified Data.ByteString as BS
@ -17,22 +23,43 @@ import qualified Proto3.Wire.Decode as PBDec
import qualified Proto3.Wire.Encode as PBEnc
import System.Environment
import Data.Int
import Mu.Adapter.ProtoBuf
import Mu.Quasi.ProtoBuf
import Mu.Schema
import Mu.Schema.Examples
#if __GHCIDE__
protobuf "ExampleSchema" "adapter/protobuf/test/protobuf/example.proto"
#else
protobuf "ExampleSchema" "test/protobuf/example.proto"
#endif
data MGender = NB | Male | Female
deriving (Eq, Show, Generic)
deriving (ToSchema ExampleSchema "gender", FromSchema ExampleSchema "gender")
via CustomFieldMapping "gender"
["NB" ':-> "nb", "Male" ':-> "male", "Female" ':-> "female" ] MGender
data MPerson
= MPerson { firstName :: T.Text
, lastName :: T.Text
, age :: Maybe Int
, gender :: Gender
, address :: MAddress
, lucky_numbers :: [Int]
, things :: M.Map T.Text Int }
, age :: Int32
, gender :: MGender
, address :: Maybe MAddress
, lucky_numbers :: [Int32]
, things :: M.Map T.Text Int32
, foo :: Maybe MFoo
}
deriving (Eq, Show, Generic)
deriving (ToSchema ExampleSchema "person")
deriving (FromSchema ExampleSchema "person")
newtype MFoo
= MFoo { fooChoice :: Either Int32 T.Text }
deriving (Eq, Show, Generic)
deriving (ToSchema ExampleSchema "Foo")
deriving (FromSchema ExampleSchema "Foo")
data MAddress
= MAddress { postcode :: T.Text
, country :: T.Text }
@ -40,42 +67,19 @@ data MAddress
deriving (ToSchema ExampleSchema "address")
deriving (FromSchema ExampleSchema "address")
type instance AnnotatedSchema ProtoBufAnnotation ExampleSchema
= '[ 'AnnField "gender" "male" ('ProtoBufId 1 '[])
, 'AnnField "gender" "female" ('ProtoBufId 2 '[])
, 'AnnField "gender" "nb" ('ProtoBufId 3 '[])
, 'AnnField "gender" "gender0" ('ProtoBufId 4 '[])
, 'AnnField "gender" "gender1" ('ProtoBufId 5 '[])
, 'AnnField "gender" "gender2" ('ProtoBufId 6 '[])
, 'AnnField "gender" "gender3" ('ProtoBufId 7 '[])
, 'AnnField "gender" "gender4" ('ProtoBufId 8 '[])
, 'AnnField "gender" "gender5" ('ProtoBufId 9 '[])
, 'AnnField "gender" "gender6" ('ProtoBufId 10 '[])
, 'AnnField "gender" "gender7" ('ProtoBufId 11 '[])
, 'AnnField "gender" "gender8" ('ProtoBufId 12 '[])
, 'AnnField "gender" "gender9" ('ProtoBufId 13 '[])
, 'AnnField "gender" "unspecified" ('ProtoBufId 0 '[])
, 'AnnField "address" "postcode" ('ProtoBufId 1 '[])
, 'AnnField "address" "country" ('ProtoBufId 2 '[])
, 'AnnField "person" "firstName" ('ProtoBufId 1 '[])
, 'AnnField "person" "lastName" ('ProtoBufId 2 '[])
, 'AnnField "person" "age" ('ProtoBufId 3 '[])
, 'AnnField "person" "gender" ('ProtoBufId 4 '[])
, 'AnnField "person" "address" ('ProtoBufId 5 '[])
, 'AnnField "person" "lucky_numbers" ('ProtoBufId 6 '[ '("packed", 'ProtoBufOptionConstantBool 'True) ])
, 'AnnField "person" "things" ('ProtoBufId 7 '[]) ]
exampleAddress :: MAddress
exampleAddress = MAddress "1111BB" "Spain"
exampleAddress :: Maybe MAddress
exampleAddress = Just $ MAddress "0000AA" "Nederland"
examplePerson1, examplePerson2 :: MPerson
examplePerson1 = MPerson "Haskellio" "Gómez"
(Just 30) Male
examplePerson1 = MPerson "Pythonio" "van Gogh"
30 Male
exampleAddress [1,2,3]
(M.fromList [("pepe", 1), ("juan", 2)])
(M.fromList [("hola", 1), ("hello", 2), ("hallo", 3)])
(Just $ MFoo $ Right "blah")
examplePerson2 = MPerson "Cuarenta" "Siete"
Nothing Unspecified
0 NB
exampleAddress [] M.empty
(Just $ MFoo $ Left 3)
main :: IO ()
main = do -- Obtain the filenames
@ -83,8 +87,12 @@ main = do -- Obtain the filenames
-- Read the file produced by Python
putStrLn "haskell/consume"
cbs <- BS.readFile conFile
let Right people = PBDec.parse (fromProtoViaSchema @_ @_ @ExampleSchema) cbs
print (people :: MPerson)
let Right parsedPerson1 = PBDec.parse (fromProtoViaSchema @_ @_ @ExampleSchema) cbs
if parsedPerson1 == examplePerson1
then putStrLn $ "Parsed correctly as: \n" <> show parsedPerson1
else putStrLn $ "Parsed person does not match expected person\n"
<> "Parsed person: \n" <> show parsedPerson1
<> "\nExpected person: \n" <> show examplePerson1
-- Encode a couple of values
putStrLn "haskell/generate"
print examplePerson1

View File

@ -1,7 +1,8 @@
from example_pb2 import *
import sys
f = open(sys.argv[1], "rb")
example_person = person()
example_person.ParseFromString(f.read())
f.close()
print(example_person)
print(example_person)

View File

@ -8,6 +8,7 @@ message person {
address address = 5;
repeated int32 lucky_numbers = 6 [packed=true];
map<string, int32> things = 7;
Foo foo = 8;
}
message address {
@ -20,3 +21,10 @@ enum gender {
male = 1;
female = 2;
}
message Foo {
oneof fooChoice {
int32 foo_int = 1;
string foo_string = 2;
}
}

View File

@ -1,9 +1,7 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: example.proto
import sys
_b=sys.version_info[0]<3 and (lambda x:x) or (lambda x:x.encode('latin1'))
"""Generated protocol buffer code."""
from google.protobuf.internal import enum_type_wrapper
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
@ -21,7 +19,8 @@ DESCRIPTOR = _descriptor.FileDescriptor(
package='',
syntax='proto3',
serialized_options=None,
serialized_pb=_b('\n\rexample.proto\"\xdd\x01\n\x06person\x12\x11\n\tfirstName\x18\x01 \x01(\t\x12\x10\n\x08lastName\x18\x02 \x01(\t\x12\x0b\n\x03\x61ge\x18\x03 \x01(\x05\x12\x17\n\x06gender\x18\x04 \x01(\x0e\x32\x07.gender\x12\x19\n\x07\x61\x64\x64ress\x18\x05 \x01(\x0b\x32\x08.address\x12\x19\n\rlucky_numbers\x18\x06 \x03(\x05\x42\x02\x10\x01\x12#\n\x06things\x18\x07 \x03(\x0b\x32\x13.person.ThingsEntry\x1a-\n\x0bThingsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x05:\x02\x38\x01\",\n\x07\x61\x64\x64ress\x12\x10\n\x08postcode\x18\x01 \x01(\t\x12\x0f\n\x07\x63ountry\x18\x02 \x01(\t*&\n\x06gender\x12\x06\n\x02nb\x10\x00\x12\x08\n\x04male\x10\x01\x12\n\n\x06\x66\x65male\x10\x02\x62\x06proto3')
create_key=_descriptor._internal_create_key,
serialized_pb=b'\n\rexample.proto\"\xf0\x01\n\x06person\x12\x11\n\tfirstName\x18\x01 \x01(\t\x12\x10\n\x08lastName\x18\x02 \x01(\t\x12\x0b\n\x03\x61ge\x18\x03 \x01(\x05\x12\x17\n\x06gender\x18\x04 \x01(\x0e\x32\x07.gender\x12\x19\n\x07\x61\x64\x64ress\x18\x05 \x01(\x0b\x32\x08.address\x12\x19\n\rlucky_numbers\x18\x06 \x03(\x05\x42\x02\x10\x01\x12#\n\x06things\x18\x07 \x03(\x0b\x32\x13.person.ThingsEntry\x12\x11\n\x03\x66oo\x18\x08 \x01(\x0b\x32\x04.Foo\x1a-\n\x0bThingsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\x05:\x02\x38\x01\",\n\x07\x61\x64\x64ress\x12\x10\n\x08postcode\x18\x01 \x01(\t\x12\x0f\n\x07\x63ountry\x18\x02 \x01(\t\"5\n\x03\x46oo\x12\x11\n\x07\x66oo_int\x18\x01 \x01(\x05H\x00\x12\x14\n\nfoo_string\x18\x02 \x01(\tH\x00\x42\x05\n\x03\x46oo*&\n\x06gender\x12\x06\n\x02nb\x10\x00\x12\x08\n\x04male\x10\x01\x12\n\n\x06\x66\x65male\x10\x02\x62\x06proto3'
)
_GENDER = _descriptor.EnumDescriptor(
@ -29,24 +28,28 @@ _GENDER = _descriptor.EnumDescriptor(
full_name='gender',
filename=None,
file=DESCRIPTOR,
create_key=_descriptor._internal_create_key,
values=[
_descriptor.EnumValueDescriptor(
name='nb', index=0, number=0,
serialized_options=None,
type=None),
type=None,
create_key=_descriptor._internal_create_key),
_descriptor.EnumValueDescriptor(
name='male', index=1, number=1,
serialized_options=None,
type=None),
type=None,
create_key=_descriptor._internal_create_key),
_descriptor.EnumValueDescriptor(
name='female', index=2, number=2,
serialized_options=None,
type=None),
type=None,
create_key=_descriptor._internal_create_key),
],
containing_type=None,
serialized_options=None,
serialized_start=287,
serialized_end=325,
serialized_start=361,
serialized_end=399,
)
_sym_db.RegisterEnumDescriptor(_GENDER)
@ -63,35 +66,36 @@ _PERSON_THINGSENTRY = _descriptor.Descriptor(
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='key', full_name='person.ThingsEntry.key', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='value', full_name='person.ThingsEntry.value', index=1,
number=2, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=_b('8\001'),
serialized_options=b'8\001',
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
],
serialized_start=194,
serialized_end=239,
serialized_start=213,
serialized_end=258,
)
_PERSON = _descriptor.Descriptor(
@ -100,56 +104,64 @@ _PERSON = _descriptor.Descriptor(
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='firstName', full_name='person.firstName', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='lastName', full_name='person.lastName', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='age', full_name='person.age', index=2,
number=3, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='gender', full_name='person.gender', index=3,
number=4, type=14, cpp_type=8, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='address', full_name='person.address', index=4,
number=5, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='lucky_numbers', full_name='person.lucky_numbers', index=5,
number=6, type=5, cpp_type=1, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=_b('\020\001'), file=DESCRIPTOR),
serialized_options=b'\020\001', file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='things', full_name='person.things', index=6,
number=7, type=11, cpp_type=10, label=3,
has_default_value=False, default_value=[],
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='foo', full_name='person.foo', index=7,
number=8, type=11, cpp_type=10, label=1,
has_default_value=False, default_value=None,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
@ -163,7 +175,7 @@ _PERSON = _descriptor.Descriptor(
oneofs=[
],
serialized_start=18,
serialized_end=239,
serialized_end=258,
)
@ -173,21 +185,22 @@ _ADDRESS = _descriptor.Descriptor(
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='postcode', full_name='address.postcode', index=0,
number=1, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='country', full_name='address.country', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=_b("").decode('utf-8'),
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR),
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
@ -200,16 +213,68 @@ _ADDRESS = _descriptor.Descriptor(
extension_ranges=[],
oneofs=[
],
serialized_start=241,
serialized_end=285,
serialized_start=260,
serialized_end=304,
)
_FOO = _descriptor.Descriptor(
name='Foo',
full_name='Foo',
filename=None,
file=DESCRIPTOR,
containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[
_descriptor.FieldDescriptor(
name='foo_int', full_name='Foo.foo_int', index=0,
number=1, type=5, cpp_type=1, label=1,
has_default_value=False, default_value=0,
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
_descriptor.FieldDescriptor(
name='foo_string', full_name='Foo.foo_string', index=1,
number=2, type=9, cpp_type=9, label=1,
has_default_value=False, default_value=b"".decode('utf-8'),
message_type=None, enum_type=None, containing_type=None,
is_extension=False, extension_scope=None,
serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key),
],
extensions=[
],
nested_types=[],
enum_types=[
],
serialized_options=None,
is_extendable=False,
syntax='proto3',
extension_ranges=[],
oneofs=[
_descriptor.OneofDescriptor(
name='Foo', full_name='Foo.Foo',
index=0, containing_type=None,
create_key=_descriptor._internal_create_key,
fields=[]),
],
serialized_start=306,
serialized_end=359,
)
_PERSON_THINGSENTRY.containing_type = _PERSON
_PERSON.fields_by_name['gender'].enum_type = _GENDER
_PERSON.fields_by_name['address'].message_type = _ADDRESS
_PERSON.fields_by_name['things'].message_type = _PERSON_THINGSENTRY
_PERSON.fields_by_name['foo'].message_type = _FOO
_FOO.oneofs_by_name['Foo'].fields.append(
_FOO.fields_by_name['foo_int'])
_FOO.fields_by_name['foo_int'].containing_oneof = _FOO.oneofs_by_name['Foo']
_FOO.oneofs_by_name['Foo'].fields.append(
_FOO.fields_by_name['foo_string'])
_FOO.fields_by_name['foo_string'].containing_oneof = _FOO.oneofs_by_name['Foo']
DESCRIPTOR.message_types_by_name['person'] = _PERSON
DESCRIPTOR.message_types_by_name['address'] = _ADDRESS
DESCRIPTOR.message_types_by_name['Foo'] = _FOO
DESCRIPTOR.enum_types_by_name['gender'] = _GENDER
_sym_db.RegisterFileDescriptor(DESCRIPTOR)
@ -235,6 +300,13 @@ address = _reflection.GeneratedProtocolMessageType('address', (_message.Message,
})
_sym_db.RegisterMessage(address)
Foo = _reflection.GeneratedProtocolMessageType('Foo', (_message.Message,), {
'DESCRIPTOR' : _FOO,
'__module__' : 'example_pb2'
# @@protoc_insertion_point(class_scope:Foo)
})
_sym_db.RegisterMessage(Foo)
_PERSON_THINGSENTRY._options = None
_PERSON.fields_by_name['lucky_numbers']._options = None

View File

@ -1,9 +1,13 @@
from example_pb2 import *
import sys
example_address = address()
example_address.postcode = "0000AA"
example_address.country = "Nederland"
example_foo = Foo()
example_foo.foo_string = "blah"
example_person = person()
example_person.firstName = "Pythonio"
example_person.lastName = "van Gogh"
@ -15,6 +19,7 @@ example_person.address.CopyFrom(example_address)
example_person.things["hola"] = 1
example_person.things["hello"] = 2
example_person.things["hallo"] = 3
example_person.foo.CopyFrom(example_foo)
f = open(sys.argv[1], "wb")
f.write(example_person.SerializeToString())

View File

@ -10,6 +10,7 @@ stack test-avro dist/avro-haskell.avro dist/avro-python.avro
echo "ptyhon/consume"
python3 adapter/avro/test/avro/consume.py adapter/avro/test/avro/example.avsc dist/avro-haskell.avro
# if protobuf is not installed, do so with 'pip install protobuf'
echo "\nPROTOBUF\n========\n"
echo "python/generate"
python2 adapter/protobuf/test/protobuf/generate.py dist/protobuf-python.pbuf