Tunnel implementation, allow using ssh on connection screen

This commit is contained in:
Dan Sosedoff 2016-01-14 19:50:01 -06:00
parent fb66acebc3
commit f0f447857f
11 changed files with 229 additions and 93 deletions

View File

@ -7,6 +7,7 @@ import (
"os/signal" "os/signal"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/sosedoff/pgweb/pkg/api" "github.com/sosedoff/pgweb/pkg/api"
"github.com/sosedoff/pgweb/pkg/client" "github.com/sosedoff/pgweb/pkg/client"
"github.com/sosedoff/pgweb/pkg/command" "github.com/sosedoff/pgweb/pkg/command"

View File

@ -8,10 +8,12 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/sosedoff/pgweb/pkg/bookmarks" "github.com/sosedoff/pgweb/pkg/bookmarks"
"github.com/sosedoff/pgweb/pkg/client" "github.com/sosedoff/pgweb/pkg/client"
"github.com/sosedoff/pgweb/pkg/command" "github.com/sosedoff/pgweb/pkg/command"
"github.com/sosedoff/pgweb/pkg/connection" "github.com/sosedoff/pgweb/pkg/connection"
"github.com/sosedoff/pgweb/pkg/shared"
) )
var ( var (
@ -67,6 +69,7 @@ func GetSessions(c *gin.Context) {
} }
func Connect(c *gin.Context) { func Connect(c *gin.Context) {
var sshInfo *shared.SSHInfo
url := c.Request.FormValue("url") url := c.Request.FormValue("url")
if url == "" { if url == "" {
@ -82,7 +85,11 @@ func Connect(c *gin.Context) {
return return
} }
cl, err := client.NewFromUrl(url) if c.Request.FormValue("ssh") != "" {
sshInfo = parseSshInfo(c)
}
cl, err := client.NewFromUrl(url, sshInfo)
if err != nil { if err != nil {
c.JSON(400, Error{err.Error()}) c.JSON(400, Error{err.Error()})
return return

View File

@ -7,6 +7,8 @@ import (
"strconv" "strconv"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/sosedoff/pgweb/pkg/shared"
) )
var extraMimeTypes = map[string]string{ var extraMimeTypes = map[string]string{
@ -69,6 +71,21 @@ func parseIntFormValue(c *gin.Context, name string, defValue int) (int, error) {
return num, nil return num, nil
} }
func parseSshInfo(c *gin.Context) *shared.SSHInfo {
info := shared.SSHInfo{
Host: c.Request.FormValue("ssh_host"),
Port: c.Request.FormValue("ssh_port"),
User: c.Request.FormValue("ssh_user"),
Password: c.Request.FormValue("ssh_password"),
}
if info.Port == "" {
info.Port = "22"
}
return &info
}
func assetContentType(name string) string { func assetContentType(name string) string {
ext := filepath.Ext(name) ext := filepath.Ext(name)
result := mime.TypeByExtension(ext) result := mime.TypeByExtension(ext)

View File

@ -8,22 +8,19 @@ import (
"github.com/BurntSushi/toml" "github.com/BurntSushi/toml"
"github.com/mitchellh/go-homedir" "github.com/mitchellh/go-homedir"
"github.com/sosedoff/pgweb/pkg/shared"
) )
type Bookmark struct { type Bookmark struct {
Url string `json:"url"` // Postgres connection URL Url string `json:"url"` // Postgres connection URL
Host string `json:"host"` // Server hostname Host string `json:"host"` // Server hostname
Port string `json:"port"` // Server port Port string `json:"port"` // Server port
User string `json:"user"` // Database user User string `json:"user"` // Database user
Password string `json:"password"` // User password Password string `json:"password"` // User password
Database string `json:"database"` // Database name Database string `json:"database"` // Database name
Ssl string `json:"ssl"` // Connection SSL mode Ssl string `json:"ssl"` // Connection SSL mode
Ssh shared.SSHInfo `json:"ssh,omitempty"`
SshHost string `json:"ssh_user"`
SshPort string `json:"ssh_port"`
SshUser string `json:"ssh_user"`
SshPassword string `json:"ssh_password"`
SshKey string `json:"ssh_key"`
} }
func readServerConfig(path string) (Bookmark, error) { func readServerConfig(path string) (Bookmark, error) {

View File

@ -11,6 +11,7 @@ import (
"github.com/sosedoff/pgweb/pkg/command" "github.com/sosedoff/pgweb/pkg/command"
"github.com/sosedoff/pgweb/pkg/connection" "github.com/sosedoff/pgweb/pkg/connection"
"github.com/sosedoff/pgweb/pkg/history" "github.com/sosedoff/pgweb/pkg/history"
"github.com/sosedoff/pgweb/pkg/shared"
"github.com/sosedoff/pgweb/pkg/statements" "github.com/sosedoff/pgweb/pkg/statements"
) )
@ -63,7 +64,32 @@ func New() (*Client, error) {
return &client, nil return &client, nil
} }
func NewFromUrl(url string) (*Client, error) { func NewFromUrl(url string, sshInfo *shared.SSHInfo) (*Client, error) {
var tunnel *Tunnel
if sshInfo != nil {
if command.Opts.Debug {
fmt.Println("Opening SSH tunnel for:", sshInfo)
}
tunnel, err := NewTunnel(sshInfo, url)
if err != nil {
tunnel.Close()
return nil, err
}
err = tunnel.Configure()
if err != nil {
tunnel.Close()
return nil, err
}
go tunnel.Start()
// Override remote postgres port with local proxy port
url = strings.Replace(url, ":5432", fmt.Sprintf(":%v", tunnel.Port), 1)
}
if command.Opts.Debug { if command.Opts.Debug {
fmt.Println("Creating a new client for:", url) fmt.Println("Creating a new client for:", url)
} }
@ -75,6 +101,7 @@ func NewFromUrl(url string) (*Client, error) {
client := Client{ client := Client{
db: db, db: db,
tunnel: tunnel,
ConnectionString: url, ConnectionString: url,
History: history.New(), History: history.New(),
} }
@ -230,9 +257,14 @@ func (client *Client) query(query string, args ...interface{}) (*Result, error)
// Close database connection // Close database connection
func (client *Client) Close() error { func (client *Client) Close() error {
if client.tunnel != nil {
client.tunnel.Close()
}
if client.db != nil { if client.db != nil {
return client.db.Close() return client.db.Close()
} }
return nil return nil
} }

View File

@ -60,7 +60,7 @@ func setup() {
} }
func setupClient() { func setupClient() {
testClient, _ = NewFromUrl("postgres://postgres@localhost/booktown?sslmode=disable") testClient, _ = NewFromUrl("postgres://postgres@localhost/booktown?sslmode=disable", nil)
} }
func teardownClient() { func teardownClient() {
@ -79,7 +79,7 @@ func teardown() {
func test_NewClientFromUrl(t *testing.T) { func test_NewClientFromUrl(t *testing.T) {
url := "postgres://postgres@localhost/booktown?sslmode=disable" url := "postgres://postgres@localhost/booktown?sslmode=disable"
client, err := NewFromUrl(url) client, err := NewFromUrl(url, nil)
if err != nil { if err != nil {
defer client.Close() defer client.Close()
@ -91,7 +91,7 @@ func test_NewClientFromUrl(t *testing.T) {
func test_NewClientFromUrl2(t *testing.T) { func test_NewClientFromUrl2(t *testing.T) {
url := "postgresql://postgres@localhost/booktown?sslmode=disable" url := "postgresql://postgres@localhost/booktown?sslmode=disable"
client, err := NewFromUrl(url) client, err := NewFromUrl(url, nil)
if err != nil { if err != nil {
defer client.Close() defer client.Close()
@ -257,7 +257,7 @@ func test_HistoryError(t *testing.T) {
} }
func test_HistoryUniqueness(t *testing.T) { func test_HistoryUniqueness(t *testing.T) {
client, _ := NewFromUrl("postgres://postgres@localhost/booktown?sslmode=disable") client, _ := NewFromUrl("postgres://postgres@localhost/booktown?sslmode=disable", nil)
client.Query("SELECT * FROM books WHERE id = 1") client.Query("SELECT * FROM books WHERE id = 1")
client.Query("SELECT * FROM books WHERE id = 1") client.Query("SELECT * FROM books WHERE id = 1")

View File

@ -6,12 +6,15 @@ import (
"io/ioutil" "io/ioutil"
"log" "log"
"net" "net"
"net/url"
"os" "os"
"strings"
"sync" "sync"
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
"github.com/sosedoff/pgweb/pkg/connection" "github.com/sosedoff/pgweb/pkg/connection"
"github.com/sosedoff/pgweb/pkg/shared"
) )
const ( const (
@ -22,21 +25,22 @@ const (
type Tunnel struct { type Tunnel struct {
TargetHost string TargetHost string
TargetPort string TargetPort string
Port int
SshHost string SSHInfo *shared.SSHInfo
SshPort string Config *ssh.ClientConfig
SshUser string Client *ssh.Client
SshPassword string Listener *net.TCPListener
SshKey string
Config *ssh.ClientConfig
Client *ssh.Client
} }
func privateKeyPath() string { func privateKeyPath() string {
return os.Getenv("HOME") + "/.ssh/id_rsa" return os.Getenv("HOME") + "/.ssh/id_rsa"
} }
func fileExists(path string) bool {
_, err := os.Stat(path)
return err == nil
}
func parsePrivateKey(keyPath string) (ssh.Signer, error) { func parsePrivateKey(keyPath string) (ssh.Signer, error) {
buff, err := ioutil.ReadFile(keyPath) buff, err := ioutil.ReadFile(keyPath)
if err != nil { if err != nil {
@ -46,10 +50,11 @@ func parsePrivateKey(keyPath string) (ssh.Signer, error) {
return ssh.ParsePrivateKey(buff) return ssh.ParsePrivateKey(buff)
} }
func makeConfig(user, password, keyPath string) (*ssh.ClientConfig, error) { func makeConfig(info *shared.SSHInfo) (*ssh.ClientConfig, error) {
methods := []ssh.AuthMethod{} methods := []ssh.AuthMethod{}
if keyPath != "" { keyPath := privateKeyPath()
if fileExists(keyPath) {
key, err := parsePrivateKey(keyPath) key, err := parsePrivateKey(keyPath)
if err != nil { if err != nil {
return nil, err return nil, err
@ -58,52 +63,19 @@ func makeConfig(user, password, keyPath string) (*ssh.ClientConfig, error) {
methods = append(methods, ssh.PublicKeys(key)) methods = append(methods, ssh.PublicKeys(key))
} }
methods = append(methods, ssh.Password(password)) methods = append(methods, ssh.Password(info.Password))
return &ssh.ClientConfig{User: user, Auth: methods}, nil return &ssh.ClientConfig{User: info.User, Auth: methods}, nil
} }
func (tunnel *Tunnel) sshEndpoint() string { func (tunnel *Tunnel) sshEndpoint() string {
return fmt.Sprintf("%s:%v", tunnel.SshHost, tunnel.SshPort) return fmt.Sprintf("%s:%v", tunnel.SSHInfo.Host, tunnel.SSHInfo.Port)
} }
func (tunnel *Tunnel) targetEndpoint() string { func (tunnel *Tunnel) targetEndpoint() string {
return fmt.Sprintf("%v:%v", tunnel.TargetHost, tunnel.TargetPort) return fmt.Sprintf("%v:%v", tunnel.TargetHost, tunnel.TargetPort)
} }
func (tunnel *Tunnel) Start() error {
config, err := makeConfig(tunnel.SshUser, tunnel.SshPassword, tunnel.SshKey)
if err != nil {
return err
}
client, err := ssh.Dial("tcp", tunnel.sshEndpoint(), config)
if err != nil {
return err
}
defer client.Close()
port, err := connection.AvailablePort(PORT_START, PORT_LIMIT)
if err != nil {
return err
}
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%v", port))
if err != nil {
return err
}
defer listener.Close()
for {
conn, err := listener.Accept()
if err != nil {
return err
}
go tunnel.handleConnection(conn, client)
}
}
func (tunnel *Tunnel) copy(wg *sync.WaitGroup, writer, reader net.Conn) { func (tunnel *Tunnel) copy(wg *sync.WaitGroup, writer, reader net.Conn) {
defer wg.Done() defer wg.Done()
if _, err := io.Copy(writer, reader); err != nil { if _, err := io.Copy(writer, reader); err != nil {
@ -111,8 +83,8 @@ func (tunnel *Tunnel) copy(wg *sync.WaitGroup, writer, reader net.Conn) {
} }
} }
func (tunnel *Tunnel) handleConnection(local net.Conn, sshClient *ssh.Client) { func (tunnel *Tunnel) handleConnection(local net.Conn) {
remote, err := sshClient.Dial("tcp", tunnel.targetEndpoint()) remote, err := tunnel.Client.Dial("tcp", tunnel.targetEndpoint())
if err != nil { if err != nil {
return return
} }
@ -124,4 +96,79 @@ func (tunnel *Tunnel) handleConnection(local net.Conn, sshClient *ssh.Client) {
go tunnel.copy(&wg, remote, local) go tunnel.copy(&wg, remote, local)
wg.Wait() wg.Wait()
local.Close()
}
func (tunnel *Tunnel) Close() {
if tunnel.Client != nil {
tunnel.Client.Close()
}
if tunnel.Listener != nil {
tunnel.Listener.Close()
}
}
func (tunnel *Tunnel) Configure() error {
config, err := makeConfig(tunnel.SSHInfo)
if err != nil {
return err
}
tunnel.Config = config
client, err := ssh.Dial("tcp", tunnel.sshEndpoint(), config)
if err != nil {
return err
}
tunnel.Client = client
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%v", tunnel.Port))
if err != nil {
return err
}
tunnel.Listener = listener.(*net.TCPListener)
return nil
}
func (tunnel *Tunnel) Start() {
for {
conn, err := tunnel.Listener.Accept()
if err != nil {
return
}
go tunnel.handleConnection(conn)
}
tunnel.Close()
}
func NewTunnel(sshInfo *shared.SSHInfo, dbUrl string) (*Tunnel, error) {
uri, err := url.Parse(dbUrl)
if err != nil {
return nil, err
}
listenPort, err := connection.AvailablePort(PORT_START, PORT_LIMIT)
if err != nil {
return nil, err
}
chunks := strings.Split(uri.Host, ":")
host := chunks[0]
port := "5432"
if len(chunks) == 2 {
port = chunks[1]
}
tunnel := &Tunnel{
Port: listenPort,
SSHInfo: sshInfo,
TargetHost: host,
TargetPort: port,
}
return tunnel, nil
} }

File diff suppressed because one or more lines are too long

17
pkg/shared/ssh_info.go Normal file
View File

@ -0,0 +1,17 @@
package shared
import (
"fmt"
)
type SSHInfo struct {
Host string `json:"host,omitempty"`
Port string `json:"port,omitempty"`
User string `json:"user,omitempty"`
Password string `json:"password,omitempty"`
Key string `json:"key,omitempty"`
}
func (info SSHInfo) String() string {
return fmt.Sprintf("%s@%s:%s", info.User, info.Host, info.Port)
}

View File

@ -116,6 +116,7 @@
<div class="btn-group btn-group-sm connection-group-switch"> <div class="btn-group btn-group-sm connection-group-switch">
<button type="button" data="scheme" class="btn btn-default" id="connection_scheme">Scheme</button> <button type="button" data="scheme" class="btn btn-default" id="connection_scheme">Scheme</button>
<button type="button" data="standard" class="btn btn-default active" id="connection_standard">Standard</button> <button type="button" data="standard" class="btn btn-default active" id="connection_standard">Standard</button>
<button type="button" data="ssh" class="btn btn-default" id="connection_ssh">SSH</button>
</div> </div>
</div> </div>
@ -204,14 +205,14 @@
<div class="form-group"> <div class="form-group">
<label class="col-sm-3 control-label">SSH Password</label> <label class="col-sm-3 control-label">SSH Password</label>
<div class="col-sm-9"> <div class="col-sm-9">
<input type="text" id="ssh_password" class="form-control" placeholder="optional" /> <input type="password" id="ssh_password" class="form-control" placeholder="optional" />
</div> </div>
</div> </div>
<div class="form-group"> <div class="form-group">
<label class="col-sm-3 control-label">SSH Port</label> <label class="col-sm-3 control-label">SSH Port</label>
<div class="col-sm-9"> <div class="col-sm-9">
<input type="text" id="pg_host" class="form-control" placeholder="optional" /> <input type="text" id="ssh_port" class="form-control" placeholder="optional" />
</div> </div>
</div> </div>
</div> </div>

View File

@ -648,7 +648,7 @@ function getConnectionString() {
var mode = $(".connection-group-switch button.active").attr("data"); var mode = $(".connection-group-switch button.active").attr("data");
var ssl = $("#connection_ssl").val(); var ssl = $("#connection_ssl").val();
if (mode == "standard") { if (mode == "standard" || mode == "ssh") {
var host = $("#pg_host").val(); var host = $("#pg_host").val();
var port = $("#pg_port").val(); var port = $("#pg_port").val();
var user = $("#pg_user").val(); var user = $("#pg_user").val();
@ -929,22 +929,39 @@ $(document).ready(function() {
$("#pg_password").val(item.password); $("#pg_password").val(item.password);
$("#pg_db").val(item.database); $("#pg_db").val(item.database);
$("#connection_ssl").val(item.ssl); $("#connection_ssl").val(item.ssl);
if (item.ssh) {
$("#ssh_host").val(item.ssh.host);
$("#ssh_port").val(item.ssh.port);
$("#ssh_user").val(item.ssh.user);
$("#ssh_password").val(item.ssh.password);
}
}); });
$("#connection_form").on("submit", function(e) { $("#connection_form").on("submit", function(e) {
e.preventDefault(); e.preventDefault();
var button = $(this).children("button"); var button = $(this).children("button");
var url = getConnectionString(); var params = {
url: getConnectionString()
};
if (url.length == 0) { if (params.url.length == 0) {
return; return;
} }
if ($(".connection-group-switch button.active").attr("data") == "ssh") {
params["ssh"] = 1
params["ssh_host"] = $("#ssh_host").val();
params["ssh_port"] = $("#ssh_port").val();
params["ssh_user"] = $("#ssh_user").val();
params["ssh_password"] = $("#ssh_password").val();
}
$("#connection_error").hide(); $("#connection_error").hide();
button.prop("disabled", true).text("Please wait..."); button.prop("disabled", true).text("Please wait...");
apiCall("post", "/connect", { url: url }, function(resp) { apiCall("post", "/connect", params, function(resp) {
button.prop("disabled", false).text("Connect"); button.prop("disabled", false).text("Connect");
if (resp.error) { if (resp.error) {