mirror of
https://github.com/dbcli/pgcli.git
synced 2024-10-04 09:17:08 +03:00
ssh tunnels: allow configuring auto matches (#1302)
This commit is contained in:
parent
ed9d123073
commit
54f0cc9ddd
@ -286,6 +286,7 @@ class PGCli:
|
|||||||
|
|
||||||
self.prompt_app = None
|
self.prompt_app = None
|
||||||
|
|
||||||
|
self.ssh_tunnel_config = c.get("ssh tunnels")
|
||||||
self.ssh_tunnel_url = ssh_tunnel_url
|
self.ssh_tunnel_url = ssh_tunnel_url
|
||||||
self.ssh_tunnel = None
|
self.ssh_tunnel = None
|
||||||
|
|
||||||
@ -599,18 +600,24 @@ class PGCli:
|
|||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if dsn:
|
||||||
|
parsed_dsn = parse_dsn(dsn)
|
||||||
|
if "host" in parsed_dsn:
|
||||||
|
host = parsed_dsn["host"]
|
||||||
|
if "port" in parsed_dsn:
|
||||||
|
port = parsed_dsn["port"]
|
||||||
|
|
||||||
|
if self.ssh_tunnel_config and not self.ssh_tunnel_url:
|
||||||
|
for db_host_regex, tunnel_url in self.ssh_tunnel_config.items():
|
||||||
|
if re.search(db_host_regex, host):
|
||||||
|
self.ssh_tunnel_url = tunnel_url
|
||||||
|
break
|
||||||
|
|
||||||
if self.ssh_tunnel_url:
|
if self.ssh_tunnel_url:
|
||||||
# We add the protocol as urlparse doesn't find it by itself
|
# We add the protocol as urlparse doesn't find it by itself
|
||||||
if "://" not in self.ssh_tunnel_url:
|
if "://" not in self.ssh_tunnel_url:
|
||||||
self.ssh_tunnel_url = f"ssh://{self.ssh_tunnel_url}"
|
self.ssh_tunnel_url = f"ssh://{self.ssh_tunnel_url}"
|
||||||
|
|
||||||
if dsn:
|
|
||||||
parsed_dsn = parse_dsn(dsn)
|
|
||||||
if "host" in parsed_dsn:
|
|
||||||
host = parsed_dsn["host"]
|
|
||||||
if "port" in parsed_dsn:
|
|
||||||
port = parsed_dsn["port"]
|
|
||||||
|
|
||||||
tunnel_info = urlparse(self.ssh_tunnel_url)
|
tunnel_info = urlparse(self.ssh_tunnel_url)
|
||||||
params = {
|
params = {
|
||||||
"local_bind_address": ("127.0.0.1",),
|
"local_bind_address": ("127.0.0.1",),
|
||||||
|
@ -1,6 +1,8 @@
|
|||||||
|
import os
|
||||||
from unittest.mock import patch, MagicMock, ANY
|
from unittest.mock import patch, MagicMock, ANY
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from configobj import ConfigObj
|
||||||
from click.testing import CliRunner
|
from click.testing import CliRunner
|
||||||
from sshtunnel import SSHTunnelForwarder
|
from sshtunnel import SSHTunnelForwarder
|
||||||
|
|
||||||
@ -129,3 +131,58 @@ def test_cli_with_tunnel() -> None:
|
|||||||
mock_pgcli.assert_called_once()
|
mock_pgcli.assert_called_once()
|
||||||
call_args, call_kwargs = mock_pgcli.call_args
|
call_args, call_kwargs = mock_pgcli.call_args
|
||||||
assert call_kwargs["ssh_tunnel_url"] == tunnel_url
|
assert call_kwargs["ssh_tunnel_url"] == tunnel_url
|
||||||
|
|
||||||
|
|
||||||
|
def test_config(
|
||||||
|
tmpdir: os.PathLike, mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock
|
||||||
|
) -> None:
|
||||||
|
pgclirc = str(tmpdir.join("rcfile"))
|
||||||
|
|
||||||
|
tunnel_user = "tunnel_user"
|
||||||
|
tunnel_passwd = "tunnel_pass"
|
||||||
|
tunnel_host = "tunnel.host"
|
||||||
|
tunnel_port = 1022
|
||||||
|
tunnel_url = f"{tunnel_user}:{tunnel_passwd}@{tunnel_host}:{tunnel_port}"
|
||||||
|
|
||||||
|
tunnel2_url = "tunnel2.host"
|
||||||
|
|
||||||
|
config = ConfigObj()
|
||||||
|
config.filename = pgclirc
|
||||||
|
config["ssh tunnels"] = {}
|
||||||
|
config["ssh tunnels"][r"\.com$"] = tunnel_url
|
||||||
|
config["ssh tunnels"][r"^hello-"] = tunnel2_url
|
||||||
|
config.write()
|
||||||
|
|
||||||
|
# Unmatched host
|
||||||
|
pgcli = PGCli(pgclirc_file=pgclirc)
|
||||||
|
pgcli.connect(host="unmatched.host")
|
||||||
|
mock_ssh_tunnel_forwarder.assert_not_called()
|
||||||
|
|
||||||
|
# Host matching first tunnel
|
||||||
|
pgcli = PGCli(pgclirc_file=pgclirc)
|
||||||
|
pgcli.connect(host="matched.host.com")
|
||||||
|
mock_ssh_tunnel_forwarder.assert_called_once()
|
||||||
|
call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args
|
||||||
|
assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port)
|
||||||
|
assert call_kwargs["ssh_username"] == tunnel_user
|
||||||
|
assert call_kwargs["ssh_password"] == tunnel_passwd
|
||||||
|
mock_ssh_tunnel_forwarder.reset_mock()
|
||||||
|
|
||||||
|
# Host matching second tunnel
|
||||||
|
pgcli = PGCli(pgclirc_file=pgclirc)
|
||||||
|
pgcli.connect(host="hello-i-am-matched")
|
||||||
|
mock_ssh_tunnel_forwarder.assert_called_once()
|
||||||
|
|
||||||
|
call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args
|
||||||
|
assert call_kwargs["ssh_address_or_host"] == (tunnel2_url, 22)
|
||||||
|
mock_ssh_tunnel_forwarder.reset_mock()
|
||||||
|
|
||||||
|
# Host matching both tunnels (will use the first one matched)
|
||||||
|
pgcli = PGCli(pgclirc_file=pgclirc)
|
||||||
|
pgcli.connect(host="hello-i-am-matched.com")
|
||||||
|
mock_ssh_tunnel_forwarder.assert_called_once()
|
||||||
|
|
||||||
|
call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args
|
||||||
|
assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port)
|
||||||
|
assert call_kwargs["ssh_username"] == tunnel_user
|
||||||
|
assert call_kwargs["ssh_password"] == tunnel_passwd
|
||||||
|
Loading…
Reference in New Issue
Block a user