diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f6451155c..ce76f1414d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,10 @@ for Rust libraries in [RFC #1105](https://github.com/rust-lang/rfcs/blob/master/ * Add `min` function that mirrors SQLs MIN. +* Added support for the `Numeric` data type. Since there is no Big Decimal type + in the standard library, a dumb struct has been provided which mirrors what + Postgres provides, which can be converted into whatever crate you are using. + ## [0.2.0] - 2015-11-30 ### Added diff --git a/diesel/src/types/impls/floats.rs b/diesel/src/types/impls/floats.rs deleted file mode 100644 index 8ccfd8877c..0000000000 --- a/diesel/src/types/impls/floats.rs +++ /dev/null @@ -1,38 +0,0 @@ -extern crate byteorder; - -use self::byteorder::{ReadBytesExt, WriteBytesExt, BigEndian}; -use super::option::UnexpectedNullError; -use types::{FromSql, ToSql, IsNull}; -use types; -use std::error::Error; -use std::io::Write; - -impl FromSql for f32 { - fn from_sql(bytes: Option<&[u8]>) -> Result> { - let mut bytes = not_none!(bytes); - bytes.read_f32::().map_err(|e| Box::new(e) as Box) - } -} - -impl ToSql for f32 { - fn to_sql(&self, out: &mut W) -> Result> { - out.write_f32::(*self) - .map(|_| IsNull::No) - .map_err(|e| Box::new(e) as Box) - } -} - -impl FromSql for f64 { - fn from_sql(bytes: Option<&[u8]>) -> Result> { - let mut bytes = not_none!(bytes); - bytes.read_f64::().map_err(|e| Box::new(e) as Box) - } -} - -impl ToSql for f64 { - fn to_sql(&self, out: &mut W) -> Result> { - out.write_f64::(*self) - .map(|_| IsNull::No) - .map_err(|e| Box::new(e) as Box) - } -} diff --git a/diesel/src/types/impls/floats/mod.rs b/diesel/src/types/impls/floats/mod.rs new file mode 100644 index 0000000000..0d32f7e98f --- /dev/null +++ b/diesel/src/types/impls/floats/mod.rs @@ -0,0 +1,135 @@ +extern crate byteorder; + +use self::byteorder::{ReadBytesExt, WriteBytesExt, BigEndian}; +use super::option::UnexpectedNullError; +use types::{FromSql, ToSql, IsNull}; +use types; +use std::error::Error; +use std::io::Write; + +#[cfg(feature = "quickcheck")] +mod quickcheck_impls; + +impl FromSql for f32 { + fn from_sql(bytes: Option<&[u8]>) -> Result> { + let mut bytes = not_none!(bytes); + bytes.read_f32::().map_err(|e| Box::new(e) as Box) + } +} + +impl ToSql for f32 { + fn to_sql(&self, out: &mut W) -> Result> { + out.write_f32::(*self) + .map(|_| IsNull::No) + .map_err(|e| Box::new(e) as Box) + } +} + +impl FromSql for f64 { + fn from_sql(bytes: Option<&[u8]>) -> Result> { + let mut bytes = not_none!(bytes); + bytes.read_f64::().map_err(|e| Box::new(e) as Box) + } +} + +impl ToSql for f64 { + fn to_sql(&self, out: &mut W) -> Result> { + out.write_f64::(*self) + .map(|_| IsNull::No) + .map_err(|e| Box::new(e) as Box) + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PgNumeric { + Positive { + weight: i16, + scale: u16, + digits: Vec, + }, + Negative { + weight: i16, + scale: u16, + digits: Vec, + }, + NaN, +} + +#[derive(Debug, Clone, Copy)] +struct InvalidNumericSign(u16); + +impl ::std::fmt::Display for InvalidNumericSign { + fn fmt(&self, f: &mut ::std::fmt::Formatter) -> ::std::fmt::Result { + write!(f, "InvalidNumericSign({0:x})", self.0) + } +} + +impl Error for InvalidNumericSign { + fn description(&self) -> &str { + "sign for numeric field was not one of 0, 0x4000, 0xC000" + } +} + +impl FromSql for PgNumeric { + fn from_sql(bytes: Option<&[u8]>) -> Result> { + let mut bytes = not_none!(bytes); + let ndigits = try!(bytes.read_u16::()); + let mut digits = Vec::with_capacity(ndigits as usize); + let weight = try!(bytes.read_i16::()); + let sign = try!(bytes.read_u16::()); + let scale = try!(bytes.read_u16::()); + for _ in 0..ndigits { + digits.push(try!(bytes.read_i16::())); + } + + match sign { + 0 => Ok(PgNumeric::Positive { + weight: weight, + scale: scale, + digits: digits, + }), + 0x4000 => Ok(PgNumeric::Negative { + weight: weight, + scale: scale, + digits: digits, + }), + 0xC000 => Ok(PgNumeric::NaN), + invalid => Err(Box::new(InvalidNumericSign(invalid))), + } + } +} + +impl ToSql for PgNumeric { + fn to_sql(&self, out: &mut W) -> Result> { + let sign = match self { + &PgNumeric::Positive { .. } => 0, + &PgNumeric::Negative { .. } => 0x4000, + &PgNumeric::NaN => 0xC000, + }; + let empty_vec = Vec::new(); + let digits = match self { + &PgNumeric::Positive { ref digits, .. } => digits, + &PgNumeric::Negative { ref digits, .. } => digits, + &PgNumeric::NaN => &empty_vec, + }; + let weight = match self { + &PgNumeric::Positive { weight, .. } => weight, + &PgNumeric::Negative { weight, .. } => weight, + &PgNumeric::NaN => 0, + }; + let scale = match self { + &PgNumeric::Positive { scale, .. } => scale, + &PgNumeric::Negative { scale, .. } => scale, + &PgNumeric::NaN => 0, + }; + try!(out.write_u16::(digits.len() as u16)); + try!(out.write_i16::(weight)); + try!(out.write_u16::(sign)); + try!(out.write_u16::(scale)); + for digit in digits.iter() { + try!(out.write_i16::(*digit)); + } + + Ok(IsNull::No) + } +} diff --git a/diesel/src/types/impls/floats/quickcheck_impls.rs b/diesel/src/types/impls/floats/quickcheck_impls.rs new file mode 100644 index 0000000000..3534c0f2a9 --- /dev/null +++ b/diesel/src/types/impls/floats/quickcheck_impls.rs @@ -0,0 +1,65 @@ +extern crate quickcheck; + +use self::quickcheck::{Arbitrary, Gen}; + +use super::PgNumeric; + +const SCALE_MASK: u16 = 0x3FFF; + +impl Arbitrary for PgNumeric { + fn arbitrary(g: &mut G) -> Self { + let mut variant = Option::::arbitrary(g); + let mut weight = -1; + while weight < 0 { + // Oh postgres... Don't ever change. http://bit.ly/lol-code-comments + weight = i16::arbitrary(g); + } + let scale = u16::arbitrary(g) & SCALE_MASK; + let digits = gen_vec_of_appropriate_length_valid_digits(g, weight as u16, scale); + if digits.len() == 0 { + weight = 0; + variant = Some(true); + } + + match variant { + Some(true) => PgNumeric::Positive { + digits: digits, + weight: weight, + scale: scale, + }, + Some(false) => PgNumeric::Negative { + digits: digits, + weight: weight, + scale: scale, + }, + None => PgNumeric::NaN, + } + } +} + +fn gen_vec_of_appropriate_length_valid_digits +(g: &mut G, weight: u16, scale: u16) -> Vec { + let max_digits = ::std::cmp::min(weight, scale); + let mut digits = Vec::::arbitrary(g).into_iter() + .map(|d| d.0) + .skip_while(|d| d == &0) // drop leading zeros + .take(max_digits as usize) + .collect::>(); + while digits.last() == Some(&0) { + digits.pop(); // drop trailing zeros + } + digits +} + +#[derive(Debug, Clone, Copy)] +struct Digit(i16); + +impl Arbitrary for Digit { + fn arbitrary(g: &mut G) -> Self { + let mut n = -1; + while n < 0 || n >= 10000 { + n = i16::arbitrary(g); + } + Digit(n) + } +} diff --git a/diesel/src/types/impls/mod.rs b/diesel/src/types/impls/mod.rs index 48f952b48c..da8415d5b1 100644 --- a/diesel/src/types/impls/mod.rs +++ b/diesel/src/types/impls/mod.rs @@ -80,7 +80,7 @@ macro_rules! primitive_impls { mod array; pub mod date_and_time; -mod floats; +pub mod floats; mod integers; mod option; mod primitives; diff --git a/diesel/src/types/impls/primitives.rs b/diesel/src/types/impls/primitives.rs index aa8419b9a4..0a8cca73c5 100644 --- a/diesel/src/types/impls/primitives.rs +++ b/diesel/src/types/impls/primitives.rs @@ -1,7 +1,9 @@ -use expression::{Expression, AsExpression}; -use expression::bound::Bound; use std::error::Error; use std::io::Write; + +use data_types::PgNumeric; +use expression::bound::Bound; +use expression::{Expression, AsExpression}; use super::option::UnexpectedNullError; use types::{NativeSqlType, FromSql, ToSql, IsNull}; use {Queriable, types}; @@ -15,6 +17,7 @@ primitive_impls! { Float -> (f32, 700), Double -> (f64, 701), + Numeric -> (PgNumeric, 1700), VarChar -> (String, 1043), Text -> (String, 25), diff --git a/diesel/src/types/mod.rs b/diesel/src/types/mod.rs index cd9d8ac1dc..865da6e7c1 100644 --- a/diesel/src/types/mod.rs +++ b/diesel/src/types/mod.rs @@ -11,6 +11,7 @@ pub mod structs { //! there is no existing Rust primitive, or where using it would be //! confusing (such as date and time types) pub use super::super::impls::date_and_time::{PgTimestamp, PgDate, PgTime, PgInterval}; + pub use super::super::impls::floats::PgNumeric; } } @@ -33,6 +34,7 @@ pub type BigSerial = BigInt; #[derive(Clone, Copy, Default)] pub struct Float; #[derive(Clone, Copy, Default)] pub struct Double; +#[derive(Clone, Copy, Default)] pub struct Numeric; #[derive(Clone, Copy, Default)] pub struct VarChar; #[derive(Clone, Copy, Default)] pub struct Text; diff --git a/diesel/src/types/ops.rs b/diesel/src/types/ops.rs index 36c69a4d73..d6fce313af 100644 --- a/diesel/src/types/ops.rs +++ b/diesel/src/types/ops.rs @@ -50,7 +50,7 @@ macro_rules! numeric_type { } } -numeric_type!(SmallInt, Integer, BigInt, Float, Double); +numeric_type!(SmallInt, Integer, BigInt, Float, Double, Numeric); impl Add for super::Timestamp { type Rhs = super::Interval; diff --git a/diesel_tests/tests/types.rs b/diesel_tests/tests/types.rs index d1ef06f597..e152f55092 100644 --- a/diesel_tests/tests/types.rs +++ b/diesel_tests/tests/types.rs @@ -263,6 +263,29 @@ fn pg_timestamp_to_sql_timestamp() { assert!(!query_to_sql_equality::(expected_non_equal_value, value)); } +#[test] +fn pg_numeric_from_sql() { + use diesel::data_types::PgNumeric; + + let query = "SELECT 1.0::numeric"; + let expected_value = PgNumeric::Positive { + digits: vec![1], + weight: 0, + scale: 1, + }; + assert_eq!(expected_value, query_single_value::(query)); + let query = "SELECT -31.0::numeric"; + let expected_value = PgNumeric::Negative { + digits: vec![31], + weight: 0, + scale: 1, + }; + assert_eq!(expected_value, query_single_value::(query)); + let query = "SELECT 'NaN'::numeric"; + let expected_value = PgNumeric::NaN; + assert_eq!(expected_value, query_single_value::(query)); +} + fn query_single_value>(sql: &str) -> U { let connection = connection(); let mut cursor = connection.query_sql::(sql) diff --git a/diesel_tests/tests/types_roundtrip.rs b/diesel_tests/tests/types_roundtrip.rs index e4ee5cbdbd..0da85219ad 100644 --- a/diesel_tests/tests/types_roundtrip.rs +++ b/diesel_tests/tests/types_roundtrip.rs @@ -57,3 +57,4 @@ test_round_trip!(date_roundtrips, Date, PgDate, "date"); test_round_trip!(time_roundtrips, Time, PgTime, "time"); test_round_trip!(timestamp_roundtrips, Timestamp, PgTimestamp, "timestamp"); test_round_trip!(interval_roundtrips, Interval, PgInterval, "interval"); +test_round_trip!(numeric_roundtrips, Numeric, PgNumeric, "numeric");