From 38051b9465cf87d957857c0cf20b24b20823b960 Mon Sep 17 00:00:00 2001 From: Dan Sosedoff Date: Wed, 7 Dec 2022 11:58:07 -0600 Subject: [PATCH] Add support for user functions (#608) * Add initial support for functions * Show functions definitions * Fix client tests * Fix schema objects search * Perform partial matching for functions * Add function test * Make sure to close client connections so that database could be dropped in tests * Fix lint * Allow to copy the view/functions definitions * Nits --- Makefile | 2 +- pkg/api/api.go | 26 ++++-- pkg/api/routes.go | 1 + pkg/client/client.go | 4 + pkg/client/client_test.go | 141 ++++++++++++++++++++++++-------- pkg/client/result.go | 43 ++++++---- pkg/statements/sql.go | 3 + pkg/statements/sql/function.sql | 7 ++ pkg/statements/sql/objects.sql | 71 ++++++++++------ static/css/app.css | 24 ++++++ static/js/app.js | 62 +++++++++++--- 11 files changed, 292 insertions(+), 92 deletions(-) create mode 100644 pkg/statements/sql/function.sql diff --git a/Makefile b/Makefile index eb65d4c..2fb38ab 100644 --- a/Makefile +++ b/Makefile @@ -28,7 +28,7 @@ usage: @echo "" test: - go test -race -cover ./pkg/... + go test -v -race -cover ./pkg/... test-all: @./script/test_all.sh diff --git a/pkg/api/api.go b/pkg/api/api.go index ae76567..17a89a4 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -342,13 +342,21 @@ func GetSchemas(c *gin.Context) { // GetTable renders table information func GetTable(c *gin.Context) { - var res *client.Result - var err error + var ( + res *client.Result + err error + ) - if c.Request.FormValue("type") == client.ObjTypeMaterializedView { - res, err = DB(c).MaterializedView(c.Params.ByName("table")) - } else { - res, err = DB(c).Table(c.Params.ByName("table")) + db := DB(c) + tableName := c.Params.ByName("table") + + switch c.Request.FormValue("type") { + case client.ObjTypeMaterializedView: + res, err = db.MaterializedView(tableName) + case client.ObjTypeFunction: + res, err = db.Function(tableName) + default: + res, err = db.Table(tableName) } serveResult(c, res, err) @@ -541,3 +549,9 @@ func DataExport(c *gin.Context) { badRequest(c, err) } } + +// GetFunction renders function information +func GetFunction(c *gin.Context) { + res, err := DB(c).Function(c.Param("id")) + serveResult(c, res, err) +} diff --git a/pkg/api/routes.go b/pkg/api/routes.go index 9a83314..35048b1 100644 --- a/pkg/api/routes.go +++ b/pkg/api/routes.go @@ -42,6 +42,7 @@ func SetupRoutes(router *gin.Engine) { api.GET("/tables/:table/info", GetTableInfo) api.GET("/tables/:table/indexes", GetTableIndexes) api.GET("/tables/:table/constraints", GetTableConstraints) + api.GET("/functions/:id", GetFunction) api.GET("/query", RunQuery) api.POST("/query", RunQuery) api.GET("/explain", ExplainQuery) diff --git a/pkg/client/client.go b/pkg/client/client.go index 7497791..1085c2e 100644 --- a/pkg/client/client.go +++ b/pkg/client/client.go @@ -197,6 +197,10 @@ func (client *Client) MaterializedView(name string) (*Result, error) { return client.query(statements.MaterializedView, name) } +func (client *Client) Function(id string) (*Result, error) { + return client.query(statements.Function, id) +} + func (client *Client) TableRows(table string, opts RowsOptions) (*Result, error) { schema, table := getSchemaAndTable(table) sql := fmt.Sprintf(`SELECT * FROM "%s"."%s"`, schema, table) diff --git a/pkg/client/client_test.go b/pkg/client/client_test.go index f9e3909..6989db9 100644 --- a/pkg/client/client_test.go +++ b/pkg/client/client_test.go @@ -6,6 +6,7 @@ import ( "os" "os/exec" "runtime" + "sort" "testing" "time" @@ -32,6 +33,26 @@ func mapKeys(data map[string]*Objects) []string { return result } +func objectNames(data []Object) []string { + names := make([]string, len(data)) + for i, obj := range data { + names[i] = obj.Name + } + + sort.Strings(names) + return names +} + +// assertMatches is a helper method to check if src slice contains any elements of expected slice +func assertMatches(t *testing.T, expected, src []string) { + assert.NotEqual(t, 0, len(expected)) + assert.NotEqual(t, 0, len(src)) + + for _, val := range expected { + assert.Contains(t, src, val) + } +} + func pgVersion() (int, int) { var major, minor int fmt.Sscanf(os.Getenv("PGVERSION"), "%d.%d", &major, &minor) @@ -118,12 +139,12 @@ func setupClient() { func teardownClient() { if testClient != nil { - testClient.db.Close() + testClient.Close() } } func teardown() { - _, err := exec.Command( + output, err := exec.Command( testCommands["dropdb"], "-U", serverUser, "-h", serverHost, @@ -133,31 +154,28 @@ func teardown() { if err != nil { fmt.Println("Teardown error:", err) + fmt.Printf("%s\n", output) } } -func testNewClientFromUrl(t *testing.T) { - url := fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, serverDatabase) - client, err := NewFromUrl(url, nil) +func testNewClientFromURL(t *testing.T) { + t.Run("postgres prefix", func(t *testing.T) { + url := fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, serverDatabase) + client, err := NewFromUrl(url, nil) - if err != nil { - defer client.Close() - } + assert.Equal(t, nil, err) + assert.Equal(t, url, client.ConnectionString) + assert.NoError(t, client.Close()) + }) - assert.Equal(t, nil, err) - assert.Equal(t, url, client.ConnectionString) -} + t.Run("postgresql prefix", func(t *testing.T) { + url := fmt.Sprintf("postgresql://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, serverDatabase) + client, err := NewFromUrl(url, nil) -func testNewClientFromUrl2(t *testing.T) { - url := fmt.Sprintf("postgresql://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, serverDatabase) - client, err := NewFromUrl(url, nil) - - if err != nil { - defer client.Close() - } - - assert.Equal(t, nil, err) - assert.Equal(t, url, client.ConnectionString) + assert.Equal(t, nil, err) + assert.Equal(t, url, client.ConnectionString) + assert.NoError(t, client.Close()) + }) } func testClientIdleTime(t *testing.T) { @@ -202,16 +220,13 @@ func testActivity(t *testing.T) { res, err := testClient.Activity() assert.NoError(t, err) - for _, val := range expected { - assert.Contains(t, res.Columns, val) - } + assertMatches(t, expected, res.Columns) } func testDatabases(t *testing.T) { res, err := testClient.Databases() assert.NoError(t, err) - assert.Contains(t, res, "booktown") - assert.Contains(t, res, "postgres") + assertMatches(t, []string{"booktown", "postgres"}, res) } func testObjects(t *testing.T) { @@ -245,16 +260,44 @@ func testObjects(t *testing.T) { "text_sorting", } + functions := []string{ + "add_shipment", + "add_two_loop", + "books_by_subject", + "compound_word", + "count_by_two", + "double_price", + "extract_all_titles", + "extract_all_titles2", + "extract_title", + "first", + "get_author", + "get_author", + "get_customer_id", + "get_customer_name", + "html_linebreaks", + "in_stock", + "isbn_to_title", + "mixed", + "raise_test", + "ship_item", + "stock_amount", + "test", + "title", + "triple_price", + } + assert.NoError(t, err) - assert.Equal(t, []string{"schema", "name", "type", "owner", "comment"}, res.Columns) + assert.Equal(t, []string{"oid", "schema", "name", "type", "owner", "comment"}, res.Columns) assert.Equal(t, []string{"public"}, mapKeys(objects)) - assert.Equal(t, tables, objects["public"].Tables) - assert.Equal(t, []string{"recent_shipments", "stock_view"}, objects["public"].Views) - assert.Equal(t, []string{"author_ids", "book_ids", "shipments_ship_id_seq", "subject_ids"}, objects["public"].Sequences) + assert.Equal(t, tables, objectNames(objects["public"].Tables)) + assertMatches(t, functions, objectNames(objects["public"].Functions)) + assert.Equal(t, []string{"recent_shipments", "stock_view"}, objectNames(objects["public"].Views)) + assert.Equal(t, []string{"author_ids", "book_ids", "shipments_ship_id_seq", "subject_ids"}, objectNames(objects["public"].Sequences)) major, minor := pgVersion() if minor == 0 || minor >= 3 { - assert.Equal(t, []string{"m_stock_view"}, objects["public"].MaterializedViews) + assert.Equal(t, []string{"m_stock_view"}, objectNames(objects["public"].MaterializedViews)) } else { t.Logf("Skipping materialized view on %d.%d\n", major, minor) } @@ -428,6 +471,33 @@ func testTableRowsOrderEscape(t *testing.T) { assert.Nil(t, rows) } +func testFunctions(t *testing.T) { + funcName := "get_customer_name" + funcID := "" + + res, err := testClient.Objects() + assert.NoError(t, err) + + for _, row := range res.Rows { + if row[2] == funcName { + funcID = row[0].(string) + break + } + } + + res, err = testClient.Function("12345") + assert.NoError(t, err) + assertMatches(t, []string{"oid", "proname", "functiondef"}, res.Columns) + assert.Equal(t, 0, len(res.Rows)) + + res, err = testClient.Function(funcID) + assert.NoError(t, err) + assertMatches(t, []string{"oid", "proname", "functiondef"}, res.Columns) + assert.Equal(t, 1, len(res.Rows)) + assert.Equal(t, funcName, res.Rows[0][1]) + assert.Contains(t, res.Rows[0][len(res.Columns)-1], "SELECT INTO customer_fname, customer_lname") +} + func testResult(t *testing.T) { t.Run("json", func(t *testing.T) { result, err := testClient.Query("SELECT * FROM books LIMIT 1") @@ -466,8 +536,8 @@ func testHistory(t *testing.T) { t.Run("unique queries", func(t *testing.T) { url := fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, serverDatabase) - client, err := NewFromUrl(url, nil) - assert.NoError(t, err) + client, _ := NewFromUrl(url, nil) + defer client.Close() for i := 0; i < 3; i++ { _, err := client.Query("SELECT * FROM books WHERE id = 1") @@ -487,6 +557,7 @@ func testReadOnlyMode(t *testing.T) { url := fmt.Sprintf("postgres://%s@%s:%s/%s?sslmode=disable", serverUser, serverHost, serverPort, serverDatabase) client, _ := NewFromUrl(url, nil) + defer client.Close() err := client.SetReadOnlyMode() assert.NoError(t, err) @@ -522,8 +593,7 @@ func TestAll(t *testing.T) { setup() setupClient() - testNewClientFromUrl(t) - testNewClientFromUrl2(t) + testNewClientFromURL(t) testClientIdleTime(t) testTest(t) testInfo(t) @@ -544,6 +614,7 @@ func TestAll(t *testing.T) { testQueryError(t) testQueryInvalidTable(t) testTableRowsOrderEscape(t) + testFunctions(t) testResult(t) testHistory(t) testReadOnlyMode(t) diff --git a/pkg/client/result.go b/pkg/client/result.go index ee374a1..f446a02 100644 --- a/pkg/client/result.go +++ b/pkg/client/result.go @@ -18,6 +18,7 @@ const ( ObjTypeView = "view" ObjTypeMaterializedView = "materialized_view" ObjTypeSequence = "sequence" + ObjTypeFunction = "function" ) type ( @@ -36,11 +37,17 @@ type ( Rows []Row `json:"rows"` } + Object struct { + OID string `json:"oid"` + Name string `json:"name"` + } + Objects struct { - Tables []string `json:"table"` - Views []string `json:"view"` - MaterializedViews []string `json:"materialized_view"` - Sequences []string `json:"sequence"` + Tables []Object `json:"table"` + Views []Object `json:"view"` + MaterializedViews []Object `json:"materialized_view"` + Functions []Object `json:"function"` + Sequences []Object `json:"sequence"` } ) @@ -154,28 +161,34 @@ func ObjectsFromResult(res *Result) map[string]*Objects { objects := map[string]*Objects{} for _, row := range res.Rows { - schema := row[0].(string) - name := row[1].(string) - objectType := row[2].(string) + oid := row[0].(string) + schema := row[1].(string) + name := row[2].(string) + objectType := row[3].(string) if objects[schema] == nil { objects[schema] = &Objects{ - Tables: []string{}, - Views: []string{}, - MaterializedViews: []string{}, - Sequences: []string{}, + Tables: []Object{}, + Views: []Object{}, + MaterializedViews: []Object{}, + Functions: []Object{}, + Sequences: []Object{}, } } + obj := Object{OID: oid, Name: name} + switch objectType { case ObjTypeTable: - objects[schema].Tables = append(objects[schema].Tables, name) + objects[schema].Tables = append(objects[schema].Tables, obj) case ObjTypeView: - objects[schema].Views = append(objects[schema].Views, name) + objects[schema].Views = append(objects[schema].Views, obj) case ObjTypeMaterializedView: - objects[schema].MaterializedViews = append(objects[schema].MaterializedViews, name) + objects[schema].MaterializedViews = append(objects[schema].MaterializedViews, obj) + case ObjTypeFunction: + objects[schema].Functions = append(objects[schema].Functions, obj) case ObjTypeSequence: - objects[schema].Sequences = append(objects[schema].Sequences, name) + objects[schema].Sequences = append(objects[schema].Sequences, obj) } } diff --git a/pkg/statements/sql.go b/pkg/statements/sql.go index 1dd71ba..d6b2f3f 100644 --- a/pkg/statements/sql.go +++ b/pkg/statements/sql.go @@ -38,6 +38,9 @@ var ( //go:embed sql/objects.sql Objects string + //go:embed sql/function.sql + Function string + // Activity queries for specific PG versions Activity = map[string]string{ "default": "SELECT * FROM pg_stat_activity WHERE datname = current_database()", diff --git a/pkg/statements/sql/function.sql b/pkg/statements/sql/function.sql new file mode 100644 index 0000000..4eede75 --- /dev/null +++ b/pkg/statements/sql/function.sql @@ -0,0 +1,7 @@ +SELECT + p.*, + pg_get_functiondef(oid) AS functiondef +FROM + pg_catalog.pg_proc p +WHERE + oid = $1::oid diff --git a/pkg/statements/sql/objects.sql b/pkg/statements/sql/objects.sql index de1ba0f..43e73d6 100644 --- a/pkg/statements/sql/objects.sql +++ b/pkg/statements/sql/objects.sql @@ -1,25 +1,46 @@ -SELECT - n.nspname AS schema, - c.relname AS name, - CASE c.relkind - WHEN 'r' THEN 'table' - WHEN 'v' THEN 'view' - WHEN 'm' THEN 'materialized_view' - WHEN 'i' THEN 'index' - WHEN 'S' THEN 'sequence' - WHEN 's' THEN 'special' - WHEN 'f' THEN 'foreign_table' - END AS type, - pg_catalog.pg_get_userbyid(c.relowner) AS owner, - pg_catalog.obj_description(c.oid) AS comment -FROM - pg_catalog.pg_class c -LEFT JOIN - pg_catalog.pg_namespace n ON n.oid = c.relnamespace -WHERE - c.relkind IN ('r','v','m','S','s','') - AND n.nspname !~ '^pg_toast' - AND n.nspname NOT IN ('information_schema', 'pg_catalog') - AND has_schema_privilege(n.nspname, 'USAGE') -ORDER BY - 1, 2 +WITH all_objects AS ( + SELECT + c.oid, + n.nspname AS schema, + c.relname AS name, + CASE c.relkind + WHEN 'r' THEN 'table' + WHEN 'v' THEN 'view' + WHEN 'm' THEN 'materialized_view' + WHEN 'i' THEN 'index' + WHEN 'S' THEN 'sequence' + WHEN 's' THEN 'special' + WHEN 'f' THEN 'foreign_table' + END AS type, + pg_catalog.pg_get_userbyid(c.relowner) AS owner, + pg_catalog.obj_description(c.oid) AS comment + FROM + pg_catalog.pg_class c + LEFT JOIN + pg_catalog.pg_namespace n ON n.oid = c.relnamespace + WHERE + c.relkind IN ('r','v','m','S','s','') + AND n.nspname !~ '^pg_toast' + AND n.nspname NOT IN ('information_schema', 'pg_catalog') + AND has_schema_privilege(n.nspname, 'USAGE') + + UNION + + SELECT + p.oid, + n.nspname AS schema, + p.proname AS name, + 'function' AS function, + pg_catalog.pg_get_userbyid(p.proowner) AS owner, + NULL AS comment + FROM + pg_catalog.pg_namespace n + JOIN + pg_catalog.pg_proc p ON p.pronamespace = n.oid + WHERE + n.nspname !~ '^pg_toast' + AND n.nspname NOT IN ('information_schema', 'pg_catalog') + AND p.prokind = 'f' +) +SELECT * FROM all_objects +ORDER BY 1, 2 diff --git a/static/css/app.css b/static/css/app.css index 59a1077..72a57e2 100644 --- a/static/css/app.css +++ b/static/css/app.css @@ -583,6 +583,30 @@ #results_view pre { border: 0px none; + position: relative; +} + +#results_view .copy { + position: absolute; + display: none; + text-align: center; + line-height: 30px; + right: 4px; + top: 4px; + width: 30px; + height: 30px; + background: #fff; + border: 1px solid #ddd; + border-radius: 3px; + cursor: pointer; +} + +#results_view .copy:hover { + border-color: #999; +} + +#results_view pre:hover .copy { + display: block; } .full #output { diff --git a/static/js/app.js b/static/js/app.js index af2e7cf..cf74617 100644 --- a/static/js/app.js +++ b/static/js/app.js @@ -88,6 +88,7 @@ function getTableRows(table, opts, cb) { apiCall("get", "/tables/" + table function getTableStructure(table, opts, cb) { apiCall("get", "/tables/" + table, opts, cb); } function getTableIndexes(table, cb) { apiCall("get", "/tables/" + table + "/indexes", {}, cb); } function getTableConstraints(table, cb) { apiCall("get", "/tables/" + table + "/constraints", {}, cb); } +function getFunction(id, cb) { apiCall("get", "/functions/" + id, {}, cb); } function getHistory(cb) { apiCall("get", "/history", {}, cb); } function getBookmarks(cb) { apiCall("get", "/bookmarks", {}, cb); } function executeQuery(query, cb) { apiCall("post", "/query", { query: query }, cb); } @@ -106,6 +107,7 @@ function buildSchemaSection(name, objects) { "table": "Tables", "view": "Views", "materialized_view": "Materialized Views", + "function": "Functions", "sequence": "Sequences" }; @@ -113,6 +115,7 @@ function buildSchemaSection(name, objects) { "table": '', "view": '', "materialized_view": '', + "function": '', "sequence": '' }; @@ -123,7 +126,7 @@ function buildSchemaSection(name, objects) { section += "
" + name + "
"; section += "
"; - ["table", "view", "materialized_view", "sequence"].forEach(function(group) { + ["table", "view", "materialized_view", "function", "sequence"].forEach(function(group) { group_klass = ""; if (name == "public" && group == "table") group_klass = "expanded"; @@ -133,8 +136,14 @@ function buildSchemaSection(name, objects) { if (objects[group]) { objects[group].forEach(function(item) { - var id = name + "." + item; - section += "
  • " + icons[group] + " " + item + "
  • "; + var id = name + "." + item.name; + + // Use function OID since multiple functions with the same name might exist + if (group == "function") { + id = item.oid; + } + + section += "
  • " + icons[group] + " " + item.name + "
  • "; }); section += "
    "; } @@ -154,6 +163,7 @@ function loadSchemas() { table: [], view: [], materialized_view: [], + function: [], sequence: [] }; } @@ -170,13 +180,14 @@ function loadSchemas() { autocompleteObjects = []; for (schema in data) { for (kind in data[schema]) { - if (!(kind == "table" || kind == "view" || kind == "materialized_view")) { + if (!(kind == "table" || kind == "view" || kind == "materialized_view" || kind == "function")) { continue } + for (item in data[schema][kind]) { autocompleteObjects.push({ - caption: data[schema][kind][item], - value: data[schema][kind][item], + caption: data[schema][kind][item].name, + value: data[schema][kind][item].name, meta: kind }); } @@ -507,6 +518,11 @@ function showTableContent(sortColumn, sortOrder) { return; } + if (getCurrentObject().type == "function") { + alert("Cant view rows for a function"); + return; + } + var opts = { limit: getRowsLimit(), offset: getPaginationOffset(), @@ -569,6 +585,13 @@ function showTableStructure() { $("#body").prop("class", "full"); getTableStructure(name, { type: getCurrentObject().type }, function(data) { + if (getCurrentObject().type == "function") { + var name = data.rows[0][data.columns.indexOf("proname")]; + var definition = data.rows[0][data.columns.indexOf("functiondef")]; + showFunctionDefinition(name, definition); + return + } + buildTable(data); $("#results").addClass("no-crop"); }); @@ -576,14 +599,25 @@ function showTableStructure() { function showViewDefinition(viewName, viewDefintion) { setCurrentTab("table_structure"); + renderResultsView("View definition for: " + viewName + "", viewDefintion); +} +function showFunctionDefinition(functionName, definition) { + setCurrentTab("table_structure"); + renderResultsView("Function definition for: " + functionName + "", definition) +} + +function renderResultsView(title, content) { $("#results").addClass("no-crop"); $("#input").hide(); $("#body").prop("class", "full"); $("#results").hide(); - var title = $("
    ").prop("class", "title").html("View definition for: " + viewName + ""); - var content = $("
    ").text(viewDefintion);
    +  var title = $("
    ").prop("class", "title").html(title); + var content = $("
    ").text(content);
    +  console.log(content);
    +
    +  $("
    ").html("").addClass("copy").appendTo(content); $("#results_view").html(""); title.appendTo("#results_view"); @@ -1138,7 +1172,7 @@ function bindDatabaseObjectsFilter() { $(".clear-objects-filter").show(); $(".schema-group").addClass("expanded"); - filterTimeout = setTimeout(function () { + filterTimeout = setTimeout(function() { filterObjectsByName(val) }, 200); }); @@ -1354,6 +1388,10 @@ $(document).ready(function() { exportTo("xml"); }); + $("#results_view").on("click", ".copy", function() { + copyToClipboard($(this).parent().text()); + }); + $("#results").on("click", "tr", function(e) { $("#results tr.selected").removeClass(); $(this).addClass("selected"); @@ -1378,7 +1416,11 @@ $(document).ready(function() { $(".current-page").data("page", 1); $(".filters select, .filters input").val(""); - showTableInfo(); + if (currentObject.type == "function") { + sessionStorage.setItem("tab", "table_structure"); + } else { + showTableInfo(); + } switch(sessionStorage.getItem("tab")) { case "table_content":