mirror of
https://github.com/ProvableHQ/leo.git
synced 2024-12-25 19:22:01 +03:00
impl mul for int types. fix alloc
This commit is contained in:
parent
615c3a763a
commit
5d25770b72
@ -20,7 +20,7 @@ where
|
||||
}
|
||||
|
||||
// Generic impl
|
||||
impl RippleCarryAdder for &[Boolean] {
|
||||
impl RippleCarryAdder for Vec<Boolean> {
|
||||
fn add_bits<F: PrimeField, CS: ConstraintSystem<F>>(
|
||||
&self,
|
||||
mut cs: CS,
|
||||
|
@ -25,13 +25,12 @@ macro_rules! add_int_impl {
|
||||
impl Add for $gadget {
|
||||
fn add<F: PrimeField, CS: ConstraintSystem<F>>(&self, mut cs: CS, other: &Self) -> Result<Self, IntegerError> {
|
||||
// Compute the maximum value of the sum
|
||||
let mut max_bits = <$gadget as Int>::SIZE;
|
||||
let max_bits = <$gadget as Int>::SIZE;
|
||||
|
||||
// Make some arbitrary bounds for ourselves to avoid overflows
|
||||
// in the scalar field
|
||||
assert!(F::Params::MODULUS_BITS >= max_bits as u32);
|
||||
|
||||
|
||||
// Accumulate the value
|
||||
let result_value = match (self.value, other.value) {
|
||||
(Some(a), Some(b)) => {
|
||||
@ -60,7 +59,7 @@ macro_rules! add_int_impl {
|
||||
// we discard the carry since we check for overflow above
|
||||
let _carry = bits.pop();
|
||||
|
||||
// Iterate over each bit_gadget of self and add each bit to
|
||||
// Iterate over each bit_gadget of result and add each bit to
|
||||
// the linear combination
|
||||
let mut coeff = F::one();
|
||||
for bit in bits {
|
||||
@ -87,6 +86,7 @@ macro_rules! add_int_impl {
|
||||
coeff.double_in_place();
|
||||
}
|
||||
|
||||
|
||||
// The value of the actual result is modulo 2 ^ $size
|
||||
let modular_value = result_value.map(|v| v as <$gadget as Int>::IntegerType);
|
||||
|
||||
@ -102,11 +102,13 @@ macro_rules! add_int_impl {
|
||||
|
||||
// Allocate each bit_gadget of the result
|
||||
let mut coeff = F::one();
|
||||
let mut i = 0;
|
||||
while max_bits != 0 {
|
||||
for i in 0..max_bits {
|
||||
// get bit value
|
||||
let mask = 1 << i as <$gadget as Int>::IntegerType;
|
||||
|
||||
// Allocate the bit_gadget
|
||||
let b = AllocatedBit::alloc(cs.ns(|| format!("result bit_gadget {}", i)), || {
|
||||
result_value.map(|v| (v >> i) & 1 == 1).get()
|
||||
result_value.map(|v| (v & mask) == mask).get()
|
||||
})?;
|
||||
|
||||
// Subtract this bit_gadget from the linear combination to ensure that the sums
|
||||
@ -115,8 +117,6 @@ macro_rules! add_int_impl {
|
||||
|
||||
result_bits.push(b.into());
|
||||
|
||||
max_bits -= 1;
|
||||
i += 1;
|
||||
coeff.double_in_place();
|
||||
}
|
||||
|
||||
|
@ -1,7 +1,24 @@
|
||||
use crate::{binary::RippleCarryAdder, errors::IntegerError, sign_extend::SignExtend, Int, Int16, Int32, Int64, Int8};
|
||||
use crate::{
|
||||
binary::RippleCarryAdder,
|
||||
errors::IntegerError,
|
||||
sign_extend::SignExtend,
|
||||
Int,
|
||||
Int128,
|
||||
Int16,
|
||||
Int32,
|
||||
Int64,
|
||||
Int8,
|
||||
};
|
||||
use snarkos_models::{
|
||||
curves::PrimeField,
|
||||
gadgets::{r1cs::ConstraintSystem, utilities::boolean::Boolean},
|
||||
curves::{FpParameters, PrimeField},
|
||||
gadgets::{
|
||||
r1cs::{Assignment, ConstraintSystem, LinearCombination},
|
||||
utilities::{
|
||||
alloc::AllocGadget,
|
||||
boolean::{AllocatedBit, Boolean},
|
||||
select::CondSelectGadget,
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
/// Multiplication for a signed integer gadget
|
||||
@ -13,32 +30,165 @@ where
|
||||
Self: std::marker::Sized,
|
||||
{
|
||||
#[must_use]
|
||||
fn mul<F: PrimeField, CS: ConstraintSystem<F>>(&self, cs: CS, other: &Self) -> Result<(), IntegerError>;
|
||||
fn mul<F: PrimeField, CS: ConstraintSystem<F>>(&self, cs: CS, other: &Self) -> Result<Self, IntegerError>;
|
||||
}
|
||||
|
||||
macro_rules! mul_int_impl {
|
||||
($($gadget: ident)*) => ($(
|
||||
impl Mul for $gadget {
|
||||
fn mul<F: PrimeField, CS: ConstraintSystem<F>>(&self, cs: CS, other: &Self) -> Result<(), IntegerError> {
|
||||
// let is_constant = Boolean::constant(Self::result_is_constant(&self, &other));
|
||||
// let constant_result = Self::constant(0 as <$gadget as Int>::)
|
||||
//
|
||||
// let double = <$gadget as Int>::SIZE * 2;
|
||||
//
|
||||
// let a = Boolean::sign_extend(&self.bits, double);
|
||||
// let b = Boolean::sign_extend(&other.bits, double);
|
||||
//
|
||||
// let result =
|
||||
//
|
||||
// for bit in b.iter() {
|
||||
//
|
||||
// }
|
||||
fn mul<F: PrimeField, CS: ConstraintSystem<F>>(&self, mut cs: CS, other: &Self) -> Result<Self, IntegerError> {
|
||||
// Conditionally select constant result
|
||||
let is_constant = Boolean::constant(Self::result_is_constant(&self, &other));
|
||||
let allocated_false = Boolean::from(AllocatedBit::alloc(&mut cs.ns(|| "false"), || Ok(false)).unwrap());
|
||||
let false_bit = Boolean::conditionally_select(
|
||||
&mut cs.ns(|| "constant_or_allocated_false"),
|
||||
&is_constant,
|
||||
&Boolean::constant(false),
|
||||
&allocated_false,
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
// Sign extend to double precision
|
||||
let size = <$gadget as Int>::SIZE * 2;
|
||||
|
||||
let a = Boolean::sign_extend(&self.bits, size);
|
||||
let b = Boolean::sign_extend(&other.bits, size);
|
||||
|
||||
let mut bits = vec![false_bit; size];
|
||||
|
||||
// Compute double and add algorithm
|
||||
for (i, b_bit) in b.iter().enumerate() {
|
||||
// double
|
||||
let mut a_shifted = vec![false_bit.clone(); i];
|
||||
a_shifted.append(&mut a.clone());
|
||||
a_shifted.truncate(size);
|
||||
|
||||
// conditionally add
|
||||
let mut to_add = vec![];
|
||||
for (j, a_bit) in a_shifted.iter().enumerate() {
|
||||
let selected_bit = Boolean::conditionally_select(
|
||||
&mut cs.ns(|| format!("select product bit {} {}", i, j)),
|
||||
b_bit,
|
||||
a_bit,
|
||||
&false_bit,
|
||||
)?;
|
||||
|
||||
to_add.push(selected_bit);
|
||||
}
|
||||
|
||||
bits = bits.add_bits(
|
||||
&mut cs.ns(|| format!("add bit {}", i)),
|
||||
&to_add
|
||||
)?;
|
||||
let _carry = bits.pop();
|
||||
}
|
||||
|
||||
// Compute the maximum value of the sum
|
||||
let max_bits = <$gadget as Int>::SIZE;
|
||||
|
||||
// Truncate the bits to the size of the integer
|
||||
bits.truncate(max_bits);
|
||||
|
||||
// Make some arbitrary bounds for ourselves to avoid overflows
|
||||
// in the scalar field
|
||||
assert!(F::Params::MODULUS_BITS >= max_bits as u32);
|
||||
|
||||
// Accumulate the value
|
||||
let result_value = match (self.value, other.value) {
|
||||
(Some(a), Some(b)) => {
|
||||
// check for addition overflow here
|
||||
let val = match a.checked_mul(b) {
|
||||
Some(val) => val,
|
||||
None => return Err(IntegerError::Overflow)
|
||||
};
|
||||
|
||||
Some(val)
|
||||
},
|
||||
_ => {
|
||||
// If any of the operands have unknown value, we won't
|
||||
// know the value of the result
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
// This is a linear combination that we will enforce to be zero
|
||||
let mut lc = LinearCombination::zero();
|
||||
|
||||
let mut all_constants = true;
|
||||
|
||||
|
||||
// Iterate over each bit_gadget of result and add each bit to
|
||||
// the linear combination
|
||||
let mut coeff = F::one();
|
||||
for bit in bits {
|
||||
match bit {
|
||||
Boolean::Is(ref bit) => {
|
||||
all_constants = false;
|
||||
|
||||
// Add the coeff * bit_gadget
|
||||
lc = lc + (coeff, bit.get_variable());
|
||||
}
|
||||
Boolean::Not(ref bit) => {
|
||||
all_constants = false;
|
||||
|
||||
// Add coeff * (1 - bit_gadget) = coeff * ONE - coeff * bit_gadget
|
||||
lc = lc + (coeff, CS::one()) - (coeff, bit.get_variable());
|
||||
}
|
||||
Boolean::Constant(bit) => {
|
||||
if bit {
|
||||
lc = lc + (coeff, CS::one());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
coeff.double_in_place();
|
||||
}
|
||||
|
||||
// The value of the actual result is modulo 2 ^ $size
|
||||
let modular_value = result_value.map(|v| v as <$gadget as Int>::IntegerType);
|
||||
|
||||
if all_constants && modular_value.is_some() {
|
||||
// We can just return a constant, rather than
|
||||
// unpacking the result into allocated bits.
|
||||
|
||||
return Ok(Self::constant(modular_value.unwrap()));
|
||||
}
|
||||
|
||||
// Storage area for the resulting bits
|
||||
let mut result_bits = vec![];
|
||||
|
||||
// Allocate each bit_gadget of the result
|
||||
let mut coeff = F::one();
|
||||
for i in 0..max_bits {
|
||||
// get bit value
|
||||
let mask = 1 << i as <$gadget as Int>::IntegerType;
|
||||
|
||||
// Allocate the bit_gadget
|
||||
let b = AllocatedBit::alloc(cs.ns(|| format!("result bit_gadget {}", i)), || {
|
||||
result_value.map(|v| (v & mask) == mask).get()
|
||||
})?;
|
||||
|
||||
// Subtract this bit_gadget from the linear combination to ensure that the sums
|
||||
// balance out
|
||||
lc = lc - (coeff, b.get_variable());
|
||||
|
||||
result_bits.push(b.into());
|
||||
|
||||
coeff.double_in_place();
|
||||
}
|
||||
|
||||
// Enforce that the linear combination equals zero
|
||||
cs.enforce(|| "modular multiplication", |lc| lc, |lc| lc, |_| lc);
|
||||
|
||||
// Discard carry bits we don't care about
|
||||
result_bits.truncate(<$gadget as Int>::SIZE);
|
||||
|
||||
Ok(Self {
|
||||
bits: result_bits,
|
||||
value: modular_value,
|
||||
})
|
||||
}
|
||||
}
|
||||
)*)
|
||||
}
|
||||
|
||||
// mul_int_impl!(Int8 Int16 Int32 Int64);
|
||||
mul_int_impl!(Int8);
|
||||
mul_int_impl!(Int8 Int16 Int32 Int64 Int128);
|
||||
|
@ -1,6 +1,9 @@
|
||||
use crate::{binary::RippleCarryAdder, errors::IntegerError, signed_integer::*};
|
||||
|
||||
use snarkos_models::{curves::PrimeField, gadgets::r1cs::ConstraintSystem};
|
||||
use snarkos_models::{
|
||||
curves::PrimeField,
|
||||
gadgets::{r1cs::ConstraintSystem, utilities::boolean::Boolean},
|
||||
};
|
||||
|
||||
/// Inverts the given number and adds 1 to the lsb of the result
|
||||
pub trait TwosComplement
|
||||
@ -11,6 +14,22 @@ where
|
||||
fn twos_comp<F: PrimeField, CS: ConstraintSystem<F>>(&self, cs: CS) -> Result<Self, IntegerError>;
|
||||
}
|
||||
|
||||
impl TwosComplement for Vec<Boolean> {
|
||||
fn twos_comp<F: PrimeField, CS: ConstraintSystem<F>>(&self, mut cs: CS) -> Result<Self, IntegerError> {
|
||||
// flip all bits
|
||||
let flipped: Self = self.iter().map(|bit| bit.not()).collect();
|
||||
|
||||
// add one
|
||||
let mut one = vec![Boolean::constant(true)];
|
||||
one.append(&mut vec![Boolean::Constant(false); self.len() - 1]);
|
||||
|
||||
let mut bits = flipped.add_bits(cs.ns(|| format!("add one")), &one)?;
|
||||
let _carry = bits.pop(); // we already accounted for overflow above
|
||||
|
||||
Ok(bits)
|
||||
}
|
||||
}
|
||||
|
||||
macro_rules! twos_comp_int_impl {
|
||||
($($gadget: ident)*) => ($(
|
||||
impl TwosComplement for $gadget {
|
||||
|
@ -30,17 +30,17 @@ macro_rules! int_impl {
|
||||
pub fn constant(value: $type_) -> Self {
|
||||
let mut bits = Vec::with_capacity($size);
|
||||
|
||||
let mut tmp = value;
|
||||
for i in 0..$size {
|
||||
// shift value by i
|
||||
let mask = 1 << i as $type_;
|
||||
let result = value & mask;
|
||||
|
||||
for _ in 0..$size {
|
||||
// If last bit is one, push one.
|
||||
if tmp & 1 == 1 {
|
||||
if result == mask {
|
||||
bits.push(Boolean::constant(true))
|
||||
} else {
|
||||
bits.push(Boolean::constant(false))
|
||||
}
|
||||
|
||||
tmp >>= 1;
|
||||
}
|
||||
|
||||
Self {
|
||||
|
@ -7,3 +7,6 @@ pub use self::select::*;
|
||||
|
||||
pub mod sign_extend;
|
||||
pub use self::sign_extend::*;
|
||||
|
||||
pub mod zero_extend;
|
||||
pub use self::zero_extend::*;
|
||||
|
@ -13,9 +13,10 @@ where
|
||||
impl SignExtend for Boolean {
|
||||
fn sign_extend(bits: &[Boolean], length: usize) -> Vec<Self> {
|
||||
let msb = bits.last().expect("empty bit list");
|
||||
let mut extension = vec![msb.clone(); length];
|
||||
let mut result = Vec::from(bits);
|
||||
let bits_needed = length - bits.len();
|
||||
let mut extension = vec![msb.clone(); bits_needed];
|
||||
|
||||
let mut result = Vec::from(bits);
|
||||
result.append(&mut extension);
|
||||
|
||||
result
|
||||
|
23
gadgets/src/signed_integer/utilities/zero_extend.rs
Normal file
23
gadgets/src/signed_integer/utilities/zero_extend.rs
Normal file
@ -0,0 +1,23 @@
|
||||
use snarkos_models::gadgets::utilities::boolean::Boolean;
|
||||
|
||||
/// Zero extends an array of bits to the desired length.
|
||||
/// Least significant bit first
|
||||
pub trait ZeroExtend
|
||||
where
|
||||
Self: std::marker::Sized,
|
||||
{
|
||||
#[must_use]
|
||||
fn zero_extend(&self, zero: Boolean, length: usize) -> Self;
|
||||
}
|
||||
|
||||
impl ZeroExtend for Vec<Boolean> {
|
||||
fn zero_extend(&self, zero: Boolean, length: usize) -> Self {
|
||||
let bits_needed = length - self.len();
|
||||
let mut extension = vec![zero.clone(); bits_needed];
|
||||
|
||||
let mut result = self.clone();
|
||||
result.append(&mut extension);
|
||||
|
||||
result
|
||||
}
|
||||
}
|
@ -10,33 +10,64 @@ use snarkos_models::{
|
||||
use rand::{Rng, SeedableRng};
|
||||
use rand_xorshift::XorShiftRng;
|
||||
|
||||
fn check_all_constant_bits(mut expected: i8, actual: Int8) {
|
||||
for b in actual.bits.iter() {
|
||||
fn check_all_constant_bits(expected: i8, actual: Int8) {
|
||||
for (i, b) in actual.bits.iter().enumerate() {
|
||||
// shift value by i
|
||||
let mask = 1 << i as i8;
|
||||
let result = expected & mask;
|
||||
|
||||
match b {
|
||||
&Boolean::Is(_) => panic!(),
|
||||
&Boolean::Not(_) => panic!(),
|
||||
&Boolean::Constant(b) => {
|
||||
assert!(b == (expected & 1 == 1));
|
||||
let bit = result == mask;
|
||||
assert_eq!(b, bit);
|
||||
}
|
||||
}
|
||||
|
||||
expected >>= 1;
|
||||
}
|
||||
}
|
||||
|
||||
fn check_all_allocated_bits(mut expected: i8, actual: Int8) {
|
||||
for b in actual.bits.iter() {
|
||||
fn check_all_allocated_bits(expected: i8, actual: Int8) {
|
||||
for (i, b) in actual.bits.iter().enumerate() {
|
||||
// shift value by i
|
||||
let mask = 1 << i as i8;
|
||||
let result = expected & mask;
|
||||
|
||||
match b {
|
||||
&Boolean::Is(ref b) => {
|
||||
assert!(b.get_value().unwrap() == (expected & 1 == 1));
|
||||
let bit = result == mask;
|
||||
assert_eq!(b.get_value().unwrap(), bit);
|
||||
}
|
||||
&Boolean::Not(ref b) => {
|
||||
assert!(!b.get_value().unwrap() == (expected & 1 == 1));
|
||||
let bit = result == mask;
|
||||
assert_eq!(!b.get_value().unwrap(), bit);
|
||||
}
|
||||
&Boolean::Constant(_) => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
expected >>= 1;
|
||||
#[test]
|
||||
fn test_int8_constant_and_alloc() {
|
||||
let mut rng = XorShiftRng::seed_from_u64(1231275789u64);
|
||||
|
||||
for _ in 0..1000 {
|
||||
let mut cs = TestConstraintSystem::<Fr>::new();
|
||||
|
||||
let a: i8 = rng.gen();
|
||||
|
||||
let a_const = Int8::constant(a);
|
||||
|
||||
assert!(a_const.value == Some(a));
|
||||
|
||||
check_all_constant_bits(a, a_const);
|
||||
|
||||
let a_bit = Int8::alloc(cs.ns(|| "a_bit"), || Ok(a)).unwrap();
|
||||
|
||||
assert!(cs.is_satisfied());
|
||||
assert!(a_bit.value == Some(a));
|
||||
|
||||
check_all_allocated_bits(a, a_bit);
|
||||
}
|
||||
}
|
||||
|
||||
@ -178,3 +209,66 @@ fn test_int8_sub() {
|
||||
assert!(!cs.is_satisfied());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_int8_mul_constants() {
|
||||
let mut rng = XorShiftRng::seed_from_u64(1231275789u64);
|
||||
|
||||
for _ in 0..1000 {
|
||||
let mut cs = TestConstraintSystem::<Fr>::new();
|
||||
|
||||
let a: i8 = rng.gen();
|
||||
let b: i8 = rng.gen();
|
||||
|
||||
let expected = match a.checked_mul(b) {
|
||||
Some(valid) => valid,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let a_bit = Int8::constant(a);
|
||||
let b_bit = Int8::constant(b);
|
||||
|
||||
let r = a_bit.mul(cs.ns(|| "multiplication"), &b_bit).unwrap();
|
||||
|
||||
assert!(r.value == Some(expected));
|
||||
|
||||
check_all_constant_bits(expected, r);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_int8_mul() {
|
||||
let mut rng = XorShiftRng::seed_from_u64(1231275789u64);
|
||||
|
||||
for _ in 0..1000 {
|
||||
let mut cs = TestConstraintSystem::<Fr>::new();
|
||||
|
||||
let a: i8 = rng.gen();
|
||||
let b: i8 = rng.gen();
|
||||
|
||||
let expected = match a.checked_mul(b) {
|
||||
Some(valid) => valid,
|
||||
None => continue,
|
||||
};
|
||||
|
||||
let a_bit = Int8::alloc(cs.ns(|| "a_bit"), || Ok(a)).unwrap();
|
||||
let b_bit = Int8::alloc(cs.ns(|| "b_bit"), || Ok(b)).unwrap();
|
||||
|
||||
let r = a_bit.mul(cs.ns(|| "multiplication"), &b_bit).unwrap();
|
||||
|
||||
assert!(cs.is_satisfied());
|
||||
|
||||
assert!(r.value == Some(expected));
|
||||
|
||||
check_all_allocated_bits(expected, r);
|
||||
|
||||
// Flip a bit_gadget and see if the multiplication constraint still works
|
||||
if cs.get("multiplication/result bit_gadget 0/boolean").is_zero() {
|
||||
cs.set("multiplication/result bit_gadget 0/boolean", Fr::one());
|
||||
} else {
|
||||
cs.set("multiplication/result bit_gadget 0/boolean", Fr::zero());
|
||||
}
|
||||
|
||||
assert!(!cs.is_satisfied());
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user