[PACHA-12] sql: fix failing introspection queries (#958)

### What

Introspection queries (on 'hasura' schema) would fail when there is no
data in the underlying tables.

### How

A more robust 'MemTable' with a comprehensive set of tests is introduced
which shouldn't run into these issues.

V3_GIT_ORIGIN_REV_ID: e09de03e8d093fb4348514cfed6b6dc1d9b0b0c8
This commit is contained in:
Vamshi Surabhi 2024-08-12 14:47:46 -07:00 committed by hasura-bot
parent 6f9e92c160
commit 173ec9a1e5
8 changed files with 833 additions and 260 deletions

1
v3/Cargo.lock generated
View File

@ -4411,6 +4411,7 @@ dependencies = [
"serde",
"serde_json",
"thiserror",
"tokio",
"tracing-util",
]

View File

@ -22,5 +22,8 @@ serde = { workspace = true, features = ["rc"] }
serde_json = { workspace = true }
thiserror = { workspace = true }
[dev-dependencies]
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
[lints]
workspace = true

View File

@ -15,6 +15,7 @@ mod datafusion {
}
pub mod introspection;
pub mod mem_table;
pub mod model;
pub mod subgraph;

View File

@ -3,37 +3,33 @@
use std::{any::Any, sync::Arc};
use async_trait::async_trait;
use column_metadata::{ColumnMetadata, ColumnMetadataRow, COLUMN_METADATA};
use foreign_keys::{InferredForeignKeys, InferredForeignKeysRow, INFERRED_FOREIGN_KEY_CONSTRAINTS};
use indexmap::IndexMap;
use metadata_resolve::{self as resolved, ModelRelationshipTarget};
use table_metadata::{TableMetadata, TableMetadataRow, TABLE_METADATA};
mod datafusion {
pub(super) use datafusion::{
arrow::{
array::RecordBatch,
datatypes::{DataType, Field, Schema, SchemaRef},
},
catalog::schema::SchemaProvider,
common::ScalarValue,
datasource::{TableProvider, TableType},
error::Result,
execution::context::SessionState,
logical_expr::Expr,
physical_plan::{values::ValuesExec, ExecutionPlan},
catalog::schema::SchemaProvider, datasource::TableProvider, error::Result,
};
}
use open_dds::relationships::RelationshipType;
use serde::{Deserialize, Serialize};
use super::mem_table::MemTable;
mod column_metadata;
mod foreign_keys;
mod table_metadata;
pub const HASURA_METADATA_SCHEMA: &str = "hasura";
pub const TABLE_METADATA: &str = "table_metadata";
pub const COLUMN_METADATA: &str = "column_metadata";
pub const INFERRED_FOREIGN_KEY_CONSTRAINTS: &str = "inferred_foreign_key_constraints";
/// Describes the database schema structure and metadata.
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub(crate) struct Introspection {
pub(crate) table_metadata: TableMetadata,
pub(crate) column_metadata: ColumnMetadata,
pub(crate) inferred_foreign_key_constraints: InferredatafusionoreignKeys,
table_metadata: TableMetadata,
column_metadata: ColumnMetadata,
inferred_foreign_key_constraints: InferredForeignKeys,
}
impl Introspection {
@ -47,18 +43,18 @@ impl Introspection {
let mut foreign_key_constraint_rows = Vec::new();
for (schema_name, schema) in schemas {
for (table_name, table) in &schema.tables {
table_metadata_rows.push(TableRow::new(
table_metadata_rows.push(TableMetadataRow::new(
schema_name.to_string(),
table_name.to_string(),
table.description.clone(),
));
for (column_name, column_description) in &table.columns {
column_metadata_rows.push(ColumnRow {
schema_name: schema_name.to_string(),
table_name: table_name.clone(),
column_name: column_name.clone(),
description: column_description.clone(),
});
column_metadata_rows.push(ColumnMetadataRow::new(
schema_name.to_string(),
table_name.clone(),
column_name.clone(),
column_description.clone(),
));
}
// TODO:
@ -77,14 +73,14 @@ impl Introspection {
) = &relationship.target
{
for mapping in mappings {
foreign_key_constraint_rows.push(ForeignKeyRow {
from_schema_name: schema_name.to_string(),
from_table_name: table_name.clone(),
from_column_name: mapping.source_field.field_name.to_string(),
to_schema_name: model_name.subgraph.to_string(),
to_table_name: model_name.name.to_string(),
to_column_name: mapping.target_field.field_name.to_string(),
});
foreign_key_constraint_rows.push(InferredForeignKeysRow::new(
schema_name.to_string(),
table_name.clone(),
mapping.source_field.field_name.to_string(),
model_name.subgraph.to_string(),
model_name.name.to_string(),
mapping.target_field.field_name.to_string(),
));
}
}
}
@ -94,197 +90,14 @@ impl Introspection {
Introspection {
table_metadata: TableMetadata::new(table_metadata_rows),
column_metadata: ColumnMetadata::new(column_metadata_rows),
inferred_foreign_key_constraints: InferredatafusionoreignKeys::new(
foreign_key_constraint_rows,
),
inferred_foreign_key_constraints: InferredForeignKeys::new(foreign_key_constraint_rows),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub(crate) struct TableMetadata {
schema: datafusion::SchemaRef,
rows: Vec<TableRow>,
}
impl TableMetadata {
pub(crate) fn new(rows: Vec<TableRow>) -> Self {
let schema_name = datafusion::Field::new("schema_name", datafusion::DataType::Utf8, false);
let table_name = datafusion::Field::new("table_name", datafusion::DataType::Utf8, false);
let description = datafusion::Field::new("description", datafusion::DataType::Utf8, true);
let schema = datafusion::SchemaRef::new(datafusion::Schema::new(vec![
schema_name,
table_name,
description,
]));
TableMetadata { schema, rows }
}
}
impl TableMetadata {
fn to_values_table(&self) -> ValuesTable {
ValuesTable {
schema: self.schema.clone(),
rows: self
.rows
.iter()
.map(|row| {
vec![
ScalarValue::Utf8(Some(row.schema_name.clone())),
ScalarValue::Utf8(Some(row.table_name.clone())),
ScalarValue::Utf8(row.description.clone()),
]
})
.collect(),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub(crate) struct TableRow {
schema_name: String,
table_name: String,
description: Option<String>,
}
impl TableRow {
pub(crate) fn new(
schema_name: String,
table_name: String,
description: Option<String>,
) -> Self {
Self {
schema_name,
table_name,
description,
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub(crate) struct ColumnMetadata {
pub(crate) schema: datafusion::SchemaRef,
pub(crate) rows: Vec<ColumnRow>,
}
impl ColumnMetadata {
fn new(rows: Vec<ColumnRow>) -> Self {
let schema_name = datafusion::Field::new("schema_name", datafusion::DataType::Utf8, false);
let table_name = datafusion::Field::new("table_name", datafusion::DataType::Utf8, false);
let column_name = datafusion::Field::new("column_name", datafusion::DataType::Utf8, false);
let description = datafusion::Field::new("description", datafusion::DataType::Utf8, true);
let schema = datafusion::SchemaRef::new(datafusion::Schema::new(vec![
schema_name,
table_name,
column_name,
description,
]));
ColumnMetadata { schema, rows }
}
fn to_values_table(&self) -> ValuesTable {
ValuesTable {
schema: self.schema.clone(),
rows: self
.rows
.iter()
.map(|row| {
vec![
ScalarValue::Utf8(Some(row.schema_name.clone())),
ScalarValue::Utf8(Some(row.table_name.clone())),
ScalarValue::Utf8(Some(row.column_name.clone())),
ScalarValue::Utf8(row.description.clone()),
]
})
.collect(),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub(crate) struct ColumnRow {
schema_name: String,
table_name: String,
column_name: String,
description: Option<String>,
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub(crate) struct InferredatafusionoreignKeys {
schema: datafusion::SchemaRef,
rows: Vec<ForeignKeyRow>,
}
impl InferredatafusionoreignKeys {
fn new(rows: Vec<ForeignKeyRow>) -> Self {
let from_schema_name =
datafusion::Field::new("from_schema_name", datafusion::DataType::Utf8, false);
let from_table_name =
datafusion::Field::new("from_table_name", datafusion::DataType::Utf8, false);
let from_column_name =
datafusion::Field::new("from_column_name", datafusion::DataType::Utf8, false);
let to_schema_name =
datafusion::Field::new("to_schema_name", datafusion::DataType::Utf8, false);
let to_table_name =
datafusion::Field::new("to_table_name", datafusion::DataType::Utf8, false);
let to_column_name =
datafusion::Field::new("to_column_name", datafusion::DataType::Utf8, false);
let schema = datafusion::SchemaRef::new(datafusion::Schema::new(vec![
from_schema_name,
from_table_name,
from_column_name,
to_schema_name,
to_table_name,
to_column_name,
]));
InferredatafusionoreignKeys { schema, rows }
}
fn to_values_table(&self) -> ValuesTable {
ValuesTable {
schema: self.schema.clone(),
rows: self
.rows
.iter()
.map(|row| {
vec![
ScalarValue::Utf8(Some(row.from_schema_name.clone())),
ScalarValue::Utf8(Some(row.from_table_name.clone())),
ScalarValue::Utf8(Some(row.from_column_name.clone())),
ScalarValue::Utf8(Some(row.to_schema_name.clone())),
ScalarValue::Utf8(Some(row.to_table_name.clone())),
ScalarValue::Utf8(Some(row.to_column_name.clone())),
]
})
.collect(),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
enum ScalarValue {
Utf8(Option<String>),
}
impl ScalarValue {
fn into_datafusion_scalar_value(self) -> datafusion::ScalarValue {
match self {
ScalarValue::Utf8(value) => datafusion::ScalarValue::Utf8(value),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
struct ForeignKeyRow {
from_schema_name: String,
from_table_name: String,
from_column_name: String,
to_schema_name: String,
to_table_name: String,
to_column_name: String,
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub(crate) struct IntrospectionSchemaProvider {
tables: IndexMap<String, Arc<ValuesTable>>,
tables: IndexMap<String, Arc<MemTable>>,
}
impl IntrospectionSchemaProvider {
@ -292,17 +105,17 @@ impl IntrospectionSchemaProvider {
let tables = [
(
TABLE_METADATA,
introspection.table_metadata.to_values_table(),
introspection.table_metadata.to_table_provider(),
),
(
COLUMN_METADATA,
introspection.column_metadata.to_values_table(),
introspection.column_metadata.to_table_provider(),
),
(
INFERRED_FOREIGN_KEY_CONSTRAINTS,
introspection
.inferred_foreign_key_constraints
.to_values_table(),
.to_table_provider(),
),
]
.into_iter()
@ -338,51 +151,145 @@ impl datafusion::SchemaProvider for IntrospectionSchemaProvider {
}
}
// A table with static rows
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
struct ValuesTable {
schema: datafusion::SchemaRef,
rows: Vec<Vec<ScalarValue>>,
}
#[cfg(test)]
mod tests {
use ::datafusion::catalog::{CatalogProvider, MemoryCatalogProvider, SchemaProvider};
#[async_trait]
impl datafusion::TableProvider for ValuesTable {
fn as_any(&self) -> &dyn Any {
self
use super::*;
use ::datafusion::prelude::*;
use std::sync::Arc;
fn create_test_introspection() -> Introspection {
let table_metadata = TableMetadata::new(vec![
TableMetadataRow::new(
"public".to_string(),
"users".to_string(),
Some("Users table".to_string()),
),
TableMetadataRow::new(
"public".to_string(),
"posts".to_string(),
Some("Posts table".to_string()),
),
]);
let column_metadata = ColumnMetadata::new(vec![
ColumnMetadataRow::new(
"public".to_string(),
"users".to_string(),
"id".to_string(),
Some("User ID".to_string()),
),
ColumnMetadataRow::new(
"public".to_string(),
"users".to_string(),
"name".to_string(),
Some("User name".to_string()),
),
ColumnMetadataRow::new(
"public".to_string(),
"posts".to_string(),
"id".to_string(),
Some("Post ID".to_string()),
),
ColumnMetadataRow::new(
"public".to_string(),
"posts".to_string(),
"user_id".to_string(),
Some("Author's user ID".to_string()),
),
]);
let inferred_foreign_keys = InferredForeignKeys::new(vec![InferredForeignKeysRow::new(
"public".to_string(),
"posts".to_string(),
"user_id".to_string(),
"public".to_string(),
"users".to_string(),
"id".to_string(),
)]);
Introspection {
table_metadata,
column_metadata,
inferred_foreign_key_constraints: inferred_foreign_keys,
}
}
fn schema(&self) -> datafusion::SchemaRef {
self.schema.clone()
#[tokio::test]
async fn test_introspection_schema_provider_table() {
let introspection = create_test_introspection();
let schema_provider = IntrospectionSchemaProvider::new(&introspection);
let table_metadata = schema_provider.table(TABLE_METADATA).await.unwrap();
assert!(table_metadata.is_some());
let column_metadata = schema_provider.table(COLUMN_METADATA).await.unwrap();
assert!(column_metadata.is_some());
let foreign_keys = schema_provider
.table(INFERRED_FOREIGN_KEY_CONSTRAINTS)
.await
.unwrap();
assert!(foreign_keys.is_some());
let non_existent_table = schema_provider.table("non_existent").await.unwrap();
assert!(non_existent_table.is_none());
}
fn table_type(&self) -> datafusion::TableType {
datafusion::TableType::View
// ... (keep the create_test_introspection function and other existing tests)
fn create_test_context(introspection: &Introspection) -> SessionContext {
let config = SessionConfig::new().with_default_catalog_and_schema("default", "default");
let schema_provider = Arc::new(IntrospectionSchemaProvider::new(introspection));
let ctx = SessionContext::new_with_config(config);
let catalog = MemoryCatalogProvider::new();
catalog
.register_schema(HASURA_METADATA_SCHEMA, schema_provider)
.unwrap();
ctx.register_catalog("default", Arc::new(catalog));
ctx
}
async fn scan(
&self,
_state: &datafusion::SessionState,
projection: Option<&Vec<usize>>,
// filters and limit can be used here to inject some push-down operations if needed
_filters: &[datafusion::Expr],
_limit: Option<usize>,
) -> datafusion::Result<Arc<dyn datafusion::ExecutionPlan>> {
let projected_schema = Arc::new(self.schema.project(projection.unwrap_or(&vec![]))?);
let columnar_projection = projection
.unwrap_or(&vec![])
.iter()
.map(|j| {
self.rows
.iter()
.map(|row| row[*j].clone().into_datafusion_scalar_value())
})
.map(datafusion::ScalarValue::iter_to_array)
.collect::<datafusion::Result<Vec<_>>>()?;
Ok(Arc::new(datafusion::ValuesExec::try_new_from_batches(
projected_schema.clone(),
vec![datafusion::RecordBatch::try_new(
projected_schema,
columnar_projection,
)?],
)?))
#[tokio::test]
async fn test_query_table_metadata() {
let introspection = create_test_introspection();
let ctx = create_test_context(&introspection);
let sql = format!("SELECT * FROM {HASURA_METADATA_SCHEMA}.{TABLE_METADATA}");
let df = ctx.sql(&sql).await.unwrap();
let results = df.collect().await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].num_rows(), 2);
}
#[tokio::test]
async fn test_query_column_metadata() {
let introspection = create_test_introspection();
let ctx = create_test_context(&introspection);
let sql = format!(
"SELECT * FROM {HASURA_METADATA_SCHEMA}.{COLUMN_METADATA} WHERE table_name = 'users'",
);
let df = ctx.sql(&sql).await.unwrap();
let results = df.collect().await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].num_rows(), 2);
}
#[tokio::test]
async fn test_query_inferred_foreign_keys() {
let introspection = create_test_introspection();
let ctx = create_test_context(&introspection);
let sql =
format!("SELECT * FROM {HASURA_METADATA_SCHEMA}.{INFERRED_FOREIGN_KEY_CONSTRAINTS}",);
let df = ctx.sql(&sql).await.unwrap();
let results = df.collect().await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].num_rows(), 1);
}
}

View File

@ -0,0 +1,79 @@
use serde::{Deserialize, Serialize};
use crate::catalog::mem_table::MemTable;
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub(crate) struct ColumnMetadataRow {
schema_name: String,
table_name: String,
column_name: String,
description: Option<String>,
}
impl ColumnMetadataRow {
pub(crate) fn new(
schema_name: String,
table_name: String,
column_name: String,
description: Option<String>,
) -> Self {
Self {
schema_name,
table_name,
column_name,
description,
}
}
}
pub const COLUMN_METADATA: &str = "column_metadata";
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub(super) struct ColumnMetadata {
pub(crate) rows: Vec<ColumnMetadataRow>,
}
impl ColumnMetadata {
pub(super) fn new(rows: Vec<ColumnMetadataRow>) -> Self {
ColumnMetadata { rows }
}
}
impl ColumnMetadata {
pub(super) fn to_table_provider(&self) -> MemTable {
MemTable::new_from_iter(vec![
(
"schema_name",
false,
self.rows
.iter()
.map(|row| Some(row.schema_name.clone()))
.collect(),
),
(
"table_name",
false,
self.rows
.iter()
.map(|row| Some(row.table_name.clone()))
.collect(),
),
(
"column_name",
false,
self.rows
.iter()
.map(|row| Some(row.column_name.clone()))
.collect(),
),
(
"description",
true,
self.rows
.iter()
.map(|row| row.description.clone())
.collect(),
),
])
}
}

View File

@ -0,0 +1,101 @@
use serde::{Deserialize, Serialize};
use crate::catalog::mem_table::MemTable;
pub(super) const INFERRED_FOREIGN_KEY_CONSTRAINTS: &str = "inferred_foreign_key_constraints";
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub(super) struct InferredForeignKeysRow {
from_schema_name: String,
from_table_name: String,
from_column_name: String,
to_schema_name: String,
to_table_name: String,
to_column_name: String,
}
impl InferredForeignKeysRow {
pub(super) fn new(
from_schema_name: String,
from_table_name: String,
from_column_name: String,
to_schema_name: String,
to_table_name: String,
to_column_name: String,
) -> Self {
Self {
from_schema_name,
from_table_name,
from_column_name,
to_schema_name,
to_table_name,
to_column_name,
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub(crate) struct InferredForeignKeys {
rows: Vec<InferredForeignKeysRow>,
}
impl InferredForeignKeys {
pub(super) fn new(rows: Vec<InferredForeignKeysRow>) -> Self {
InferredForeignKeys { rows }
}
}
impl InferredForeignKeys {
pub(super) fn to_table_provider(&self) -> MemTable {
MemTable::new_from_iter(vec![
(
"from_schema_name",
false,
self.rows
.iter()
.map(|row| Some(row.from_schema_name.clone()))
.collect(),
),
(
"from_table_name",
false,
self.rows
.iter()
.map(|row| Some(row.from_table_name.clone()))
.collect(),
),
(
"from_column_name",
false,
self.rows
.iter()
.map(|row| Some(row.from_column_name.clone()))
.collect(),
),
(
"to_schema_name",
false,
self.rows
.iter()
.map(|row| Some(row.to_schema_name.clone()))
.collect(),
),
(
"to_table_name",
false,
self.rows
.iter()
.map(|row| Some(row.to_table_name.clone()))
.collect(),
),
(
"to_column_name",
false,
self.rows
.iter()
.map(|row| Some(row.to_column_name.clone()))
.collect(),
),
])
}
}

View File

@ -0,0 +1,68 @@
use serde::{Deserialize, Serialize};
use crate::catalog::mem_table::MemTable;
pub const TABLE_METADATA: &str = "table_metadata";
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub(crate) struct TableMetadata {
rows: Vec<TableMetadataRow>,
}
impl TableMetadata {
pub(crate) fn new(rows: Vec<TableMetadataRow>) -> Self {
TableMetadata { rows }
}
}
impl TableMetadata {
pub(crate) fn to_table_provider(&self) -> MemTable {
MemTable::new_from_iter(vec![
(
"schema_name",
false,
self.rows
.iter()
.map(|row| Some(row.schema_name.clone()))
.collect(),
),
(
"table_name",
false,
self.rows
.iter()
.map(|row| Some(row.table_name.clone()))
.collect(),
),
(
"description",
true,
self.rows
.iter()
.map(|row| row.description.clone())
.collect(),
),
])
}
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub(crate) struct TableMetadataRow {
schema_name: String,
table_name: String,
description: Option<String>,
}
impl TableMetadataRow {
pub(crate) fn new(
schema_name: String,
table_name: String,
description: Option<String>,
) -> Self {
Self {
schema_name,
table_name,
description,
}
}
}

View File

@ -0,0 +1,413 @@
//! In-memory table implementation for DataFusion
//!
//! This module provides a serializable and deserializable in-memory table implementation.
//! Datafusion's built-in MemTable doesn't support the serializability property
use async_trait::async_trait;
use std::{any::Any, sync::Arc};
mod datafusion {
pub(super) use datafusion::{
arrow::{
array::{ArrayRef, BooleanArray, Int64Array, RecordBatch, StringArray},
datatypes::{DataType, Field, SchemaBuilder, SchemaRef},
},
datasource::{TableProvider, TableType},
error::Result,
execution::context::SessionState,
logical_expr::Expr,
physical_plan::{values::ValuesExec, ExecutionPlan},
};
}
use serde::{Deserialize, Serialize};
/// Represents the data for a single column in a table.
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub(crate) enum ColumnData {
Utf8 { data: Vec<Option<String>> },
Int64 { data: Vec<Option<i64>> },
Bool { data: Vec<Option<bool>> },
}
impl<'a> FromIterator<Option<&'a str>> for ColumnData {
fn from_iter<I: IntoIterator<Item = Option<&'a str>>>(iter: I) -> Self {
ColumnData::Utf8 {
data: iter
.into_iter()
.map(|s| s.map(std::borrow::ToOwned::to_owned))
.collect(),
}
}
}
impl FromIterator<Option<String>> for ColumnData {
fn from_iter<I: IntoIterator<Item = Option<String>>>(iter: I) -> Self {
ColumnData::Utf8 {
data: iter.into_iter().collect(),
}
}
}
impl FromIterator<Option<i64>> for ColumnData {
fn from_iter<I: IntoIterator<Item = Option<i64>>>(iter: I) -> Self {
ColumnData::Int64 {
data: iter.into_iter().collect(),
}
}
}
impl FromIterator<Option<bool>> for ColumnData {
fn from_iter<I: IntoIterator<Item = Option<bool>>>(iter: I) -> Self {
ColumnData::Bool {
data: iter.into_iter().collect(),
}
}
}
impl ColumnData {
fn data_type(&self) -> datafusion::DataType {
match self {
ColumnData::Utf8 { .. } => datafusion::DataType::Utf8,
ColumnData::Int64 { .. } => datafusion::DataType::Int64,
ColumnData::Bool { .. } => datafusion::DataType::Boolean,
}
}
fn into_array_ref(self) -> datafusion::ArrayRef {
match self {
ColumnData::Utf8 { data } => Arc::new(datafusion::StringArray::from(data)),
ColumnData::Int64 { data } => Arc::new(datafusion::Int64Array::from(data)),
ColumnData::Bool { data } => Arc::new(datafusion::BooleanArray::from(data)),
}
}
}
/// A table with a fixed set rows, stored in memory. Like datafusion's MemTable but serializable
/// and deserializable
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq)]
pub(super) struct MemTable {
schema: datafusion::SchemaRef,
columns: Vec<ColumnData>,
}
impl MemTable {
/// Creates a new MemTable from an iterator of (field_name, nullable, ColumnData) tuples.
///
/// # Arguments
///
/// * `value` - An iterator yielding tuples of (field_name, nullable, ColumnData)
///
/// # Returns
///
/// A new MemTable instance
///
/// Based on try_from_iter_with_nullable on RecordBatch
pub(crate) fn new_from_iter<I, F>(value: I) -> Self
where
I: IntoIterator<Item = (F, bool, ColumnData)>,
F: AsRef<str>,
{
let iter = value.into_iter();
let capacity = iter.size_hint().0;
let mut schema = datafusion::SchemaBuilder::with_capacity(capacity);
let mut columns = Vec::with_capacity(capacity);
for (field_name, nullable, array) in iter {
let field_name = field_name.as_ref();
schema.push(datafusion::Field::new(
field_name,
array.data_type().clone(),
nullable,
));
columns.push(array);
}
let schema = Arc::new(schema.finish());
MemTable { schema, columns }
}
}
#[async_trait]
impl datafusion::TableProvider for MemTable {
fn as_any(&self) -> &dyn Any {
self
}
fn schema(&self) -> datafusion::SchemaRef {
self.schema.clone()
}
fn table_type(&self) -> datafusion::TableType {
datafusion::TableType::View
}
async fn scan(
&self,
_state: &datafusion::SessionState,
projection: Option<&Vec<usize>>,
// filters and limit can be used here to inject some push-down operations if needed
_filters: &[datafusion::Expr],
_limit: Option<usize>,
) -> datafusion::Result<Arc<dyn datafusion::ExecutionPlan>> {
let projected_schema = Arc::new(self.schema.project(projection.unwrap_or(&vec![]))?);
let columnar_projection = projection
.unwrap_or(&vec![])
.iter()
.map(|j| self.columns[*j].clone().into_array_ref())
.collect::<Vec<_>>();
Ok(Arc::new(datafusion::ValuesExec::try_new_from_batches(
projected_schema.clone(),
vec![datafusion::RecordBatch::try_new(
projected_schema,
columnar_projection,
)?],
)?))
}
}
#[cfg(test)]
mod tests {
use super::*;
use ::datafusion::arrow::array::{Array, BooleanArray, Int64Array, StringArray};
use ::datafusion::prelude::*;
#[test]
fn test_column_data_utf8() {
let utf8_data = vec![Some("hello".to_string()), None, Some("world".to_string())];
let column = ColumnData::from_iter(utf8_data);
let array_ref = column.into_array_ref();
let string_array = array_ref.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(string_array.len(), 3);
assert_eq!(string_array.value(0), "hello");
assert!(string_array.is_null(1));
assert_eq!(string_array.value(2), "world");
}
#[test]
fn test_column_data_int64() {
let int64_data = vec![Some(1), None, Some(42)];
let column = ColumnData::from_iter(int64_data);
let array_ref = column.into_array_ref();
let int64_array = array_ref.as_any().downcast_ref::<Int64Array>().unwrap();
assert_eq!(int64_array.len(), 3);
assert_eq!(int64_array.value(0), 1);
assert!(int64_array.is_null(1));
assert_eq!(int64_array.value(2), 42);
}
#[test]
fn test_column_data_bool() {
let bool_data = vec![Some(true), None, Some(false)];
let column = ColumnData::from_iter(bool_data);
let array_ref = column.into_array_ref();
let bool_array = array_ref.as_any().downcast_ref::<BooleanArray>().unwrap();
assert_eq!(bool_array.len(), 3);
assert!(bool_array.value(0));
assert!(bool_array.is_null(1));
assert!(!bool_array.value(2));
}
#[test]
fn test_column_data_empty() {
let empty_data: Vec<Option<String>> = vec![];
let column = ColumnData::from_iter(empty_data);
let array_ref = column.into_array_ref();
let string_array = array_ref.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(string_array.len(), 0);
}
#[test]
fn test_column_data_all_null() {
let all_null_data: Vec<Option<String>> = vec![None, None, None];
let column = ColumnData::from_iter(all_null_data);
let array_ref = column.into_array_ref();
let string_array = array_ref.as_any().downcast_ref::<StringArray>().unwrap();
assert_eq!(string_array.len(), 3);
assert!(string_array.is_null(0));
assert!(string_array.is_null(1));
assert!(string_array.is_null(2));
}
#[tokio::test]
async fn test_mem_table_provider_sql_full_scan() -> datafusion::Result<()> {
let ctx = create_test_context()?;
let df = ctx.sql("SELECT * FROM test_table").await?;
let results = df.collect().await?;
assert_eq!(results.len(), 1); // One batch
let batch = &results[0];
assert_eq!(batch.num_columns(), 3);
assert_eq!(batch.num_rows(), 3);
// Verify column data
let col1 = batch
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(col1.value(0), "a");
assert!(col1.is_null(1));
assert_eq!(col1.value(2), "c");
let col2 = batch
.column(1)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(col2.value(0), 1);
assert_eq!(col2.value(1), 2);
assert_eq!(col2.value(2), 3);
let col3 = batch
.column(2)
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap();
assert!(col3.value(0));
assert!(!col3.value(1));
assert!(col3.is_null(2));
Ok(())
}
#[tokio::test]
async fn test_mem_table_provider_sql_projection() -> datafusion::Result<()> {
let ctx = create_test_context()?;
let df = ctx.sql("SELECT col1, col3 FROM test_table").await?;
let results = df.collect().await?;
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.num_columns(), 2);
assert_eq!(batch.num_rows(), 3);
// Verify projected columns
let col1 = batch
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert_eq!(col1.value(0), "a");
assert!(col1.is_null(1));
assert_eq!(col1.value(2), "c");
let col3 = batch
.column(1)
.as_any()
.downcast_ref::<BooleanArray>()
.unwrap();
assert!(col3.value(0));
assert!(!col3.value(1));
assert!(col3.is_null(2));
Ok(())
}
#[tokio::test]
async fn test_mem_table_provider_sql_filter() -> datafusion::Result<()> {
let ctx = create_test_context()?;
let df = ctx
.sql("SELECT col1, col2 FROM test_table WHERE col2 > 1")
.await?;
let results = df.collect().await?;
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.num_columns(), 2);
assert_eq!(batch.num_rows(), 2);
let col1 = batch
.column(0)
.as_any()
.downcast_ref::<StringArray>()
.unwrap();
assert!(col1.is_null(0));
assert_eq!(col1.value(1), "c");
let col2 = batch
.column(1)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(col2.value(0), 2);
assert_eq!(col2.value(1), 3);
Ok(())
}
#[tokio::test]
async fn test_mem_table_provider_sql_aggregation() -> datafusion::Result<()> {
let ctx = create_test_context()?;
let df = ctx
.sql("SELECT COUNT(*), SUM(col2) FROM test_table")
.await?;
let results = df.collect().await?;
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.num_columns(), 2);
assert_eq!(batch.num_rows(), 1);
let count = batch
.column(0)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(count.value(0), 3);
let sum = batch
.column(1)
.as_any()
.downcast_ref::<Int64Array>()
.unwrap();
assert_eq!(sum.value(0), 6);
Ok(())
}
#[tokio::test]
async fn test_mem_table_provider_sql_empty() -> datafusion::Result<()> {
let ctx = SessionContext::new();
let empty_table = MemTable::new_from_iter(vec![(
"empty_col".to_string(),
true,
ColumnData::from_iter(Vec::<Option<String>>::new()),
)]);
ctx.register_table("empty_table", Arc::new(empty_table))?;
let df = ctx.sql("SELECT * FROM empty_table").await?;
let results = df.collect().await?;
assert_eq!(results.len(), 1);
let batch = &results[0];
assert_eq!(batch.num_columns(), 1);
assert_eq!(batch.num_rows(), 0);
Ok(())
}
fn create_test_context() -> datafusion::Result<SessionContext> {
let ctx = SessionContext::new();
let mem_table = MemTable::new_from_iter(vec![
(
"col1".to_string(),
true,
ColumnData::from_iter(vec![Some("a"), None, Some("c")]),
),
(
"col2".to_string(),
false,
ColumnData::from_iter(vec![Some(1), Some(2), Some(3)]),
),
(
"col3".to_string(),
true,
ColumnData::from_iter(vec![Some(true), Some(false), None]),
),
]);
ctx.register_table("test_table", Arc::new(mem_table))?;
Ok(ctx)
}
}