From 173ec9a1e54fd3ed4e14c21d17daa6230a7b045a Mon Sep 17 00:00:00 2001 From: Vamshi Surabhi <0x777@users.noreply.github.com> Date: Mon, 12 Aug 2024 14:47:46 -0700 Subject: [PATCH] [PACHA-12] sql: fix failing introspection queries (#958) ### What Introspection queries (on 'hasura' schema) would fail when there is no data in the underlying tables. ### How A more robust 'MemTable' with a comprehensive set of tests is introduced which shouldn't run into these issues. V3_GIT_ORIGIN_REV_ID: e09de03e8d093fb4348514cfed6b6dc1d9b0b0c8 --- v3/Cargo.lock | 1 + v3/crates/sql/Cargo.toml | 3 + v3/crates/sql/src/catalog.rs | 1 + v3/crates/sql/src/catalog/introspection.rs | 427 +++++++----------- .../catalog/introspection/column_metadata.rs | 79 ++++ .../src/catalog/introspection/foreign_keys.rs | 101 +++++ .../catalog/introspection/table_metadata.rs | 68 +++ v3/crates/sql/src/catalog/mem_table.rs | 413 +++++++++++++++++ 8 files changed, 833 insertions(+), 260 deletions(-) create mode 100644 v3/crates/sql/src/catalog/introspection/column_metadata.rs create mode 100644 v3/crates/sql/src/catalog/introspection/foreign_keys.rs create mode 100644 v3/crates/sql/src/catalog/introspection/table_metadata.rs create mode 100644 v3/crates/sql/src/catalog/mem_table.rs diff --git a/v3/Cargo.lock b/v3/Cargo.lock index 016c5ac111c..2023540b981 100644 --- a/v3/Cargo.lock +++ b/v3/Cargo.lock @@ -4411,6 +4411,7 @@ dependencies = [ "serde", "serde_json", "thiserror", + "tokio", "tracing-util", ] diff --git a/v3/crates/sql/Cargo.toml b/v3/crates/sql/Cargo.toml index e22e3bc3008..2de76541a48 100644 --- a/v3/crates/sql/Cargo.toml +++ b/v3/crates/sql/Cargo.toml @@ -22,5 +22,8 @@ serde = { workspace = true, features = ["rc"] } serde_json = { workspace = true } thiserror = { workspace = true } +[dev-dependencies] +tokio = { workspace = true, features = ["macros", "rt-multi-thread"] } + [lints] workspace = true diff --git a/v3/crates/sql/src/catalog.rs b/v3/crates/sql/src/catalog.rs index 1cfb35d1204..419caecdf87 100644 --- a/v3/crates/sql/src/catalog.rs +++ b/v3/crates/sql/src/catalog.rs @@ -15,6 +15,7 @@ mod datafusion { } pub mod introspection; +pub mod mem_table; pub mod model; pub mod subgraph; diff --git a/v3/crates/sql/src/catalog/introspection.rs b/v3/crates/sql/src/catalog/introspection.rs index 81aeef03aa7..bad9f20e8e5 100644 --- a/v3/crates/sql/src/catalog/introspection.rs +++ b/v3/crates/sql/src/catalog/introspection.rs @@ -3,37 +3,33 @@ use std::{any::Any, sync::Arc}; use async_trait::async_trait; +use column_metadata::{ColumnMetadata, ColumnMetadataRow, COLUMN_METADATA}; +use foreign_keys::{InferredForeignKeys, InferredForeignKeysRow, INFERRED_FOREIGN_KEY_CONSTRAINTS}; use indexmap::IndexMap; use metadata_resolve::{self as resolved, ModelRelationshipTarget}; +use table_metadata::{TableMetadata, TableMetadataRow, TABLE_METADATA}; mod datafusion { pub(super) use datafusion::{ - arrow::{ - array::RecordBatch, - datatypes::{DataType, Field, Schema, SchemaRef}, - }, - catalog::schema::SchemaProvider, - common::ScalarValue, - datasource::{TableProvider, TableType}, - error::Result, - execution::context::SessionState, - logical_expr::Expr, - physical_plan::{values::ValuesExec, ExecutionPlan}, + catalog::schema::SchemaProvider, datasource::TableProvider, error::Result, }; } use open_dds::relationships::RelationshipType; use serde::{Deserialize, Serialize}; +use super::mem_table::MemTable; + +mod column_metadata; +mod foreign_keys; +mod table_metadata; + pub const HASURA_METADATA_SCHEMA: &str = "hasura"; -pub const TABLE_METADATA: &str = "table_metadata"; -pub const COLUMN_METADATA: &str = "column_metadata"; -pub const INFERRED_FOREIGN_KEY_CONSTRAINTS: &str = "inferred_foreign_key_constraints"; /// Describes the database schema structure and metadata. #[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] pub(crate) struct Introspection { - pub(crate) table_metadata: TableMetadata, - pub(crate) column_metadata: ColumnMetadata, - pub(crate) inferred_foreign_key_constraints: InferredatafusionoreignKeys, + table_metadata: TableMetadata, + column_metadata: ColumnMetadata, + inferred_foreign_key_constraints: InferredForeignKeys, } impl Introspection { @@ -47,18 +43,18 @@ impl Introspection { let mut foreign_key_constraint_rows = Vec::new(); for (schema_name, schema) in schemas { for (table_name, table) in &schema.tables { - table_metadata_rows.push(TableRow::new( + table_metadata_rows.push(TableMetadataRow::new( schema_name.to_string(), table_name.to_string(), table.description.clone(), )); for (column_name, column_description) in &table.columns { - column_metadata_rows.push(ColumnRow { - schema_name: schema_name.to_string(), - table_name: table_name.clone(), - column_name: column_name.clone(), - description: column_description.clone(), - }); + column_metadata_rows.push(ColumnMetadataRow::new( + schema_name.to_string(), + table_name.clone(), + column_name.clone(), + column_description.clone(), + )); } // TODO: @@ -77,14 +73,14 @@ impl Introspection { ) = &relationship.target { for mapping in mappings { - foreign_key_constraint_rows.push(ForeignKeyRow { - from_schema_name: schema_name.to_string(), - from_table_name: table_name.clone(), - from_column_name: mapping.source_field.field_name.to_string(), - to_schema_name: model_name.subgraph.to_string(), - to_table_name: model_name.name.to_string(), - to_column_name: mapping.target_field.field_name.to_string(), - }); + foreign_key_constraint_rows.push(InferredForeignKeysRow::new( + schema_name.to_string(), + table_name.clone(), + mapping.source_field.field_name.to_string(), + model_name.subgraph.to_string(), + model_name.name.to_string(), + mapping.target_field.field_name.to_string(), + )); } } } @@ -94,197 +90,14 @@ impl Introspection { Introspection { table_metadata: TableMetadata::new(table_metadata_rows), column_metadata: ColumnMetadata::new(column_metadata_rows), - inferred_foreign_key_constraints: InferredatafusionoreignKeys::new( - foreign_key_constraint_rows, - ), + inferred_foreign_key_constraints: InferredForeignKeys::new(foreign_key_constraint_rows), } } } -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] -pub(crate) struct TableMetadata { - schema: datafusion::SchemaRef, - rows: Vec, -} - -impl TableMetadata { - pub(crate) fn new(rows: Vec) -> Self { - let schema_name = datafusion::Field::new("schema_name", datafusion::DataType::Utf8, false); - let table_name = datafusion::Field::new("table_name", datafusion::DataType::Utf8, false); - let description = datafusion::Field::new("description", datafusion::DataType::Utf8, true); - let schema = datafusion::SchemaRef::new(datafusion::Schema::new(vec![ - schema_name, - table_name, - description, - ])); - TableMetadata { schema, rows } - } -} - -impl TableMetadata { - fn to_values_table(&self) -> ValuesTable { - ValuesTable { - schema: self.schema.clone(), - rows: self - .rows - .iter() - .map(|row| { - vec![ - ScalarValue::Utf8(Some(row.schema_name.clone())), - ScalarValue::Utf8(Some(row.table_name.clone())), - ScalarValue::Utf8(row.description.clone()), - ] - }) - .collect(), - } - } -} - -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] -pub(crate) struct TableRow { - schema_name: String, - table_name: String, - description: Option, -} - -impl TableRow { - pub(crate) fn new( - schema_name: String, - table_name: String, - description: Option, - ) -> Self { - Self { - schema_name, - table_name, - description, - } - } -} - -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] -pub(crate) struct ColumnMetadata { - pub(crate) schema: datafusion::SchemaRef, - pub(crate) rows: Vec, -} - -impl ColumnMetadata { - fn new(rows: Vec) -> Self { - let schema_name = datafusion::Field::new("schema_name", datafusion::DataType::Utf8, false); - let table_name = datafusion::Field::new("table_name", datafusion::DataType::Utf8, false); - let column_name = datafusion::Field::new("column_name", datafusion::DataType::Utf8, false); - let description = datafusion::Field::new("description", datafusion::DataType::Utf8, true); - let schema = datafusion::SchemaRef::new(datafusion::Schema::new(vec![ - schema_name, - table_name, - column_name, - description, - ])); - ColumnMetadata { schema, rows } - } - fn to_values_table(&self) -> ValuesTable { - ValuesTable { - schema: self.schema.clone(), - rows: self - .rows - .iter() - .map(|row| { - vec![ - ScalarValue::Utf8(Some(row.schema_name.clone())), - ScalarValue::Utf8(Some(row.table_name.clone())), - ScalarValue::Utf8(Some(row.column_name.clone())), - ScalarValue::Utf8(row.description.clone()), - ] - }) - .collect(), - } - } -} - -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] -pub(crate) struct ColumnRow { - schema_name: String, - table_name: String, - column_name: String, - description: Option, -} - -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] -pub(crate) struct InferredatafusionoreignKeys { - schema: datafusion::SchemaRef, - rows: Vec, -} - -impl InferredatafusionoreignKeys { - fn new(rows: Vec) -> Self { - let from_schema_name = - datafusion::Field::new("from_schema_name", datafusion::DataType::Utf8, false); - let from_table_name = - datafusion::Field::new("from_table_name", datafusion::DataType::Utf8, false); - let from_column_name = - datafusion::Field::new("from_column_name", datafusion::DataType::Utf8, false); - let to_schema_name = - datafusion::Field::new("to_schema_name", datafusion::DataType::Utf8, false); - let to_table_name = - datafusion::Field::new("to_table_name", datafusion::DataType::Utf8, false); - let to_column_name = - datafusion::Field::new("to_column_name", datafusion::DataType::Utf8, false); - let schema = datafusion::SchemaRef::new(datafusion::Schema::new(vec![ - from_schema_name, - from_table_name, - from_column_name, - to_schema_name, - to_table_name, - to_column_name, - ])); - InferredatafusionoreignKeys { schema, rows } - } - fn to_values_table(&self) -> ValuesTable { - ValuesTable { - schema: self.schema.clone(), - rows: self - .rows - .iter() - .map(|row| { - vec![ - ScalarValue::Utf8(Some(row.from_schema_name.clone())), - ScalarValue::Utf8(Some(row.from_table_name.clone())), - ScalarValue::Utf8(Some(row.from_column_name.clone())), - ScalarValue::Utf8(Some(row.to_schema_name.clone())), - ScalarValue::Utf8(Some(row.to_table_name.clone())), - ScalarValue::Utf8(Some(row.to_column_name.clone())), - ] - }) - .collect(), - } - } -} - -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] -enum ScalarValue { - Utf8(Option), -} - -impl ScalarValue { - fn into_datafusion_scalar_value(self) -> datafusion::ScalarValue { - match self { - ScalarValue::Utf8(value) => datafusion::ScalarValue::Utf8(value), - } - } -} - -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] -struct ForeignKeyRow { - from_schema_name: String, - from_table_name: String, - from_column_name: String, - to_schema_name: String, - to_table_name: String, - to_column_name: String, -} - #[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] pub(crate) struct IntrospectionSchemaProvider { - tables: IndexMap>, + tables: IndexMap>, } impl IntrospectionSchemaProvider { @@ -292,17 +105,17 @@ impl IntrospectionSchemaProvider { let tables = [ ( TABLE_METADATA, - introspection.table_metadata.to_values_table(), + introspection.table_metadata.to_table_provider(), ), ( COLUMN_METADATA, - introspection.column_metadata.to_values_table(), + introspection.column_metadata.to_table_provider(), ), ( INFERRED_FOREIGN_KEY_CONSTRAINTS, introspection .inferred_foreign_key_constraints - .to_values_table(), + .to_table_provider(), ), ] .into_iter() @@ -338,51 +151,145 @@ impl datafusion::SchemaProvider for IntrospectionSchemaProvider { } } -// A table with static rows -#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] -struct ValuesTable { - schema: datafusion::SchemaRef, - rows: Vec>, -} +#[cfg(test)] +mod tests { + use ::datafusion::catalog::{CatalogProvider, MemoryCatalogProvider, SchemaProvider}; -#[async_trait] -impl datafusion::TableProvider for ValuesTable { - fn as_any(&self) -> &dyn Any { - self + use super::*; + use ::datafusion::prelude::*; + use std::sync::Arc; + + fn create_test_introspection() -> Introspection { + let table_metadata = TableMetadata::new(vec![ + TableMetadataRow::new( + "public".to_string(), + "users".to_string(), + Some("Users table".to_string()), + ), + TableMetadataRow::new( + "public".to_string(), + "posts".to_string(), + Some("Posts table".to_string()), + ), + ]); + + let column_metadata = ColumnMetadata::new(vec![ + ColumnMetadataRow::new( + "public".to_string(), + "users".to_string(), + "id".to_string(), + Some("User ID".to_string()), + ), + ColumnMetadataRow::new( + "public".to_string(), + "users".to_string(), + "name".to_string(), + Some("User name".to_string()), + ), + ColumnMetadataRow::new( + "public".to_string(), + "posts".to_string(), + "id".to_string(), + Some("Post ID".to_string()), + ), + ColumnMetadataRow::new( + "public".to_string(), + "posts".to_string(), + "user_id".to_string(), + Some("Author's user ID".to_string()), + ), + ]); + + let inferred_foreign_keys = InferredForeignKeys::new(vec![InferredForeignKeysRow::new( + "public".to_string(), + "posts".to_string(), + "user_id".to_string(), + "public".to_string(), + "users".to_string(), + "id".to_string(), + )]); + + Introspection { + table_metadata, + column_metadata, + inferred_foreign_key_constraints: inferred_foreign_keys, + } } - fn schema(&self) -> datafusion::SchemaRef { - self.schema.clone() + #[tokio::test] + async fn test_introspection_schema_provider_table() { + let introspection = create_test_introspection(); + let schema_provider = IntrospectionSchemaProvider::new(&introspection); + + let table_metadata = schema_provider.table(TABLE_METADATA).await.unwrap(); + assert!(table_metadata.is_some()); + + let column_metadata = schema_provider.table(COLUMN_METADATA).await.unwrap(); + assert!(column_metadata.is_some()); + + let foreign_keys = schema_provider + .table(INFERRED_FOREIGN_KEY_CONSTRAINTS) + .await + .unwrap(); + assert!(foreign_keys.is_some()); + + let non_existent_table = schema_provider.table("non_existent").await.unwrap(); + assert!(non_existent_table.is_none()); } - fn table_type(&self) -> datafusion::TableType { - datafusion::TableType::View + // ... (keep the create_test_introspection function and other existing tests) + + fn create_test_context(introspection: &Introspection) -> SessionContext { + let config = SessionConfig::new().with_default_catalog_and_schema("default", "default"); + let schema_provider = Arc::new(IntrospectionSchemaProvider::new(introspection)); + let ctx = SessionContext::new_with_config(config); + let catalog = MemoryCatalogProvider::new(); + catalog + .register_schema(HASURA_METADATA_SCHEMA, schema_provider) + .unwrap(); + ctx.register_catalog("default", Arc::new(catalog)); + ctx } - async fn scan( - &self, - _state: &datafusion::SessionState, - projection: Option<&Vec>, - // filters and limit can be used here to inject some push-down operations if needed - _filters: &[datafusion::Expr], - _limit: Option, - ) -> datafusion::Result> { - let projected_schema = Arc::new(self.schema.project(projection.unwrap_or(&vec![]))?); - let columnar_projection = projection - .unwrap_or(&vec![]) - .iter() - .map(|j| { - self.rows - .iter() - .map(|row| row[*j].clone().into_datafusion_scalar_value()) - }) - .map(datafusion::ScalarValue::iter_to_array) - .collect::>>()?; - Ok(Arc::new(datafusion::ValuesExec::try_new_from_batches( - projected_schema.clone(), - vec![datafusion::RecordBatch::try_new( - projected_schema, - columnar_projection, - )?], - )?)) + + #[tokio::test] + async fn test_query_table_metadata() { + let introspection = create_test_introspection(); + let ctx = create_test_context(&introspection); + + let sql = format!("SELECT * FROM {HASURA_METADATA_SCHEMA}.{TABLE_METADATA}"); + let df = ctx.sql(&sql).await.unwrap(); + let results = df.collect().await.unwrap(); + + assert_eq!(results.len(), 1); + assert_eq!(results[0].num_rows(), 2); + } + + #[tokio::test] + async fn test_query_column_metadata() { + let introspection = create_test_introspection(); + let ctx = create_test_context(&introspection); + + let sql = format!( + "SELECT * FROM {HASURA_METADATA_SCHEMA}.{COLUMN_METADATA} WHERE table_name = 'users'", + ); + let df = ctx.sql(&sql).await.unwrap(); + let results = df.collect().await.unwrap(); + + assert_eq!(results.len(), 1); + assert_eq!(results[0].num_rows(), 2); + } + + #[tokio::test] + async fn test_query_inferred_foreign_keys() { + let introspection = create_test_introspection(); + let ctx = create_test_context(&introspection); + + let sql = + format!("SELECT * FROM {HASURA_METADATA_SCHEMA}.{INFERRED_FOREIGN_KEY_CONSTRAINTS}",); + let df = ctx.sql(&sql).await.unwrap(); + let results = df.collect().await.unwrap(); + + assert_eq!(results.len(), 1); + assert_eq!(results[0].num_rows(), 1); } } diff --git a/v3/crates/sql/src/catalog/introspection/column_metadata.rs b/v3/crates/sql/src/catalog/introspection/column_metadata.rs new file mode 100644 index 00000000000..2d745f3a0b7 --- /dev/null +++ b/v3/crates/sql/src/catalog/introspection/column_metadata.rs @@ -0,0 +1,79 @@ +use serde::{Deserialize, Serialize}; + +use crate::catalog::mem_table::MemTable; + +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] +pub(crate) struct ColumnMetadataRow { + schema_name: String, + table_name: String, + column_name: String, + description: Option, +} + +impl ColumnMetadataRow { + pub(crate) fn new( + schema_name: String, + table_name: String, + column_name: String, + description: Option, + ) -> Self { + Self { + schema_name, + table_name, + column_name, + description, + } + } +} + +pub const COLUMN_METADATA: &str = "column_metadata"; + +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] +pub(super) struct ColumnMetadata { + pub(crate) rows: Vec, +} + +impl ColumnMetadata { + pub(super) fn new(rows: Vec) -> Self { + ColumnMetadata { rows } + } +} + +impl ColumnMetadata { + pub(super) fn to_table_provider(&self) -> MemTable { + MemTable::new_from_iter(vec![ + ( + "schema_name", + false, + self.rows + .iter() + .map(|row| Some(row.schema_name.clone())) + .collect(), + ), + ( + "table_name", + false, + self.rows + .iter() + .map(|row| Some(row.table_name.clone())) + .collect(), + ), + ( + "column_name", + false, + self.rows + .iter() + .map(|row| Some(row.column_name.clone())) + .collect(), + ), + ( + "description", + true, + self.rows + .iter() + .map(|row| row.description.clone()) + .collect(), + ), + ]) + } +} diff --git a/v3/crates/sql/src/catalog/introspection/foreign_keys.rs b/v3/crates/sql/src/catalog/introspection/foreign_keys.rs new file mode 100644 index 00000000000..00308d1d5d2 --- /dev/null +++ b/v3/crates/sql/src/catalog/introspection/foreign_keys.rs @@ -0,0 +1,101 @@ +use serde::{Deserialize, Serialize}; + +use crate::catalog::mem_table::MemTable; + +pub(super) const INFERRED_FOREIGN_KEY_CONSTRAINTS: &str = "inferred_foreign_key_constraints"; + +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] +pub(super) struct InferredForeignKeysRow { + from_schema_name: String, + from_table_name: String, + from_column_name: String, + to_schema_name: String, + to_table_name: String, + to_column_name: String, +} + +impl InferredForeignKeysRow { + pub(super) fn new( + from_schema_name: String, + from_table_name: String, + from_column_name: String, + to_schema_name: String, + to_table_name: String, + to_column_name: String, + ) -> Self { + Self { + from_schema_name, + from_table_name, + from_column_name, + to_schema_name, + to_table_name, + to_column_name, + } + } +} + +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] +pub(crate) struct InferredForeignKeys { + rows: Vec, +} + +impl InferredForeignKeys { + pub(super) fn new(rows: Vec) -> Self { + InferredForeignKeys { rows } + } +} + +impl InferredForeignKeys { + pub(super) fn to_table_provider(&self) -> MemTable { + MemTable::new_from_iter(vec![ + ( + "from_schema_name", + false, + self.rows + .iter() + .map(|row| Some(row.from_schema_name.clone())) + .collect(), + ), + ( + "from_table_name", + false, + self.rows + .iter() + .map(|row| Some(row.from_table_name.clone())) + .collect(), + ), + ( + "from_column_name", + false, + self.rows + .iter() + .map(|row| Some(row.from_column_name.clone())) + .collect(), + ), + ( + "to_schema_name", + false, + self.rows + .iter() + .map(|row| Some(row.to_schema_name.clone())) + .collect(), + ), + ( + "to_table_name", + false, + self.rows + .iter() + .map(|row| Some(row.to_table_name.clone())) + .collect(), + ), + ( + "to_column_name", + false, + self.rows + .iter() + .map(|row| Some(row.to_column_name.clone())) + .collect(), + ), + ]) + } +} diff --git a/v3/crates/sql/src/catalog/introspection/table_metadata.rs b/v3/crates/sql/src/catalog/introspection/table_metadata.rs new file mode 100644 index 00000000000..31738dede32 --- /dev/null +++ b/v3/crates/sql/src/catalog/introspection/table_metadata.rs @@ -0,0 +1,68 @@ +use serde::{Deserialize, Serialize}; + +use crate::catalog::mem_table::MemTable; + +pub const TABLE_METADATA: &str = "table_metadata"; + +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] +pub(crate) struct TableMetadata { + rows: Vec, +} + +impl TableMetadata { + pub(crate) fn new(rows: Vec) -> Self { + TableMetadata { rows } + } +} + +impl TableMetadata { + pub(crate) fn to_table_provider(&self) -> MemTable { + MemTable::new_from_iter(vec![ + ( + "schema_name", + false, + self.rows + .iter() + .map(|row| Some(row.schema_name.clone())) + .collect(), + ), + ( + "table_name", + false, + self.rows + .iter() + .map(|row| Some(row.table_name.clone())) + .collect(), + ), + ( + "description", + true, + self.rows + .iter() + .map(|row| row.description.clone()) + .collect(), + ), + ]) + } +} + +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] +pub(crate) struct TableMetadataRow { + schema_name: String, + table_name: String, + description: Option, +} + +impl TableMetadataRow { + pub(crate) fn new( + schema_name: String, + table_name: String, + description: Option, + ) -> Self { + Self { + schema_name, + table_name, + description, + } + } +} diff --git a/v3/crates/sql/src/catalog/mem_table.rs b/v3/crates/sql/src/catalog/mem_table.rs new file mode 100644 index 00000000000..eadd93131df --- /dev/null +++ b/v3/crates/sql/src/catalog/mem_table.rs @@ -0,0 +1,413 @@ +//! In-memory table implementation for DataFusion +//! +//! This module provides a serializable and deserializable in-memory table implementation. +//! Datafusion's built-in MemTable doesn't support the serializability property +use async_trait::async_trait; +use std::{any::Any, sync::Arc}; +mod datafusion { + pub(super) use datafusion::{ + arrow::{ + array::{ArrayRef, BooleanArray, Int64Array, RecordBatch, StringArray}, + datatypes::{DataType, Field, SchemaBuilder, SchemaRef}, + }, + datasource::{TableProvider, TableType}, + error::Result, + execution::context::SessionState, + logical_expr::Expr, + physical_plan::{values::ValuesExec, ExecutionPlan}, + }; +} +use serde::{Deserialize, Serialize}; + +/// Represents the data for a single column in a table. +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] +pub(crate) enum ColumnData { + Utf8 { data: Vec> }, + Int64 { data: Vec> }, + Bool { data: Vec> }, +} + +impl<'a> FromIterator> for ColumnData { + fn from_iter>>(iter: I) -> Self { + ColumnData::Utf8 { + data: iter + .into_iter() + .map(|s| s.map(std::borrow::ToOwned::to_owned)) + .collect(), + } + } +} + +impl FromIterator> for ColumnData { + fn from_iter>>(iter: I) -> Self { + ColumnData::Utf8 { + data: iter.into_iter().collect(), + } + } +} + +impl FromIterator> for ColumnData { + fn from_iter>>(iter: I) -> Self { + ColumnData::Int64 { + data: iter.into_iter().collect(), + } + } +} + +impl FromIterator> for ColumnData { + fn from_iter>>(iter: I) -> Self { + ColumnData::Bool { + data: iter.into_iter().collect(), + } + } +} + +impl ColumnData { + fn data_type(&self) -> datafusion::DataType { + match self { + ColumnData::Utf8 { .. } => datafusion::DataType::Utf8, + ColumnData::Int64 { .. } => datafusion::DataType::Int64, + ColumnData::Bool { .. } => datafusion::DataType::Boolean, + } + } + fn into_array_ref(self) -> datafusion::ArrayRef { + match self { + ColumnData::Utf8 { data } => Arc::new(datafusion::StringArray::from(data)), + ColumnData::Int64 { data } => Arc::new(datafusion::Int64Array::from(data)), + ColumnData::Bool { data } => Arc::new(datafusion::BooleanArray::from(data)), + } + } +} + +/// A table with a fixed set rows, stored in memory. Like datafusion's MemTable but serializable +/// and deserializable +#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)] +pub(super) struct MemTable { + schema: datafusion::SchemaRef, + columns: Vec, +} + +impl MemTable { + /// Creates a new MemTable from an iterator of (field_name, nullable, ColumnData) tuples. + /// + /// # Arguments + /// + /// * `value` - An iterator yielding tuples of (field_name, nullable, ColumnData) + /// + /// # Returns + /// + /// A new MemTable instance + /// + /// Based on try_from_iter_with_nullable on RecordBatch + pub(crate) fn new_from_iter(value: I) -> Self + where + I: IntoIterator, + F: AsRef, + { + let iter = value.into_iter(); + let capacity = iter.size_hint().0; + let mut schema = datafusion::SchemaBuilder::with_capacity(capacity); + let mut columns = Vec::with_capacity(capacity); + + for (field_name, nullable, array) in iter { + let field_name = field_name.as_ref(); + schema.push(datafusion::Field::new( + field_name, + array.data_type().clone(), + nullable, + )); + columns.push(array); + } + + let schema = Arc::new(schema.finish()); + MemTable { schema, columns } + } +} + +#[async_trait] +impl datafusion::TableProvider for MemTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> datafusion::SchemaRef { + self.schema.clone() + } + + fn table_type(&self) -> datafusion::TableType { + datafusion::TableType::View + } + async fn scan( + &self, + _state: &datafusion::SessionState, + projection: Option<&Vec>, + // filters and limit can be used here to inject some push-down operations if needed + _filters: &[datafusion::Expr], + _limit: Option, + ) -> datafusion::Result> { + let projected_schema = Arc::new(self.schema.project(projection.unwrap_or(&vec![]))?); + let columnar_projection = projection + .unwrap_or(&vec![]) + .iter() + .map(|j| self.columns[*j].clone().into_array_ref()) + .collect::>(); + Ok(Arc::new(datafusion::ValuesExec::try_new_from_batches( + projected_schema.clone(), + vec![datafusion::RecordBatch::try_new( + projected_schema, + columnar_projection, + )?], + )?)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use ::datafusion::arrow::array::{Array, BooleanArray, Int64Array, StringArray}; + use ::datafusion::prelude::*; + + #[test] + fn test_column_data_utf8() { + let utf8_data = vec![Some("hello".to_string()), None, Some("world".to_string())]; + let column = ColumnData::from_iter(utf8_data); + let array_ref = column.into_array_ref(); + let string_array = array_ref.as_any().downcast_ref::().unwrap(); + assert_eq!(string_array.len(), 3); + assert_eq!(string_array.value(0), "hello"); + assert!(string_array.is_null(1)); + assert_eq!(string_array.value(2), "world"); + } + + #[test] + fn test_column_data_int64() { + let int64_data = vec![Some(1), None, Some(42)]; + let column = ColumnData::from_iter(int64_data); + let array_ref = column.into_array_ref(); + let int64_array = array_ref.as_any().downcast_ref::().unwrap(); + assert_eq!(int64_array.len(), 3); + assert_eq!(int64_array.value(0), 1); + assert!(int64_array.is_null(1)); + assert_eq!(int64_array.value(2), 42); + } + + #[test] + fn test_column_data_bool() { + let bool_data = vec![Some(true), None, Some(false)]; + let column = ColumnData::from_iter(bool_data); + let array_ref = column.into_array_ref(); + let bool_array = array_ref.as_any().downcast_ref::().unwrap(); + assert_eq!(bool_array.len(), 3); + assert!(bool_array.value(0)); + assert!(bool_array.is_null(1)); + assert!(!bool_array.value(2)); + } + + #[test] + fn test_column_data_empty() { + let empty_data: Vec> = vec![]; + let column = ColumnData::from_iter(empty_data); + let array_ref = column.into_array_ref(); + let string_array = array_ref.as_any().downcast_ref::().unwrap(); + assert_eq!(string_array.len(), 0); + } + + #[test] + fn test_column_data_all_null() { + let all_null_data: Vec> = vec![None, None, None]; + let column = ColumnData::from_iter(all_null_data); + let array_ref = column.into_array_ref(); + let string_array = array_ref.as_any().downcast_ref::().unwrap(); + assert_eq!(string_array.len(), 3); + assert!(string_array.is_null(0)); + assert!(string_array.is_null(1)); + assert!(string_array.is_null(2)); + } + + #[tokio::test] + async fn test_mem_table_provider_sql_full_scan() -> datafusion::Result<()> { + let ctx = create_test_context()?; + + let df = ctx.sql("SELECT * FROM test_table").await?; + let results = df.collect().await?; + + assert_eq!(results.len(), 1); // One batch + let batch = &results[0]; + assert_eq!(batch.num_columns(), 3); + assert_eq!(batch.num_rows(), 3); + + // Verify column data + let col1 = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col1.value(0), "a"); + assert!(col1.is_null(1)); + assert_eq!(col1.value(2), "c"); + + let col2 = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col2.value(0), 1); + assert_eq!(col2.value(1), 2); + assert_eq!(col2.value(2), 3); + + let col3 = batch + .column(2) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(col3.value(0)); + assert!(!col3.value(1)); + assert!(col3.is_null(2)); + + Ok(()) + } + + #[tokio::test] + async fn test_mem_table_provider_sql_projection() -> datafusion::Result<()> { + let ctx = create_test_context()?; + + let df = ctx.sql("SELECT col1, col3 FROM test_table").await?; + let results = df.collect().await?; + + assert_eq!(results.len(), 1); + let batch = &results[0]; + assert_eq!(batch.num_columns(), 2); + assert_eq!(batch.num_rows(), 3); + + // Verify projected columns + let col1 = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col1.value(0), "a"); + assert!(col1.is_null(1)); + assert_eq!(col1.value(2), "c"); + + let col3 = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(col3.value(0)); + assert!(!col3.value(1)); + assert!(col3.is_null(2)); + + Ok(()) + } + + #[tokio::test] + async fn test_mem_table_provider_sql_filter() -> datafusion::Result<()> { + let ctx = create_test_context()?; + + let df = ctx + .sql("SELECT col1, col2 FROM test_table WHERE col2 > 1") + .await?; + let results = df.collect().await?; + + assert_eq!(results.len(), 1); + let batch = &results[0]; + assert_eq!(batch.num_columns(), 2); + assert_eq!(batch.num_rows(), 2); + + let col1 = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert!(col1.is_null(0)); + assert_eq!(col1.value(1), "c"); + + let col2 = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(col2.value(0), 2); + assert_eq!(col2.value(1), 3); + + Ok(()) + } + + #[tokio::test] + async fn test_mem_table_provider_sql_aggregation() -> datafusion::Result<()> { + let ctx = create_test_context()?; + + let df = ctx + .sql("SELECT COUNT(*), SUM(col2) FROM test_table") + .await?; + let results = df.collect().await?; + + assert_eq!(results.len(), 1); + let batch = &results[0]; + assert_eq!(batch.num_columns(), 2); + assert_eq!(batch.num_rows(), 1); + + let count = batch + .column(0) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(count.value(0), 3); + + let sum = batch + .column(1) + .as_any() + .downcast_ref::() + .unwrap(); + assert_eq!(sum.value(0), 6); + + Ok(()) + } + + #[tokio::test] + async fn test_mem_table_provider_sql_empty() -> datafusion::Result<()> { + let ctx = SessionContext::new(); + let empty_table = MemTable::new_from_iter(vec![( + "empty_col".to_string(), + true, + ColumnData::from_iter(Vec::>::new()), + )]); + + ctx.register_table("empty_table", Arc::new(empty_table))?; + + let df = ctx.sql("SELECT * FROM empty_table").await?; + let results = df.collect().await?; + + assert_eq!(results.len(), 1); + let batch = &results[0]; + assert_eq!(batch.num_columns(), 1); + assert_eq!(batch.num_rows(), 0); + + Ok(()) + } + + fn create_test_context() -> datafusion::Result { + let ctx = SessionContext::new(); + let mem_table = MemTable::new_from_iter(vec![ + ( + "col1".to_string(), + true, + ColumnData::from_iter(vec![Some("a"), None, Some("c")]), + ), + ( + "col2".to_string(), + false, + ColumnData::from_iter(vec![Some(1), Some(2), Some(3)]), + ), + ( + "col3".to_string(), + true, + ColumnData::from_iter(vec![Some(true), Some(false), None]), + ), + ]); + + ctx.register_table("test_table", Arc::new(mem_table))?; + Ok(ctx) + } +}