Merge pull request #843 from binwiederhier/acl-underscores

Fix ACL issues with underscores
This commit is contained in:
Philipp C. Heckel 2023-08-18 22:52:01 +02:00 committed by GitHub
commit 63629efae7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 255 additions and 18 deletions

View File

@ -2,7 +2,7 @@
Binaries for all releases can be found on the GitHub releases pages for the [ntfy server](https://github.com/binwiederhier/ntfy/releases) Binaries for all releases can be found on the GitHub releases pages for the [ntfy server](https://github.com/binwiederhier/ntfy/releases)
and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/releases). and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/releases).
### ntfy server v2.7.0 ## ntfy server v2.7.0
Released August 17, 2023 Released August 17, 2023
This release ships Markdown support for the web app (not in the Android app yet), and adds support for This release ships Markdown support for the web app (not in the Android app yet), and adds support for
@ -1283,6 +1283,12 @@ and the [ntfy Android app](https://github.com/binwiederhier/ntfy-android/release
## Not released yet ## Not released yet
### ntfy server v2.8.0 (UNRELEASED)
**Bug fixes + maintenance:**
* Fix ACL issue with topic patterns containing underscores ([#840](https://github.com/binwiederhier/ntfy/issues/840), thanks to [@Joe-0237](https://github.com/Joe-0237) for reporting)
### ntfy Android app v1.16.1 (UNRELEASED) ### ntfy Android app v1.16.1 (UNRELEASED)
**Features:** **Features:**

View File

@ -160,7 +160,7 @@ const (
SELECT read, write SELECT read, write
FROM user_access a FROM user_access a
JOIN user u ON u.id = a.user_id JOIN user u ON u.id = a.user_id
WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic WHERE (u.user = ? OR u.user = ?) AND ? LIKE a.topic ESCAPE '\'
ORDER BY u.user DESC ORDER BY u.user DESC
` `
@ -235,7 +235,7 @@ const (
selectOtherAccessCountQuery = ` selectOtherAccessCountQuery = `
SELECT COUNT(*) SELECT COUNT(*)
FROM user_access FROM user_access
WHERE (topic = ? OR ? LIKE topic) WHERE (topic = ? OR ? LIKE topic ESCAPE '\')
AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?)) AND (owner_user_id IS NULL OR owner_user_id != (SELECT id FROM user WHERE user = ?))
` `
deleteAllAccessQuery = `DELETE FROM user_access` deleteAllAccessQuery = `DELETE FROM user_access`
@ -312,7 +312,7 @@ const (
// Schema management queries // Schema management queries
const ( const (
currentSchemaVersion = 4 currentSchemaVersion = 5
insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)` insertSchemaVersion = `INSERT INTO schemaVersion VALUES (1, ?)`
updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1` updateSchemaVersion = `UPDATE schemaVersion SET version = ? WHERE id = 1`
selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1` selectSchemaVersionQuery = `SELECT version FROM schemaVersion WHERE id = 1`
@ -422,6 +422,11 @@ const (
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
); );
` `
// 4 -> 5
migrate4To5UpdateQueries = `
UPDATE user_access SET topic = REPLACE(topic, '_', '\_');
`
) )
var ( var (
@ -429,6 +434,7 @@ var (
1: migrateFrom1, 1: migrateFrom1,
2: migrateFrom2, 2: migrateFrom2,
3: migrateFrom3, 3: migrateFrom3,
4: migrateFrom4,
} }
) )
@ -1123,7 +1129,7 @@ func (a *Manager) Reservations(username string) ([]Reservation, error) {
return nil, err return nil, err
} }
reservations = append(reservations, Reservation{ reservations = append(reservations, Reservation{
Topic: topic, Topic: unescapeUnderscore(topic),
Owner: NewPermission(ownerRead, ownerWrite), Owner: NewPermission(ownerRead, ownerWrite),
Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), // false if null Everyone: NewPermission(everyoneRead.Bool, everyoneWrite.Bool), // false if null
}) })
@ -1133,7 +1139,7 @@ func (a *Manager) Reservations(username string) ([]Reservation, error) {
// HasReservation returns true if the given topic access is owned by the user // HasReservation returns true if the given topic access is owned by the user
func (a *Manager) HasReservation(username, topic string) (bool, error) { func (a *Manager) HasReservation(username, topic string) (bool, error) {
rows, err := a.db.Query(selectUserHasReservationQuery, username, topic) rows, err := a.db.Query(selectUserHasReservationQuery, username, escapeUnderscore(topic))
if err != nil { if err != nil {
return false, err return false, err
} }
@ -1168,7 +1174,7 @@ func (a *Manager) ReservationsCount(username string) (int64, error) {
// ReservationOwner returns user ID of the user that owns this topic, or an // ReservationOwner returns user ID of the user that owns this topic, or an
// empty string if it's not owned by anyone // empty string if it's not owned by anyone
func (a *Manager) ReservationOwner(topic string) (string, error) { func (a *Manager) ReservationOwner(topic string) (string, error) {
rows, err := a.db.Query(selectUserReservationsOwnerQuery, topic) rows, err := a.db.Query(selectUserReservationsOwnerQuery, escapeUnderscore(topic))
if err != nil { if err != nil {
return "", err return "", err
} }
@ -1263,7 +1269,7 @@ func (a *Manager) AllowReservation(username string, topic string) error {
if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) { if (!AllowedUsername(username) && username != Everyone) || !AllowedTopic(topic) {
return ErrInvalidArgument return ErrInvalidArgument
} }
rows, err := a.db.Query(selectOtherAccessCountQuery, topic, topic, username) rows, err := a.db.Query(selectOtherAccessCountQuery, escapeUnderscore(topic), escapeUnderscore(topic), username)
if err != nil { if err != nil {
return err return err
} }
@ -1328,10 +1334,10 @@ func (a *Manager) AddReservation(username string, topic string, everyone Permiss
return err return err
} }
defer tx.Rollback() defer tx.Rollback()
if _, err := tx.Exec(upsertUserAccessQuery, username, topic, true, true, username, username); err != nil { if _, err := tx.Exec(upsertUserAccessQuery, username, escapeUnderscore(topic), true, true, username, username); err != nil {
return err return err
} }
if _, err := tx.Exec(upsertUserAccessQuery, Everyone, topic, everyone.IsRead(), everyone.IsWrite(), username, username); err != nil { if _, err := tx.Exec(upsertUserAccessQuery, Everyone, escapeUnderscore(topic), everyone.IsRead(), everyone.IsWrite(), username, username); err != nil {
return err return err
} }
return tx.Commit() return tx.Commit()
@ -1354,10 +1360,10 @@ func (a *Manager) RemoveReservations(username string, topics ...string) error {
} }
defer tx.Rollback() defer tx.Rollback()
for _, topic := range topics { for _, topic := range topics {
if _, err := tx.Exec(deleteTopicAccessQuery, username, username, topic); err != nil { if _, err := tx.Exec(deleteTopicAccessQuery, username, username, escapeUnderscore(topic)); err != nil {
return err return err
} }
if _, err := tx.Exec(deleteTopicAccessQuery, Everyone, Everyone, topic); err != nil { if _, err := tx.Exec(deleteTopicAccessQuery, Everyone, Everyone, escapeUnderscore(topic)); err != nil {
return err return err
} }
} }
@ -1484,12 +1490,24 @@ func (a *Manager) Close() error {
return a.db.Close() return a.db.Close()
} }
// toSQLWildcard converts a wildcard string to a SQL wildcard string. It only allows '*' as wildcards,
// and escapes '_', assuming '\' as escape character.
func toSQLWildcard(s string) string { func toSQLWildcard(s string) string {
return strings.ReplaceAll(s, "*", "%") return escapeUnderscore(strings.ReplaceAll(s, "*", "%"))
} }
// fromSQLWildcard converts a SQL wildcard string to a wildcard string. It converts '%' to '*',
// and removes the '\_' escape character.
func fromSQLWildcard(s string) string { func fromSQLWildcard(s string) string {
return strings.ReplaceAll(s, "%", "*") return strings.ReplaceAll(unescapeUnderscore(s), "%", "*")
}
func escapeUnderscore(s string) string {
return strings.ReplaceAll(s, "_", "\\_")
}
func unescapeUnderscore(s string) string {
return strings.ReplaceAll(s, "\\_", "_")
} }
func runStartupQueries(db *sql.DB, startupQueries string) error { func runStartupQueries(db *sql.DB, startupQueries string) error {
@ -1627,6 +1645,22 @@ func migrateFrom3(db *sql.DB) error {
return tx.Commit() return tx.Commit()
} }
func migrateFrom4(db *sql.DB) error {
log.Tag(tag).Info("Migrating user database schema: from 4 to 5")
tx, err := db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.Exec(migrate4To5UpdateQueries); err != nil {
return err
}
if _, err := tx.Exec(updateSchemaVersion, 5); err != nil {
return err
}
return tx.Commit()
}
func nullString(s string) sql.NullString { func nullString(s string) sql.NullString {
if s == "" { if s == "" {
return sql.NullString{} return sql.NullString{}

View File

@ -330,7 +330,7 @@ func TestManager_Reservations(t *testing.T) {
a := newTestManager(t, PermissionDenyAll) a := newTestManager(t, PermissionDenyAll)
require.Nil(t, a.AddUser("phil", "phil", RoleUser)) require.Nil(t, a.AddUser("phil", "phil", RoleUser))
require.Nil(t, a.AddUser("ben", "ben", RoleUser)) require.Nil(t, a.AddUser("ben", "ben", RoleUser))
require.Nil(t, a.AddReservation("ben", "ztopic", PermissionDenyAll)) require.Nil(t, a.AddReservation("ben", "ztopic_", PermissionDenyAll))
require.Nil(t, a.AddReservation("ben", "readme", PermissionRead)) require.Nil(t, a.AddReservation("ben", "readme", PermissionRead))
require.Nil(t, a.AllowAccess("ben", "something-else", PermissionRead)) require.Nil(t, a.AllowAccess("ben", "something-else", PermissionRead))
@ -343,7 +343,7 @@ func TestManager_Reservations(t *testing.T) {
Everyone: PermissionRead, Everyone: PermissionRead,
}, reservations[0]) }, reservations[0])
require.Equal(t, Reservation{ require.Equal(t, Reservation{
Topic: "ztopic", Topic: "ztopic_",
Owner: PermissionReadWrite, Owner: PermissionReadWrite,
Everyone: PermissionDenyAll, Everyone: PermissionDenyAll,
}, reservations[1]) }, reservations[1])
@ -352,6 +352,14 @@ func TestManager_Reservations(t *testing.T) {
require.Nil(t, err) require.Nil(t, err)
require.True(t, b) require.True(t, b)
b, err = a.HasReservation("ben", "ztopic_")
require.Nil(t, err)
require.True(t, b)
b, err = a.HasReservation("ben", "ztopicX") // _ != X (used to be a SQL wildcard issue)
require.Nil(t, err)
require.False(t, b)
b, err = a.HasReservation("notben", "readme") b, err = a.HasReservation("notben", "readme")
require.Nil(t, err) require.Nil(t, err)
require.False(t, b) require.False(t, b)
@ -371,11 +379,17 @@ func TestManager_Reservations(t *testing.T) {
err = a.AllowReservation("phil", "readme") err = a.AllowReservation("phil", "readme")
require.Equal(t, errTopicOwnedByOthers, err) require.Equal(t, errTopicOwnedByOthers, err)
err = a.AllowReservation("phil", "ztopic_")
require.Equal(t, errTopicOwnedByOthers, err)
err = a.AllowReservation("phil", "ztopicX")
require.Nil(t, err)
err = a.AllowReservation("phil", "not-reserved") err = a.AllowReservation("phil", "not-reserved")
require.Nil(t, err) require.Nil(t, err)
// Now remove them again // Now remove them again
require.Nil(t, a.RemoveReservations("ben", "ztopic", "readme")) require.Nil(t, a.RemoveReservations("ben", "ztopic_", "readme"))
count, err = a.ReservationsCount("ben") count, err = a.ReservationsCount("ben")
require.Nil(t, err) require.Nil(t, err)
@ -978,7 +992,44 @@ func TestUser_PhoneNumberAdd_Multiple_Users_Same_Number(t *testing.T) {
require.Nil(t, a.AddPhoneNumber(ben.ID, "+1234567890")) require.Nil(t, a.AddPhoneNumber(ben.ID, "+1234567890"))
} }
func TestSqliteCache_Migration_From1(t *testing.T) { func TestManager_Topic_Wildcard_With_Asterisk_Underscore(t *testing.T) {
f := filepath.Join(t.TempDir(), "user.db")
a := newTestManagerFromFile(t, f, "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval)
require.Nil(t, a.AllowAccess(Everyone, "*_", PermissionRead))
require.Nil(t, a.AllowAccess(Everyone, "__*_", PermissionRead))
require.Nil(t, a.Authorize(nil, "allowed_", PermissionRead))
require.Nil(t, a.Authorize(nil, "__allowed_", PermissionRead))
require.Nil(t, a.Authorize(nil, "_allowed_", PermissionRead)) // The "%" in "%\_" matches the first "_"
require.Equal(t, ErrUnauthorized, a.Authorize(nil, "notallowed", PermissionRead))
require.Equal(t, ErrUnauthorized, a.Authorize(nil, "_notallowed", PermissionRead))
require.Equal(t, ErrUnauthorized, a.Authorize(nil, "__notallowed", PermissionRead))
}
func TestManager_Topic_Wildcard_With_Underscore(t *testing.T) {
f := filepath.Join(t.TempDir(), "user.db")
a := newTestManagerFromFile(t, f, "", PermissionDenyAll, DefaultUserPasswordBcryptCost, DefaultUserStatsQueueWriterInterval)
require.Nil(t, a.AllowAccess(Everyone, "mytopic_", PermissionReadWrite))
require.Nil(t, a.Authorize(nil, "mytopic_", PermissionRead))
require.Nil(t, a.Authorize(nil, "mytopic_", PermissionWrite))
require.Equal(t, ErrUnauthorized, a.Authorize(nil, "mytopicX", PermissionRead))
require.Equal(t, ErrUnauthorized, a.Authorize(nil, "mytopicX", PermissionWrite))
}
func TestToFromSQLWildcard(t *testing.T) {
require.Equal(t, "up%", toSQLWildcard("up*"))
require.Equal(t, "up\\_%", toSQLWildcard("up_*"))
require.Equal(t, "foo", toSQLWildcard("foo"))
require.Equal(t, "up*", fromSQLWildcard("up%"))
require.Equal(t, "up_*", fromSQLWildcard("up\\_%"))
require.Equal(t, "foo", fromSQLWildcard("foo"))
require.Equal(t, "up*", fromSQLWildcard(toSQLWildcard("up*")))
require.Equal(t, "up_*", fromSQLWildcard(toSQLWildcard("up_*")))
require.Equal(t, "foo", fromSQLWildcard(toSQLWildcard("foo")))
}
func TestMigrationFrom1(t *testing.T) {
filename := filepath.Join(t.TempDir(), "user.db") filename := filepath.Join(t.TempDir(), "user.db")
db, err := sql.Open("sqlite3", filename) db, err := sql.Open("sqlite3", filename)
require.Nil(t, err) require.Nil(t, err)
@ -1063,6 +1114,152 @@ func TestSqliteCache_Migration_From1(t *testing.T) {
require.Equal(t, PermissionRead, everyoneGrants[0].Allow) require.Equal(t, PermissionRead, everyoneGrants[0].Allow)
} }
func TestMigrationFrom4(t *testing.T) {
filename := filepath.Join(t.TempDir(), "user.db")
db, err := sql.Open("sqlite3", filename)
require.Nil(t, err)
// Create "version 4" schema
_, err = db.Exec(`
BEGIN;
CREATE TABLE IF NOT EXISTS tier (
id TEXT PRIMARY KEY,
code TEXT NOT NULL,
name TEXT NOT NULL,
messages_limit INT NOT NULL,
messages_expiry_duration INT NOT NULL,
emails_limit INT NOT NULL,
calls_limit INT NOT NULL,
reservations_limit INT NOT NULL,
attachment_file_size_limit INT NOT NULL,
attachment_total_size_limit INT NOT NULL,
attachment_expiry_duration INT NOT NULL,
attachment_bandwidth_limit INT NOT NULL,
stripe_monthly_price_id TEXT,
stripe_yearly_price_id TEXT
);
CREATE UNIQUE INDEX idx_tier_code ON tier (code);
CREATE UNIQUE INDEX idx_tier_stripe_monthly_price_id ON tier (stripe_monthly_price_id);
CREATE UNIQUE INDEX idx_tier_stripe_yearly_price_id ON tier (stripe_yearly_price_id);
CREATE TABLE IF NOT EXISTS user (
id TEXT PRIMARY KEY,
tier_id TEXT,
user TEXT NOT NULL,
pass TEXT NOT NULL,
role TEXT CHECK (role IN ('anonymous', 'admin', 'user')) NOT NULL,
prefs JSON NOT NULL DEFAULT '{}',
sync_topic TEXT NOT NULL,
stats_messages INT NOT NULL DEFAULT (0),
stats_emails INT NOT NULL DEFAULT (0),
stats_calls INT NOT NULL DEFAULT (0),
stripe_customer_id TEXT,
stripe_subscription_id TEXT,
stripe_subscription_status TEXT,
stripe_subscription_interval TEXT,
stripe_subscription_paid_until INT,
stripe_subscription_cancel_at INT,
created INT NOT NULL,
deleted INT,
FOREIGN KEY (tier_id) REFERENCES tier (id)
);
CREATE UNIQUE INDEX idx_user ON user (user);
CREATE UNIQUE INDEX idx_user_stripe_customer_id ON user (stripe_customer_id);
CREATE UNIQUE INDEX idx_user_stripe_subscription_id ON user (stripe_subscription_id);
CREATE TABLE IF NOT EXISTS user_access (
user_id TEXT NOT NULL,
topic TEXT NOT NULL,
read INT NOT NULL,
write INT NOT NULL,
owner_user_id INT,
PRIMARY KEY (user_id, topic),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE,
FOREIGN KEY (owner_user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS user_token (
user_id TEXT NOT NULL,
token TEXT NOT NULL,
label TEXT NOT NULL,
last_access INT NOT NULL,
last_origin TEXT NOT NULL,
expires INT NOT NULL,
PRIMARY KEY (user_id, token),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS user_phone (
user_id TEXT NOT NULL,
phone_number TEXT NOT NULL,
PRIMARY KEY (user_id, phone_number),
FOREIGN KEY (user_id) REFERENCES user (id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS schemaVersion (
id INT PRIMARY KEY,
version INT NOT NULL
);
INSERT INTO user (id, user, pass, role, sync_topic, created)
VALUES ('u_everyone', '*', '', 'anonymous', '', UNIXEPOCH())
ON CONFLICT (id) DO NOTHING;
INSERT INTO schemaVersion (id, version) VALUES (1, 4);
COMMIT;
`)
require.Nil(t, err)
// Insert a few ACL entries
_, err = db.Exec(`
BEGIN;
INSERT INTO user_access (user_id, topic, read, write) values ('u_everyone', 'mytopic_', 1, 1);
INSERT INTO user_access (user_id, topic, read, write) values ('u_everyone', 'up%', 1, 1);
INSERT INTO user_access (user_id, topic, read, write) values ('u_everyone', 'down_%', 1, 1);
COMMIT;
`)
require.Nil(t, err)
// Create manager to trigger migration
a := newTestManagerFromFile(t, filename, "", PermissionDenyAll, bcrypt.MinCost, DefaultUserStatsQueueWriterInterval)
checkSchemaVersion(t, a.db)
// Add another
require.Nil(t, a.AllowAccess(Everyone, "left_*", PermissionReadWrite))
// Check "external view" of grants
everyoneGrants, err := a.Grants(Everyone)
require.Nil(t, err)
require.Equal(t, 4, len(everyoneGrants))
require.Equal(t, "down_*", everyoneGrants[0].TopicPattern)
require.Equal(t, "left_*", everyoneGrants[1].TopicPattern)
require.Equal(t, "mytopic_", everyoneGrants[2].TopicPattern)
require.Equal(t, "up*", everyoneGrants[3].TopicPattern)
// Check they are stored correctly in the database
rows, err := db.Query(`SELECT topic FROM user_access WHERE user_id = 'u_everyone' ORDER BY topic`)
require.Nil(t, err)
topicPatterns := make([]string, 0)
for rows.Next() {
var topicPattern string
require.Nil(t, rows.Scan(&topicPattern))
topicPatterns = append(topicPatterns, topicPattern)
}
require.Nil(t, rows.Close())
require.Equal(t, 4, len(topicPatterns))
require.Equal(t, "down\\_%", topicPatterns[0])
require.Equal(t, "left\\_%", topicPatterns[1])
require.Equal(t, "mytopic\\_", topicPatterns[2])
require.Equal(t, "up%", topicPatterns[3])
// Check that ACL works as excepted
require.Nil(t, a.Authorize(nil, "down_123", PermissionRead))
require.Equal(t, ErrUnauthorized, a.Authorize(nil, "downX123", PermissionRead))
require.Nil(t, a.Authorize(nil, "left_abc", PermissionRead))
require.Equal(t, ErrUnauthorized, a.Authorize(nil, "leftX123", PermissionRead))
require.Nil(t, a.Authorize(nil, "mytopic_", PermissionRead))
require.Equal(t, ErrUnauthorized, a.Authorize(nil, "mytopicX", PermissionRead))
require.Nil(t, a.Authorize(nil, "up123", PermissionRead))
require.Nil(t, a.Authorize(nil, "up", PermissionRead)) // % matches 0 or more characters
}
func checkSchemaVersion(t *testing.T, db *sql.DB) { func checkSchemaVersion(t *testing.T, db *sql.DB) {
rows, err := db.Query(`SELECT version FROM schemaVersion`) rows, err := db.Query(`SELECT version FROM schemaVersion`)
require.Nil(t, err) require.Nil(t, err)