mirror of
https://github.com/sosedoff/pgweb.git
synced 2024-12-15 03:36:33 +03:00
Tunnel implementation, allow using ssh on connection screen
This commit is contained in:
parent
fb66acebc3
commit
f0f447857f
1
main.go
1
main.go
@ -7,6 +7,7 @@ import (
|
||||
"os/signal"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/sosedoff/pgweb/pkg/api"
|
||||
"github.com/sosedoff/pgweb/pkg/client"
|
||||
"github.com/sosedoff/pgweb/pkg/command"
|
||||
|
@ -8,10 +8,12 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/sosedoff/pgweb/pkg/bookmarks"
|
||||
"github.com/sosedoff/pgweb/pkg/client"
|
||||
"github.com/sosedoff/pgweb/pkg/command"
|
||||
"github.com/sosedoff/pgweb/pkg/connection"
|
||||
"github.com/sosedoff/pgweb/pkg/shared"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -67,6 +69,7 @@ func GetSessions(c *gin.Context) {
|
||||
}
|
||||
|
||||
func Connect(c *gin.Context) {
|
||||
var sshInfo *shared.SSHInfo
|
||||
url := c.Request.FormValue("url")
|
||||
|
||||
if url == "" {
|
||||
@ -82,7 +85,11 @@ func Connect(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
cl, err := client.NewFromUrl(url)
|
||||
if c.Request.FormValue("ssh") != "" {
|
||||
sshInfo = parseSshInfo(c)
|
||||
}
|
||||
|
||||
cl, err := client.NewFromUrl(url, sshInfo)
|
||||
if err != nil {
|
||||
c.JSON(400, Error{err.Error()})
|
||||
return
|
||||
|
@ -7,6 +7,8 @@ import (
|
||||
"strconv"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
||||
"github.com/sosedoff/pgweb/pkg/shared"
|
||||
)
|
||||
|
||||
var extraMimeTypes = map[string]string{
|
||||
@ -69,6 +71,21 @@ func parseIntFormValue(c *gin.Context, name string, defValue int) (int, error) {
|
||||
return num, nil
|
||||
}
|
||||
|
||||
func parseSshInfo(c *gin.Context) *shared.SSHInfo {
|
||||
info := shared.SSHInfo{
|
||||
Host: c.Request.FormValue("ssh_host"),
|
||||
Port: c.Request.FormValue("ssh_port"),
|
||||
User: c.Request.FormValue("ssh_user"),
|
||||
Password: c.Request.FormValue("ssh_password"),
|
||||
}
|
||||
|
||||
if info.Port == "" {
|
||||
info.Port = "22"
|
||||
}
|
||||
|
||||
return &info
|
||||
}
|
||||
|
||||
func assetContentType(name string) string {
|
||||
ext := filepath.Ext(name)
|
||||
result := mime.TypeByExtension(ext)
|
||||
|
@ -8,22 +8,19 @@ import (
|
||||
|
||||
"github.com/BurntSushi/toml"
|
||||
"github.com/mitchellh/go-homedir"
|
||||
|
||||
"github.com/sosedoff/pgweb/pkg/shared"
|
||||
)
|
||||
|
||||
type Bookmark struct {
|
||||
Url string `json:"url"` // Postgres connection URL
|
||||
Host string `json:"host"` // Server hostname
|
||||
Port string `json:"port"` // Server port
|
||||
User string `json:"user"` // Database user
|
||||
Password string `json:"password"` // User password
|
||||
Database string `json:"database"` // Database name
|
||||
Ssl string `json:"ssl"` // Connection SSL mode
|
||||
|
||||
SshHost string `json:"ssh_user"`
|
||||
SshPort string `json:"ssh_port"`
|
||||
SshUser string `json:"ssh_user"`
|
||||
SshPassword string `json:"ssh_password"`
|
||||
SshKey string `json:"ssh_key"`
|
||||
Url string `json:"url"` // Postgres connection URL
|
||||
Host string `json:"host"` // Server hostname
|
||||
Port string `json:"port"` // Server port
|
||||
User string `json:"user"` // Database user
|
||||
Password string `json:"password"` // User password
|
||||
Database string `json:"database"` // Database name
|
||||
Ssl string `json:"ssl"` // Connection SSL mode
|
||||
Ssh shared.SSHInfo `json:"ssh,omitempty"`
|
||||
}
|
||||
|
||||
func readServerConfig(path string) (Bookmark, error) {
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"github.com/sosedoff/pgweb/pkg/command"
|
||||
"github.com/sosedoff/pgweb/pkg/connection"
|
||||
"github.com/sosedoff/pgweb/pkg/history"
|
||||
"github.com/sosedoff/pgweb/pkg/shared"
|
||||
"github.com/sosedoff/pgweb/pkg/statements"
|
||||
)
|
||||
|
||||
@ -63,7 +64,32 @@ func New() (*Client, error) {
|
||||
return &client, nil
|
||||
}
|
||||
|
||||
func NewFromUrl(url string) (*Client, error) {
|
||||
func NewFromUrl(url string, sshInfo *shared.SSHInfo) (*Client, error) {
|
||||
var tunnel *Tunnel
|
||||
|
||||
if sshInfo != nil {
|
||||
if command.Opts.Debug {
|
||||
fmt.Println("Opening SSH tunnel for:", sshInfo)
|
||||
}
|
||||
|
||||
tunnel, err := NewTunnel(sshInfo, url)
|
||||
if err != nil {
|
||||
tunnel.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = tunnel.Configure()
|
||||
if err != nil {
|
||||
tunnel.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
go tunnel.Start()
|
||||
|
||||
// Override remote postgres port with local proxy port
|
||||
url = strings.Replace(url, ":5432", fmt.Sprintf(":%v", tunnel.Port), 1)
|
||||
}
|
||||
|
||||
if command.Opts.Debug {
|
||||
fmt.Println("Creating a new client for:", url)
|
||||
}
|
||||
@ -75,6 +101,7 @@ func NewFromUrl(url string) (*Client, error) {
|
||||
|
||||
client := Client{
|
||||
db: db,
|
||||
tunnel: tunnel,
|
||||
ConnectionString: url,
|
||||
History: history.New(),
|
||||
}
|
||||
@ -230,9 +257,14 @@ func (client *Client) query(query string, args ...interface{}) (*Result, error)
|
||||
|
||||
// Close database connection
|
||||
func (client *Client) Close() error {
|
||||
if client.tunnel != nil {
|
||||
client.tunnel.Close()
|
||||
}
|
||||
|
||||
if client.db != nil {
|
||||
return client.db.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -60,7 +60,7 @@ func setup() {
|
||||
}
|
||||
|
||||
func setupClient() {
|
||||
testClient, _ = NewFromUrl("postgres://postgres@localhost/booktown?sslmode=disable")
|
||||
testClient, _ = NewFromUrl("postgres://postgres@localhost/booktown?sslmode=disable", nil)
|
||||
}
|
||||
|
||||
func teardownClient() {
|
||||
@ -79,7 +79,7 @@ func teardown() {
|
||||
|
||||
func test_NewClientFromUrl(t *testing.T) {
|
||||
url := "postgres://postgres@localhost/booktown?sslmode=disable"
|
||||
client, err := NewFromUrl(url)
|
||||
client, err := NewFromUrl(url, nil)
|
||||
|
||||
if err != nil {
|
||||
defer client.Close()
|
||||
@ -91,7 +91,7 @@ func test_NewClientFromUrl(t *testing.T) {
|
||||
|
||||
func test_NewClientFromUrl2(t *testing.T) {
|
||||
url := "postgresql://postgres@localhost/booktown?sslmode=disable"
|
||||
client, err := NewFromUrl(url)
|
||||
client, err := NewFromUrl(url, nil)
|
||||
|
||||
if err != nil {
|
||||
defer client.Close()
|
||||
@ -257,7 +257,7 @@ func test_HistoryError(t *testing.T) {
|
||||
}
|
||||
|
||||
func test_HistoryUniqueness(t *testing.T) {
|
||||
client, _ := NewFromUrl("postgres://postgres@localhost/booktown?sslmode=disable")
|
||||
client, _ := NewFromUrl("postgres://postgres@localhost/booktown?sslmode=disable", nil)
|
||||
|
||||
client.Query("SELECT * FROM books WHERE id = 1")
|
||||
client.Query("SELECT * FROM books WHERE id = 1")
|
||||
|
@ -6,12 +6,15 @@ import (
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/sosedoff/pgweb/pkg/connection"
|
||||
"github.com/sosedoff/pgweb/pkg/shared"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -22,21 +25,22 @@ const (
|
||||
type Tunnel struct {
|
||||
TargetHost string
|
||||
TargetPort string
|
||||
|
||||
SshHost string
|
||||
SshPort string
|
||||
SshUser string
|
||||
SshPassword string
|
||||
SshKey string
|
||||
|
||||
Config *ssh.ClientConfig
|
||||
Client *ssh.Client
|
||||
Port int
|
||||
SSHInfo *shared.SSHInfo
|
||||
Config *ssh.ClientConfig
|
||||
Client *ssh.Client
|
||||
Listener *net.TCPListener
|
||||
}
|
||||
|
||||
func privateKeyPath() string {
|
||||
return os.Getenv("HOME") + "/.ssh/id_rsa"
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func parsePrivateKey(keyPath string) (ssh.Signer, error) {
|
||||
buff, err := ioutil.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
@ -46,10 +50,11 @@ func parsePrivateKey(keyPath string) (ssh.Signer, error) {
|
||||
return ssh.ParsePrivateKey(buff)
|
||||
}
|
||||
|
||||
func makeConfig(user, password, keyPath string) (*ssh.ClientConfig, error) {
|
||||
func makeConfig(info *shared.SSHInfo) (*ssh.ClientConfig, error) {
|
||||
methods := []ssh.AuthMethod{}
|
||||
|
||||
if keyPath != "" {
|
||||
keyPath := privateKeyPath()
|
||||
if fileExists(keyPath) {
|
||||
key, err := parsePrivateKey(keyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -58,52 +63,19 @@ func makeConfig(user, password, keyPath string) (*ssh.ClientConfig, error) {
|
||||
methods = append(methods, ssh.PublicKeys(key))
|
||||
}
|
||||
|
||||
methods = append(methods, ssh.Password(password))
|
||||
methods = append(methods, ssh.Password(info.Password))
|
||||
|
||||
return &ssh.ClientConfig{User: user, Auth: methods}, nil
|
||||
return &ssh.ClientConfig{User: info.User, Auth: methods}, nil
|
||||
}
|
||||
|
||||
func (tunnel *Tunnel) sshEndpoint() string {
|
||||
return fmt.Sprintf("%s:%v", tunnel.SshHost, tunnel.SshPort)
|
||||
return fmt.Sprintf("%s:%v", tunnel.SSHInfo.Host, tunnel.SSHInfo.Port)
|
||||
}
|
||||
|
||||
func (tunnel *Tunnel) targetEndpoint() string {
|
||||
return fmt.Sprintf("%v:%v", tunnel.TargetHost, tunnel.TargetPort)
|
||||
}
|
||||
|
||||
func (tunnel *Tunnel) Start() error {
|
||||
config, err := makeConfig(tunnel.SshUser, tunnel.SshPassword, tunnel.SshKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
client, err := ssh.Dial("tcp", tunnel.sshEndpoint(), config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer client.Close()
|
||||
|
||||
port, err := connection.AvailablePort(PORT_START, PORT_LIMIT)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%v", port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer listener.Close()
|
||||
|
||||
for {
|
||||
conn, err := listener.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go tunnel.handleConnection(conn, client)
|
||||
}
|
||||
}
|
||||
|
||||
func (tunnel *Tunnel) copy(wg *sync.WaitGroup, writer, reader net.Conn) {
|
||||
defer wg.Done()
|
||||
if _, err := io.Copy(writer, reader); err != nil {
|
||||
@ -111,8 +83,8 @@ func (tunnel *Tunnel) copy(wg *sync.WaitGroup, writer, reader net.Conn) {
|
||||
}
|
||||
}
|
||||
|
||||
func (tunnel *Tunnel) handleConnection(local net.Conn, sshClient *ssh.Client) {
|
||||
remote, err := sshClient.Dial("tcp", tunnel.targetEndpoint())
|
||||
func (tunnel *Tunnel) handleConnection(local net.Conn) {
|
||||
remote, err := tunnel.Client.Dial("tcp", tunnel.targetEndpoint())
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@ -124,4 +96,79 @@ func (tunnel *Tunnel) handleConnection(local net.Conn, sshClient *ssh.Client) {
|
||||
go tunnel.copy(&wg, remote, local)
|
||||
|
||||
wg.Wait()
|
||||
local.Close()
|
||||
}
|
||||
|
||||
func (tunnel *Tunnel) Close() {
|
||||
if tunnel.Client != nil {
|
||||
tunnel.Client.Close()
|
||||
}
|
||||
|
||||
if tunnel.Listener != nil {
|
||||
tunnel.Listener.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func (tunnel *Tunnel) Configure() error {
|
||||
config, err := makeConfig(tunnel.SSHInfo)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tunnel.Config = config
|
||||
|
||||
client, err := ssh.Dial("tcp", tunnel.sshEndpoint(), config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tunnel.Client = client
|
||||
|
||||
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%v", tunnel.Port))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
tunnel.Listener = listener.(*net.TCPListener)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (tunnel *Tunnel) Start() {
|
||||
for {
|
||||
conn, err := tunnel.Listener.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
go tunnel.handleConnection(conn)
|
||||
}
|
||||
|
||||
tunnel.Close()
|
||||
}
|
||||
|
||||
func NewTunnel(sshInfo *shared.SSHInfo, dbUrl string) (*Tunnel, error) {
|
||||
uri, err := url.Parse(dbUrl)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
listenPort, err := connection.AvailablePort(PORT_START, PORT_LIMIT)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
chunks := strings.Split(uri.Host, ":")
|
||||
host := chunks[0]
|
||||
port := "5432"
|
||||
|
||||
if len(chunks) == 2 {
|
||||
port = chunks[1]
|
||||
}
|
||||
|
||||
tunnel := &Tunnel{
|
||||
Port: listenPort,
|
||||
SSHInfo: sshInfo,
|
||||
TargetHost: host,
|
||||
TargetPort: port,
|
||||
}
|
||||
|
||||
return tunnel, nil
|
||||
}
|
||||
|
File diff suppressed because one or more lines are too long
17
pkg/shared/ssh_info.go
Normal file
17
pkg/shared/ssh_info.go
Normal file
@ -0,0 +1,17 @@
|
||||
package shared
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type SSHInfo struct {
|
||||
Host string `json:"host,omitempty"`
|
||||
Port string `json:"port,omitempty"`
|
||||
User string `json:"user,omitempty"`
|
||||
Password string `json:"password,omitempty"`
|
||||
Key string `json:"key,omitempty"`
|
||||
}
|
||||
|
||||
func (info SSHInfo) String() string {
|
||||
return fmt.Sprintf("%s@%s:%s", info.User, info.Host, info.Port)
|
||||
}
|
@ -116,6 +116,7 @@
|
||||
<div class="btn-group btn-group-sm connection-group-switch">
|
||||
<button type="button" data="scheme" class="btn btn-default" id="connection_scheme">Scheme</button>
|
||||
<button type="button" data="standard" class="btn btn-default active" id="connection_standard">Standard</button>
|
||||
<button type="button" data="ssh" class="btn btn-default" id="connection_ssh">SSH</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@ -204,14 +205,14 @@
|
||||
<div class="form-group">
|
||||
<label class="col-sm-3 control-label">SSH Password</label>
|
||||
<div class="col-sm-9">
|
||||
<input type="text" id="ssh_password" class="form-control" placeholder="optional" />
|
||||
<input type="password" id="ssh_password" class="form-control" placeholder="optional" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label class="col-sm-3 control-label">SSH Port</label>
|
||||
<div class="col-sm-9">
|
||||
<input type="text" id="pg_host" class="form-control" placeholder="optional" />
|
||||
<input type="text" id="ssh_port" class="form-control" placeholder="optional" />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
@ -648,7 +648,7 @@ function getConnectionString() {
|
||||
var mode = $(".connection-group-switch button.active").attr("data");
|
||||
var ssl = $("#connection_ssl").val();
|
||||
|
||||
if (mode == "standard") {
|
||||
if (mode == "standard" || mode == "ssh") {
|
||||
var host = $("#pg_host").val();
|
||||
var port = $("#pg_port").val();
|
||||
var user = $("#pg_user").val();
|
||||
@ -929,22 +929,39 @@ $(document).ready(function() {
|
||||
$("#pg_password").val(item.password);
|
||||
$("#pg_db").val(item.database);
|
||||
$("#connection_ssl").val(item.ssl);
|
||||
|
||||
if (item.ssh) {
|
||||
$("#ssh_host").val(item.ssh.host);
|
||||
$("#ssh_port").val(item.ssh.port);
|
||||
$("#ssh_user").val(item.ssh.user);
|
||||
$("#ssh_password").val(item.ssh.password);
|
||||
}
|
||||
});
|
||||
|
||||
$("#connection_form").on("submit", function(e) {
|
||||
e.preventDefault();
|
||||
|
||||
var button = $(this).children("button");
|
||||
var url = getConnectionString();
|
||||
var params = {
|
||||
url: getConnectionString()
|
||||
};
|
||||
|
||||
if (url.length == 0) {
|
||||
if (params.url.length == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
if ($(".connection-group-switch button.active").attr("data") == "ssh") {
|
||||
params["ssh"] = 1
|
||||
params["ssh_host"] = $("#ssh_host").val();
|
||||
params["ssh_port"] = $("#ssh_port").val();
|
||||
params["ssh_user"] = $("#ssh_user").val();
|
||||
params["ssh_password"] = $("#ssh_password").val();
|
||||
}
|
||||
|
||||
$("#connection_error").hide();
|
||||
button.prop("disabled", true).text("Please wait...");
|
||||
|
||||
apiCall("post", "/connect", { url: url }, function(resp) {
|
||||
apiCall("post", "/connect", params, function(resp) {
|
||||
button.prop("disabled", false).text("Connect");
|
||||
|
||||
if (resp.error) {
|
||||
|
Loading…
Reference in New Issue
Block a user