pgweb/pkg/client/tunnel.go

205 lines
3.9 KiB
Go
Raw Normal View History

2016-01-13 10:29:14 +03:00
package client
import (
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/url"
2016-01-13 10:29:14 +03:00
"os"
"strings"
2016-01-13 10:29:14 +03:00
"sync"
"time"
2016-01-13 10:29:14 +03:00
"golang.org/x/crypto/ssh"
"github.com/sosedoff/pgweb/pkg/connection"
"github.com/sosedoff/pgweb/pkg/shared"
2016-01-13 10:29:14 +03:00
)
const (
2018-12-04 21:42:37 +03:00
portStart = 29168
portLimit = 500
2016-01-13 10:29:14 +03:00
)
2018-12-04 21:42:37 +03:00
// Tunnel represents the connection between local and remote server
2016-01-13 10:29:14 +03:00
type Tunnel struct {
TargetHost string
TargetPort string
Port int
SSHInfo *shared.SSHInfo
Config *ssh.ClientConfig
Client *ssh.Client
Listener *net.TCPListener
2016-01-13 10:29:14 +03:00
}
func privateKeyPath() string {
return os.Getenv("HOME") + "/.ssh/id_rsa"
}
func expandKeyPath(path string) string {
home := os.Getenv("HOME")
if home == "" {
return path
}
return strings.Replace(path, "~", home, 1)
}
func fileExists(path string) bool {
_, err := os.Stat(path)
return err == nil
}
2016-01-13 10:29:14 +03:00
func parsePrivateKey(keyPath string) (ssh.Signer, error) {
buff, err := ioutil.ReadFile(keyPath)
if err != nil {
return nil, err
}
return ssh.ParsePrivateKey(buff)
}
func makeConfig(info *shared.SSHInfo) (*ssh.ClientConfig, error) {
2016-01-13 10:29:14 +03:00
methods := []ssh.AuthMethod{}
// Try to use user-provided key, fallback to system default key
keyPath := info.Key
if keyPath == "" {
keyPath = privateKeyPath()
} else {
keyPath = expandKeyPath(keyPath)
}
if fileExists(keyPath) {
2016-01-13 10:29:14 +03:00
key, err := parsePrivateKey(keyPath)
if err != nil {
return nil, err
}
methods = append(methods, ssh.PublicKeys(key))
}
methods = append(methods, ssh.Password(info.Password))
2016-01-13 10:29:14 +03:00
cfg := &ssh.ClientConfig{
User: info.User,
Auth: methods,
Timeout: time.Second * 10,
HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
return nil
},
}
return cfg, nil
2016-01-13 10:29:14 +03:00
}
func (tunnel *Tunnel) sshEndpoint() string {
return fmt.Sprintf("%s:%v", tunnel.SSHInfo.Host, tunnel.SSHInfo.Port)
2016-01-13 10:29:14 +03:00
}
func (tunnel *Tunnel) targetEndpoint() string {
return fmt.Sprintf("%v:%v", tunnel.TargetHost, tunnel.TargetPort)
}
func (tunnel *Tunnel) copy(wg *sync.WaitGroup, writer, reader net.Conn) {
defer wg.Done()
if _, err := io.Copy(writer, reader); err != nil {
log.Println("Tunnel copy error:", err)
}
}
func (tunnel *Tunnel) handleConnection(local net.Conn) {
remote, err := tunnel.Client.Dial("tcp", tunnel.targetEndpoint())
2016-01-13 10:29:14 +03:00
if err != nil {
return
2016-01-13 10:29:14 +03:00
}
wg := sync.WaitGroup{}
wg.Add(2)
go tunnel.copy(&wg, local, remote)
go tunnel.copy(&wg, remote, local)
wg.Wait()
local.Close()
}
2018-12-04 21:42:37 +03:00
// Close closes the tunnel connection
func (tunnel *Tunnel) Close() {
if tunnel.Client != nil {
tunnel.Client.Close()
}
if tunnel.Listener != nil {
tunnel.Listener.Close()
}
}
2018-12-04 21:42:37 +03:00
// Configure establishes the tunnel between localhost and remote machine
func (tunnel *Tunnel) Configure() error {
config, err := makeConfig(tunnel.SSHInfo)
2016-01-13 10:29:14 +03:00
if err != nil {
return err
}
tunnel.Config = config
2016-01-13 10:29:14 +03:00
client, err := ssh.Dial("tcp", tunnel.sshEndpoint(), config)
2016-01-13 10:29:14 +03:00
if err != nil {
return err
}
tunnel.Client = client
2016-01-13 10:29:14 +03:00
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%v", tunnel.Port))
2016-01-13 10:29:14 +03:00
if err != nil {
return err
}
tunnel.Listener = listener.(*net.TCPListener)
return nil
}
2016-01-13 10:29:14 +03:00
2018-12-04 21:42:37 +03:00
// Start starts the connection handler loop
func (tunnel *Tunnel) Start() {
2016-01-22 21:12:00 +03:00
defer tunnel.Close()
2016-01-13 10:29:14 +03:00
for {
conn, err := tunnel.Listener.Accept()
2016-01-13 10:29:14 +03:00
if err != nil {
return
2016-01-13 10:29:14 +03:00
}
go tunnel.handleConnection(conn)
2016-01-13 10:29:14 +03:00
}
}
2018-12-04 21:42:37 +03:00
// NewTunnel instantiates a new tunnel struct from given ssh info
func NewTunnel(sshInfo *shared.SSHInfo, dbUrl string) (*Tunnel, error) {
uri, err := url.Parse(dbUrl)
if err != nil {
return nil, err
2016-01-13 10:29:14 +03:00
}
2018-12-04 21:42:37 +03:00
listenPort, err := connection.FindAvailablePort(portStart, portLimit)
2016-01-13 10:29:14 +03:00
if err != nil {
return nil, err
2016-01-13 10:29:14 +03:00
}
chunks := strings.Split(uri.Host, ":")
host := chunks[0]
port := "5432"
2016-01-13 10:29:14 +03:00
if len(chunks) == 2 {
port = chunks[1]
}
2016-01-13 10:29:14 +03:00
tunnel := &Tunnel{
Port: listenPort,
SSHInfo: sshInfo,
TargetHost: host,
TargetPort: port,
}
return tunnel, nil
2016-01-13 10:29:14 +03:00
}