diff --git a/CHANGELOG.md b/CHANGELOG.md index 9fe504aec9c..03ac454bba2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,7 @@ ### Bug fixes and improvements +- server: fix subscriptions with session argument in user-defined function (fix #6657) - server: MSSQL: Support ORDER BY for text/ntext types. - server: MSSQL: Support _lt, _eq, etc. for text/ntext types. - server: MSSQL: Fix offset when there's no order by. diff --git a/server/src-lib/Hasura/Backends/Postgres/Execute/LiveQuery.hs b/server/src-lib/Hasura/Backends/Postgres/Execute/LiveQuery.hs index da4125038ee..cb350a8cfbb 100644 --- a/server/src-lib/Hasura/Backends/Postgres/Execute/LiveQuery.hs +++ b/server/src-lib/Hasura/Backends/Postgres/Execute/LiveQuery.hs @@ -151,8 +151,8 @@ mkMultiplexedQuery rootFields = MultiplexedQuery . Q.fromBuilder . toSQL $ S.mkS -- about various parameters of the query along the way. resolveMultiplexedValue :: (MonadState (QueryParametersInfo ('Postgres pgKind)) m) - => UnpreparedValue ('Postgres pgKind) -> m S.SQLExp -resolveMultiplexedValue = \case + => SessionVariables -> UnpreparedValue ('Postgres pgKind) -> m S.SQLExp +resolveMultiplexedValue allSessionVars = \case UVParameter varM colVal -> do varJsonPath <- case fmap PS.getName varM of Just varName -> do @@ -167,7 +167,10 @@ resolveMultiplexedValue = \case modifying qpiReferencedSessionVariables (Set.insert sessVar) pure $ fromResVars ty ["session", sessionVariableToText sessVar] UVLiteral sqlExp -> pure sqlExp - UVSession -> pure $ fromResVars (CollectableTypeScalar PGJSON) ["session"] + UVSession -> do + -- if the entire session is referenced, then add all session vars in referenced vars + modifying qpiReferencedSessionVariables (const $ getSessionVariablesSet allSessionVars) + pure $ fromResVars (CollectableTypeScalar PGJSON) ["session"] where fromResVars pgType jPath = addTypeAnnotation pgType $ S.SEOpApp (S.SQLOp "#>>") [ S.SEQIdentifier $ S.QIdentifier (S.QualifiedIdentifier (Identifier "_subs") Nothing) (Identifier "result_vars") diff --git a/server/src-lib/Hasura/Backends/Postgres/Instances/Execute.hs b/server/src-lib/Hasura/Backends/Postgres/Instances/Execute.hs index 9699042c348..ab832a2e00b 100644 --- a/server/src-lib/Hasura/Backends/Postgres/Instances/Execute.hs +++ b/server/src-lib/Hasura/Backends/Postgres/Instances/Execute.hs @@ -272,7 +272,7 @@ pgDBSubscriptionPlan -> m (LiveQueryPlan ('Postgres pgKind) (MultiplexedQuery ('Postgres pgKind))) pgDBSubscriptionPlan userInfo _sourceName sourceConfig unpreparedAST = do (preparedAST, PGL.QueryParametersInfo{..}) <- flip runStateT mempty $ - for unpreparedAST $ traverseQueryDB PGL.resolveMultiplexedValue + for unpreparedAST $ traverseQueryDB (PGL.resolveMultiplexedValue $ _uiSession userInfo) let multiplexedQuery = PGL.mkMultiplexedQuery preparedAST roleName = _uiRole userInfo parameterizedPlan = ParameterizedLiveQueryPlan roleName multiplexedQuery diff --git a/server/tests-py/queries/subscriptions/udf_session_args/setup.yaml b/server/tests-py/queries/subscriptions/udf_session_args/setup.yaml new file mode 100644 index 00000000000..07983181fd0 --- /dev/null +++ b/server/tests-py/queries/subscriptions/udf_session_args/setup.yaml @@ -0,0 +1,40 @@ +type: bulk +args: + +- type: run_sql + args: + sql: | + CREATE TABLE profile ( + id TEXT, + name TEXT + ); + + INSERT INTO profile (id, name) VALUES ('10', 'Miles'), ('42', 'Charlie'); + + CREATE FUNCTION me(hasura_session json) + RETURNS SETOF profile AS $$ + SELECT * FROM profile + WHERE id = hasura_session ->> 'x-hasura-user-id' + $$ LANGUAGE sql STABLE; + +- type: track_table + args: + schema: public + name: profile + +- type: create_select_permission + args: + table: profile + role: user + permission: + columns: [id, name] + filter: {} + +- type: track_function + version: 2 + args: + function: + name: me + schema: public + configuration: + session_argument: hasura_session diff --git a/server/tests-py/queries/subscriptions/udf_session_args/teardown.yaml b/server/tests-py/queries/subscriptions/udf_session_args/teardown.yaml new file mode 100644 index 00000000000..73516c6ef7b --- /dev/null +++ b/server/tests-py/queries/subscriptions/udf_session_args/teardown.yaml @@ -0,0 +1,8 @@ +type: bulk +args: +- type: run_sql + args: + cascade: true + sql: | + drop function me(json) cascade; + drop table profile cascade; diff --git a/server/tests-py/test_subscriptions.py b/server/tests-py/test_subscriptions.py index f6ce301245c..e4f78d30353 100644 --- a/server/tests-py/test_subscriptions.py +++ b/server/tests-py/test_subscriptions.py @@ -314,3 +314,35 @@ class TestSubscriptionMultiplexing: sql = response['sql'] assert isinstance(sql, str), response return sql + + +@pytest.mark.parametrize("backend", ['mssql', 'postgres']) +@usefixtures('per_class_tests_db_state', 'per_backend_tests', 'ws_conn_init') +class TestSubscriptionUDFWithSessionArg: + """ + Test a user-defined function which uses the entire session variables as argument + """ + + query = """ + subscription { + me { + id + name + } + } + """ + + @classmethod + def dir(cls): + return 'queries/subscriptions/udf_session_args' + + def test_user_defined_function_with_session_argument(self, hge_ctx, ws_client): + ws_client.init_as_admin() + headers = {'x-hasura-role': 'user', 'x-hasura-user-id': '42'} + if hge_ctx.hge_key is not None: + headers['X-Hasura-Admin-Secret'] = hge_ctx.hge_key + payload = {'query': self.query} + resp = ws_client.send_query(payload, headers=headers, timeout=15) + ev = next(resp) + assert ev['type'] == 'data', ev + assert ev['payload']['data'] == {'me': [{'id': '42', 'name': 'Charlie'}]}, ev['payload']['data']