impl mul for int types. fix alloc

This commit is contained in:
collin 2020-07-11 04:00:42 -07:00
parent 615c3a763a
commit 5d25770b72
9 changed files with 338 additions and 48 deletions

View File

@ -20,7 +20,7 @@ where
} }
// Generic impl // Generic impl
impl RippleCarryAdder for &[Boolean] { impl RippleCarryAdder for Vec<Boolean> {
fn add_bits<F: PrimeField, CS: ConstraintSystem<F>>( fn add_bits<F: PrimeField, CS: ConstraintSystem<F>>(
&self, &self,
mut cs: CS, mut cs: CS,

View File

@ -25,13 +25,12 @@ macro_rules! add_int_impl {
impl Add for $gadget { impl Add for $gadget {
fn add<F: PrimeField, CS: ConstraintSystem<F>>(&self, mut cs: CS, other: &Self) -> Result<Self, IntegerError> { fn add<F: PrimeField, CS: ConstraintSystem<F>>(&self, mut cs: CS, other: &Self) -> Result<Self, IntegerError> {
// Compute the maximum value of the sum // Compute the maximum value of the sum
let mut max_bits = <$gadget as Int>::SIZE; let max_bits = <$gadget as Int>::SIZE;
// Make some arbitrary bounds for ourselves to avoid overflows // Make some arbitrary bounds for ourselves to avoid overflows
// in the scalar field // in the scalar field
assert!(F::Params::MODULUS_BITS >= max_bits as u32); assert!(F::Params::MODULUS_BITS >= max_bits as u32);
// Accumulate the value // Accumulate the value
let result_value = match (self.value, other.value) { let result_value = match (self.value, other.value) {
(Some(a), Some(b)) => { (Some(a), Some(b)) => {
@ -60,7 +59,7 @@ macro_rules! add_int_impl {
// we discard the carry since we check for overflow above // we discard the carry since we check for overflow above
let _carry = bits.pop(); let _carry = bits.pop();
// Iterate over each bit_gadget of self and add each bit to // Iterate over each bit_gadget of result and add each bit to
// the linear combination // the linear combination
let mut coeff = F::one(); let mut coeff = F::one();
for bit in bits { for bit in bits {
@ -87,6 +86,7 @@ macro_rules! add_int_impl {
coeff.double_in_place(); coeff.double_in_place();
} }
// The value of the actual result is modulo 2 ^ $size // The value of the actual result is modulo 2 ^ $size
let modular_value = result_value.map(|v| v as <$gadget as Int>::IntegerType); let modular_value = result_value.map(|v| v as <$gadget as Int>::IntegerType);
@ -102,11 +102,13 @@ macro_rules! add_int_impl {
// Allocate each bit_gadget of the result // Allocate each bit_gadget of the result
let mut coeff = F::one(); let mut coeff = F::one();
let mut i = 0; for i in 0..max_bits {
while max_bits != 0 { // get bit value
let mask = 1 << i as <$gadget as Int>::IntegerType;
// Allocate the bit_gadget // Allocate the bit_gadget
let b = AllocatedBit::alloc(cs.ns(|| format!("result bit_gadget {}", i)), || { let b = AllocatedBit::alloc(cs.ns(|| format!("result bit_gadget {}", i)), || {
result_value.map(|v| (v >> i) & 1 == 1).get() result_value.map(|v| (v & mask) == mask).get()
})?; })?;
// Subtract this bit_gadget from the linear combination to ensure that the sums // Subtract this bit_gadget from the linear combination to ensure that the sums
@ -115,8 +117,6 @@ macro_rules! add_int_impl {
result_bits.push(b.into()); result_bits.push(b.into());
max_bits -= 1;
i += 1;
coeff.double_in_place(); coeff.double_in_place();
} }

View File

@ -1,7 +1,24 @@
use crate::{binary::RippleCarryAdder, errors::IntegerError, sign_extend::SignExtend, Int, Int16, Int32, Int64, Int8}; use crate::{
binary::RippleCarryAdder,
errors::IntegerError,
sign_extend::SignExtend,
Int,
Int128,
Int16,
Int32,
Int64,
Int8,
};
use snarkos_models::{ use snarkos_models::{
curves::PrimeField, curves::{FpParameters, PrimeField},
gadgets::{r1cs::ConstraintSystem, utilities::boolean::Boolean}, gadgets::{
r1cs::{Assignment, ConstraintSystem, LinearCombination},
utilities::{
alloc::AllocGadget,
boolean::{AllocatedBit, Boolean},
select::CondSelectGadget,
},
},
}; };
/// Multiplication for a signed integer gadget /// Multiplication for a signed integer gadget
@ -13,32 +30,165 @@ where
Self: std::marker::Sized, Self: std::marker::Sized,
{ {
#[must_use] #[must_use]
fn mul<F: PrimeField, CS: ConstraintSystem<F>>(&self, cs: CS, other: &Self) -> Result<(), IntegerError>; fn mul<F: PrimeField, CS: ConstraintSystem<F>>(&self, cs: CS, other: &Self) -> Result<Self, IntegerError>;
} }
macro_rules! mul_int_impl { macro_rules! mul_int_impl {
($($gadget: ident)*) => ($( ($($gadget: ident)*) => ($(
impl Mul for $gadget { impl Mul for $gadget {
fn mul<F: PrimeField, CS: ConstraintSystem<F>>(&self, cs: CS, other: &Self) -> Result<(), IntegerError> { fn mul<F: PrimeField, CS: ConstraintSystem<F>>(&self, mut cs: CS, other: &Self) -> Result<Self, IntegerError> {
// let is_constant = Boolean::constant(Self::result_is_constant(&self, &other)); // Conditionally select constant result
// let constant_result = Self::constant(0 as <$gadget as Int>::) let is_constant = Boolean::constant(Self::result_is_constant(&self, &other));
// let allocated_false = Boolean::from(AllocatedBit::alloc(&mut cs.ns(|| "false"), || Ok(false)).unwrap());
// let double = <$gadget as Int>::SIZE * 2; let false_bit = Boolean::conditionally_select(
// &mut cs.ns(|| "constant_or_allocated_false"),
// let a = Boolean::sign_extend(&self.bits, double); &is_constant,
// let b = Boolean::sign_extend(&other.bits, double); &Boolean::constant(false),
// &allocated_false,
// let result = )?;
//
// for bit in b.iter() {
//
// }
Ok(()) // Sign extend to double precision
let size = <$gadget as Int>::SIZE * 2;
let a = Boolean::sign_extend(&self.bits, size);
let b = Boolean::sign_extend(&other.bits, size);
let mut bits = vec![false_bit; size];
// Compute double and add algorithm
for (i, b_bit) in b.iter().enumerate() {
// double
let mut a_shifted = vec![false_bit.clone(); i];
a_shifted.append(&mut a.clone());
a_shifted.truncate(size);
// conditionally add
let mut to_add = vec![];
for (j, a_bit) in a_shifted.iter().enumerate() {
let selected_bit = Boolean::conditionally_select(
&mut cs.ns(|| format!("select product bit {} {}", i, j)),
b_bit,
a_bit,
&false_bit,
)?;
to_add.push(selected_bit);
}
bits = bits.add_bits(
&mut cs.ns(|| format!("add bit {}", i)),
&to_add
)?;
let _carry = bits.pop();
}
// Compute the maximum value of the sum
let max_bits = <$gadget as Int>::SIZE;
// Truncate the bits to the size of the integer
bits.truncate(max_bits);
// Make some arbitrary bounds for ourselves to avoid overflows
// in the scalar field
assert!(F::Params::MODULUS_BITS >= max_bits as u32);
// Accumulate the value
let result_value = match (self.value, other.value) {
(Some(a), Some(b)) => {
// check for addition overflow here
let val = match a.checked_mul(b) {
Some(val) => val,
None => return Err(IntegerError::Overflow)
};
Some(val)
},
_ => {
// If any of the operands have unknown value, we won't
// know the value of the result
None
}
};
// This is a linear combination that we will enforce to be zero
let mut lc = LinearCombination::zero();
let mut all_constants = true;
// Iterate over each bit_gadget of result and add each bit to
// the linear combination
let mut coeff = F::one();
for bit in bits {
match bit {
Boolean::Is(ref bit) => {
all_constants = false;
// Add the coeff * bit_gadget
lc = lc + (coeff, bit.get_variable());
}
Boolean::Not(ref bit) => {
all_constants = false;
// Add coeff * (1 - bit_gadget) = coeff * ONE - coeff * bit_gadget
lc = lc + (coeff, CS::one()) - (coeff, bit.get_variable());
}
Boolean::Constant(bit) => {
if bit {
lc = lc + (coeff, CS::one());
}
}
}
coeff.double_in_place();
}
// The value of the actual result is modulo 2 ^ $size
let modular_value = result_value.map(|v| v as <$gadget as Int>::IntegerType);
if all_constants && modular_value.is_some() {
// We can just return a constant, rather than
// unpacking the result into allocated bits.
return Ok(Self::constant(modular_value.unwrap()));
}
// Storage area for the resulting bits
let mut result_bits = vec![];
// Allocate each bit_gadget of the result
let mut coeff = F::one();
for i in 0..max_bits {
// get bit value
let mask = 1 << i as <$gadget as Int>::IntegerType;
// Allocate the bit_gadget
let b = AllocatedBit::alloc(cs.ns(|| format!("result bit_gadget {}", i)), || {
result_value.map(|v| (v & mask) == mask).get()
})?;
// Subtract this bit_gadget from the linear combination to ensure that the sums
// balance out
lc = lc - (coeff, b.get_variable());
result_bits.push(b.into());
coeff.double_in_place();
}
// Enforce that the linear combination equals zero
cs.enforce(|| "modular multiplication", |lc| lc, |lc| lc, |_| lc);
// Discard carry bits we don't care about
result_bits.truncate(<$gadget as Int>::SIZE);
Ok(Self {
bits: result_bits,
value: modular_value,
})
} }
} }
)*) )*)
} }
// mul_int_impl!(Int8 Int16 Int32 Int64); mul_int_impl!(Int8 Int16 Int32 Int64 Int128);
mul_int_impl!(Int8);

View File

@ -1,6 +1,9 @@
use crate::{binary::RippleCarryAdder, errors::IntegerError, signed_integer::*}; use crate::{binary::RippleCarryAdder, errors::IntegerError, signed_integer::*};
use snarkos_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem}; use snarkos_models::{
curves::PrimeField,
gadgets::{r1cs::ConstraintSystem, utilities::boolean::Boolean},
};
/// Inverts the given number and adds 1 to the lsb of the result /// Inverts the given number and adds 1 to the lsb of the result
pub trait TwosComplement pub trait TwosComplement
@ -11,6 +14,22 @@ where
fn twos_comp<F: PrimeField, CS: ConstraintSystem<F>>(&self, cs: CS) -> Result<Self, IntegerError>; fn twos_comp<F: PrimeField, CS: ConstraintSystem<F>>(&self, cs: CS) -> Result<Self, IntegerError>;
} }
impl TwosComplement for Vec<Boolean> {
fn twos_comp<F: PrimeField, CS: ConstraintSystem<F>>(&self, mut cs: CS) -> Result<Self, IntegerError> {
// flip all bits
let flipped: Self = self.iter().map(|bit| bit.not()).collect();
// add one
let mut one = vec![Boolean::constant(true)];
one.append(&mut vec![Boolean::Constant(false); self.len() - 1]);
let mut bits = flipped.add_bits(cs.ns(|| format!("add one")), &one)?;
let _carry = bits.pop(); // we already accounted for overflow above
Ok(bits)
}
}
macro_rules! twos_comp_int_impl { macro_rules! twos_comp_int_impl {
($($gadget: ident)*) => ($( ($($gadget: ident)*) => ($(
impl TwosComplement for $gadget { impl TwosComplement for $gadget {

View File

@ -30,17 +30,17 @@ macro_rules! int_impl {
pub fn constant(value: $type_) -> Self { pub fn constant(value: $type_) -> Self {
let mut bits = Vec::with_capacity($size); let mut bits = Vec::with_capacity($size);
let mut tmp = value; for i in 0..$size {
// shift value by i
let mask = 1 << i as $type_;
let result = value & mask;
for _ in 0..$size {
// If last bit is one, push one. // If last bit is one, push one.
if tmp & 1 == 1 { if result == mask {
bits.push(Boolean::constant(true)) bits.push(Boolean::constant(true))
} else { } else {
bits.push(Boolean::constant(false)) bits.push(Boolean::constant(false))
} }
tmp >>= 1;
} }
Self { Self {

View File

@ -7,3 +7,6 @@ pub use self::select::*;
pub mod sign_extend; pub mod sign_extend;
pub use self::sign_extend::*; pub use self::sign_extend::*;
pub mod zero_extend;
pub use self::zero_extend::*;

View File

@ -13,9 +13,10 @@ where
impl SignExtend for Boolean { impl SignExtend for Boolean {
fn sign_extend(bits: &[Boolean], length: usize) -> Vec<Self> { fn sign_extend(bits: &[Boolean], length: usize) -> Vec<Self> {
let msb = bits.last().expect("empty bit list"); let msb = bits.last().expect("empty bit list");
let mut extension = vec![msb.clone(); length]; let bits_needed = length - bits.len();
let mut result = Vec::from(bits); let mut extension = vec![msb.clone(); bits_needed];
let mut result = Vec::from(bits);
result.append(&mut extension); result.append(&mut extension);
result result

View File

@ -0,0 +1,23 @@
use snarkos_models::gadgets::utilities::boolean::Boolean;
/// Zero extends an array of bits to the desired length.
/// Least significant bit first
pub trait ZeroExtend
where
Self: std::marker::Sized,
{
#[must_use]
fn zero_extend(&self, zero: Boolean, length: usize) -> Self;
}
impl ZeroExtend for Vec<Boolean> {
fn zero_extend(&self, zero: Boolean, length: usize) -> Self {
let bits_needed = length - self.len();
let mut extension = vec![zero.clone(); bits_needed];
let mut result = self.clone();
result.append(&mut extension);
result
}
}

View File

@ -10,33 +10,64 @@ use snarkos_models::{
use rand::{Rng, SeedableRng}; use rand::{Rng, SeedableRng};
use rand_xorshift::XorShiftRng; use rand_xorshift::XorShiftRng;
fn check_all_constant_bits(mut expected: i8, actual: Int8) { fn check_all_constant_bits(expected: i8, actual: Int8) {
for b in actual.bits.iter() { for (i, b) in actual.bits.iter().enumerate() {
// shift value by i
let mask = 1 << i as i8;
let result = expected & mask;
match b { match b {
&Boolean::Is(_) => panic!(), &Boolean::Is(_) => panic!(),
&Boolean::Not(_) => panic!(), &Boolean::Not(_) => panic!(),
&Boolean::Constant(b) => { &Boolean::Constant(b) => {
assert!(b == (expected & 1 == 1)); let bit = result == mask;
assert_eq!(b, bit);
} }
} }
expected >>= 1;
} }
} }
fn check_all_allocated_bits(mut expected: i8, actual: Int8) { fn check_all_allocated_bits(expected: i8, actual: Int8) {
for b in actual.bits.iter() { for (i, b) in actual.bits.iter().enumerate() {
// shift value by i
let mask = 1 << i as i8;
let result = expected & mask;
match b { match b {
&Boolean::Is(ref b) => { &Boolean::Is(ref b) => {
assert!(b.get_value().unwrap() == (expected & 1 == 1)); let bit = result == mask;
assert_eq!(b.get_value().unwrap(), bit);
} }
&Boolean::Not(ref b) => { &Boolean::Not(ref b) => {
assert!(!b.get_value().unwrap() == (expected & 1 == 1)); let bit = result == mask;
assert_eq!(!b.get_value().unwrap(), bit);
} }
&Boolean::Constant(_) => unreachable!(), &Boolean::Constant(_) => unreachable!(),
} }
}
}
expected >>= 1; #[test]
fn test_int8_constant_and_alloc() {
let mut rng = XorShiftRng::seed_from_u64(1231275789u64);
for _ in 0..1000 {
let mut cs = TestConstraintSystem::<Fr>::new();
let a: i8 = rng.gen();
let a_const = Int8::constant(a);
assert!(a_const.value == Some(a));
check_all_constant_bits(a, a_const);
let a_bit = Int8::alloc(cs.ns(|| "a_bit"), || Ok(a)).unwrap();
assert!(cs.is_satisfied());
assert!(a_bit.value == Some(a));
check_all_allocated_bits(a, a_bit);
} }
} }
@ -178,3 +209,66 @@ fn test_int8_sub() {
assert!(!cs.is_satisfied()); assert!(!cs.is_satisfied());
} }
} }
#[test]
fn test_int8_mul_constants() {
let mut rng = XorShiftRng::seed_from_u64(1231275789u64);
for _ in 0..1000 {
let mut cs = TestConstraintSystem::<Fr>::new();
let a: i8 = rng.gen();
let b: i8 = rng.gen();
let expected = match a.checked_mul(b) {
Some(valid) => valid,
None => continue,
};
let a_bit = Int8::constant(a);
let b_bit = Int8::constant(b);
let r = a_bit.mul(cs.ns(|| "multiplication"), &b_bit).unwrap();
assert!(r.value == Some(expected));
check_all_constant_bits(expected, r);
}
}
#[test]
fn test_int8_mul() {
let mut rng = XorShiftRng::seed_from_u64(1231275789u64);
for _ in 0..1000 {
let mut cs = TestConstraintSystem::<Fr>::new();
let a: i8 = rng.gen();
let b: i8 = rng.gen();
let expected = match a.checked_mul(b) {
Some(valid) => valid,
None => continue,
};
let a_bit = Int8::alloc(cs.ns(|| "a_bit"), || Ok(a)).unwrap();
let b_bit = Int8::alloc(cs.ns(|| "b_bit"), || Ok(b)).unwrap();
let r = a_bit.mul(cs.ns(|| "multiplication"), &b_bit).unwrap();
assert!(cs.is_satisfied());
assert!(r.value == Some(expected));
check_all_allocated_bits(expected, r);
// Flip a bit_gadget and see if the multiplication constraint still works
if cs.get("multiplication/result bit_gadget 0/boolean").is_zero() {
cs.set("multiplication/result bit_gadget 0/boolean", Fr::one());
} else {
cs.set("multiplication/result bit_gadget 0/boolean", Fr::zero());
}
assert!(!cs.is_satisfied());
}
}