diff --git a/smtp.go b/smtp.go index 8f19b0c..2aa49c8 100644 --- a/smtp.go +++ b/smtp.go @@ -7,6 +7,7 @@ import ( "net" "net/smtp" "strings" + "time" ) // A Dialer is a dialer to an SMTP server. @@ -57,7 +58,7 @@ func NewPlainDialer(host string, port int, username, password string) *Dialer { // Dial dials and authenticates to an SMTP server. The returned SendCloser // should be closed when done using it. func (d *Dialer) Dial() (SendCloser, error) { - conn, err := netDial("tcp", addr(d.Host, d.Port)) + conn, err := netDialTimeout("tcp", addr(d.Host, d.Port), 10*time.Second) if err != nil { return nil, err } @@ -181,9 +182,9 @@ func (c *smtpSender) Close() error { // Stubbed out for tests. var ( - netDial = net.Dial - tlsClient = tls.Client - smtpNewClient = func(conn net.Conn, host string) (smtpClient, error) { + netDialTimeout = net.DialTimeout + tlsClient = tls.Client + smtpNewClient = func(conn net.Conn, host string) (smtpClient, error) { return smtp.NewClient(conn, host) } ) diff --git a/smtp_test.go b/smtp_test.go index ac1e5ef..b6f9155 100644 --- a/smtp_test.go +++ b/smtp_test.go @@ -8,6 +8,7 @@ import ( "net/smtp" "reflect" "testing" + "time" ) const ( @@ -247,7 +248,7 @@ func doTestSendMail(t *testing.T, d *Dialer, want []string, timeout bool) { timeout: timeout, } - netDial = func(network, address string) (net.Conn, error) { + netDialTimeout = func(network, address string, d time.Duration) (net.Conn, error) { if network != "tcp" { t.Errorf("Invalid network, got %q, want tcp", network) }