mirror of
https://github.com/AdguardTeam/AdGuardHome.git
synced 2024-12-15 11:22:49 +03:00
106 lines
2.6 KiB
Go
106 lines
2.6 KiB
Go
package upstream
|
|
|
|
import (
|
|
"crypto/tls"
|
|
"time"
|
|
|
|
"github.com/miekg/dns"
|
|
"golang.org/x/net/context"
|
|
)
|
|
|
|
// DnsUpstream is a very simple upstream implementation for plain DNS
|
|
type DnsUpstream struct {
|
|
endpoint string // IP:port
|
|
timeout time.Duration // Max read and write timeout
|
|
proto string // Protocol (tcp, tcp-tls, or udp)
|
|
transport *Transport // Persistent connections cache
|
|
}
|
|
|
|
// NewDnsUpstream creates a new DNS upstream
|
|
func NewDnsUpstream(endpoint string, proto string, tlsServerName string) (Upstream, error) {
|
|
u := &DnsUpstream{
|
|
endpoint: endpoint,
|
|
timeout: defaultTimeout,
|
|
proto: proto,
|
|
}
|
|
|
|
var tlsConfig *tls.Config
|
|
|
|
if proto == "tcp-tls" {
|
|
tlsConfig = new(tls.Config)
|
|
tlsConfig.ServerName = tlsServerName
|
|
}
|
|
|
|
// Initialize the connections cache
|
|
u.transport = NewTransport(endpoint)
|
|
u.transport.tlsConfig = tlsConfig
|
|
u.transport.Start()
|
|
|
|
return u, nil
|
|
}
|
|
|
|
// Exchange provides an implementation for the Upstream interface
|
|
func (u *DnsUpstream) Exchange(ctx context.Context, query *dns.Msg) (*dns.Msg, error) {
|
|
resp, err := u.exchange(u.proto, query)
|
|
|
|
// Retry over TCP if response is truncated
|
|
if err == dns.ErrTruncated && u.proto == "udp" {
|
|
resp, err = u.exchange("tcp", query)
|
|
} else if err == dns.ErrTruncated && resp != nil {
|
|
// Reassemble something to be sent to client
|
|
m := new(dns.Msg)
|
|
m.SetReply(query)
|
|
m.Truncated = true
|
|
m.Authoritative = true
|
|
m.Rcode = dns.RcodeSuccess
|
|
return m, nil
|
|
}
|
|
|
|
if err != nil {
|
|
resp = &dns.Msg{}
|
|
resp.SetRcode(resp, dns.RcodeServerFailure)
|
|
}
|
|
|
|
return resp, err
|
|
}
|
|
|
|
// Clear resources
|
|
func (u *DnsUpstream) Close() error {
|
|
// Close active connections
|
|
u.transport.Stop()
|
|
return nil
|
|
}
|
|
|
|
// Performs a synchronous query. It sends the message m via the conn
|
|
// c and waits for a reply. The conn c is not closed.
|
|
func (u *DnsUpstream) exchange(proto string, query *dns.Msg) (r *dns.Msg, err error) {
|
|
// Establish a connection if needed (or reuse cached)
|
|
conn, err := u.transport.Dial(proto)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Write the request with a timeout
|
|
conn.SetWriteDeadline(time.Now().Add(u.timeout))
|
|
if err = conn.WriteMsg(query); err != nil {
|
|
conn.Close() // Not giving it back
|
|
return nil, err
|
|
}
|
|
|
|
// Write response with a timeout
|
|
conn.SetReadDeadline(time.Now().Add(u.timeout))
|
|
r, err = conn.ReadMsg()
|
|
if err != nil {
|
|
conn.Close() // Not giving it back
|
|
} else if err == nil && r.Id != query.Id {
|
|
err = dns.ErrId
|
|
conn.Close() // Not giving it back
|
|
}
|
|
|
|
if err == nil {
|
|
// Return it back to the connections cache if there were no errors
|
|
u.transport.Yield(conn)
|
|
}
|
|
return r, err
|
|
}
|