From 062b8e4ef3b6382cdbf13b80585b22328bc705c5 Mon Sep 17 00:00:00 2001 From: alexcesaro Date: Wed, 22 Oct 2014 17:55:36 +0200 Subject: [PATCH] Added support for SMTPS Gomail now automatically uses SMTPS on port 465 --- README.md | 1 + mailer.go | 4 ++-- send.go | 44 +++++++++++++++++++++++++++++++---- send_test.go | 66 ++++++++++++++++++++++++++++++++++++++++++++++------ 4 files changed, 101 insertions(+), 14 deletions(-) diff --git a/README.md b/README.md index f09c5ea..910adc5 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,7 @@ It requires Go 1.2 or newer. * Supports HTML and text templates * Attachments * Embedded images + * SSL/TLS support * Automatic encoding of special characters * Well-documented * High test coverage diff --git a/mailer.go b/mailer.go index 193de54..bcbef01 100644 --- a/mailer.go +++ b/mailer.go @@ -67,7 +67,7 @@ func NewMailer(host string, username string, password string, port int, settings // gomail.NewCustomMailer("host:587", smtp.CRAMMD5Auth("username", "secret")) func NewCustomMailer(addr string, auth smtp.Auth, settings ...MailerSetting) *Mailer { // Error is not handled here to preserve backward compatibility - host, _, _ := net.SplitHostPort(addr) + host, port, _ := net.SplitHostPort(addr) m := &Mailer{ addr: addr, @@ -83,7 +83,7 @@ func NewCustomMailer(addr string, auth smtp.Auth, settings ...MailerSetting) *Ma m.config = &tls.Config{ServerName: host} } if m.send == nil { - m.send = m.getSendMailFunc() + m.send = m.getSendMailFunc(port == "465") } return m diff --git a/send.go b/send.go index 22062f6..77aa6a2 100644 --- a/send.go +++ b/send.go @@ -3,18 +3,22 @@ package gomail import ( "crypto/tls" "io" + "net" "net/smtp" ) -func (m *Mailer) getSendMailFunc() SendMailFunc { +func (m *Mailer) getSendMailFunc(ssl bool) SendMailFunc { return func(addr string, a smtp.Auth, from string, to []string, msg []byte) error { - c, err := initSMTP(addr) + var c smtpClient + var err error + if ssl { + c, err = sslDial(addr, m.host, m.config) + } else { + c, err = starttlsDial(addr, m.config) + } if err != nil { return err } - if ok, _ := c.Extension("STARTTLS"); ok { - return c.StartTLS(m.config) - } defer c.Close() if a != nil { @@ -52,10 +56,40 @@ func (m *Mailer) getSendMailFunc() SendMailFunc { } } +func sslDial(addr, host string, config *tls.Config) (smtpClient, error) { + conn, err := initTLS("tcp", addr, config) + if err != nil { + return nil, err + } + + return newClient(conn, host) +} + +func starttlsDial(addr string, config *tls.Config) (smtpClient, error) { + c, err := initSMTP(addr) + if err != nil { + return c, err + } + + if ok, _ := c.Extension("STARTTLS"); ok { + return c, c.StartTLS(config) + } + + return c, nil +} + var initSMTP = func(addr string) (smtpClient, error) { return smtp.Dial(addr) } +var initTLS = func(network, addr string, config *tls.Config) (*tls.Conn, error) { + return tls.Dial(network, addr, config) +} + +var newClient = func(conn net.Conn, host string) (smtpClient, error) { + return smtp.NewClient(conn, host) +} + type smtpClient interface { Extension(string) (bool, string) StartTLS(*tls.Config) error diff --git a/send_test.go b/send_test.go index f564ec0..23a3db2 100644 --- a/send_test.go +++ b/send_test.go @@ -3,18 +3,21 @@ package gomail import ( "crypto/tls" "io" + "net" "net/smtp" "testing" ) var ( - testAddr = "smtp.example.com:587" - testConfig = &tls.Config{InsecureSkipVerify: true} - testHost = "smtp.example.com" - testAuth = smtp.PlainAuth("", "user", "pwd", "smtp.example.com") - testFrom = "from@example.com" - testTo = []string{"to1@example.com", "to2@example.com"} - testBody = "Test message" + testAddr = "smtp.example.com:587" + testSSLAddr = "smtp.example.com:465" + testTLSConn = &tls.Conn{} + testConfig = &tls.Config{InsecureSkipVerify: true} + testHost = "smtp.example.com" + testAuth = smtp.PlainAuth("", "user", "pwd", "smtp.example.com") + testFrom = "from@example.com" + testTo = []string{"to1@example.com", "to2@example.com"} + testBody = "Test message" ) const wantMsg = "To: to1@example.com, to2@example.com\r\n" + @@ -43,6 +46,21 @@ func TestDefaultSendMail(t *testing.T) { }) } +func TestSSLSendMail(t *testing.T) { + testSendMail(t, testSSLAddr, nil, []string{ + "Extension AUTH", + "Auth", + "Mail " + testFrom, + "Rcpt " + testTo[0], + "Rcpt " + testTo[1], + "Data", + "Write message", + "Close writer", + "Quit", + "Close", + }) +} + func TestTLSConfigSendMail(t *testing.T) { testSendMail(t, testAddr, testConfig, []string{ "Extension STARTTLS", @@ -60,6 +78,21 @@ func TestTLSConfigSendMail(t *testing.T) { }) } +func TestTLSConfigSSLSendMail(t *testing.T) { + testSendMail(t, testSSLAddr, testConfig, []string{ + "Extension AUTH", + "Auth", + "Mail " + testFrom, + "Rcpt " + testTo[0], + "Rcpt " + testTo[1], + "Data", + "Write message", + "Close writer", + "Quit", + "Close", + }) +} + type mockClient struct { t *testing.T i int @@ -152,6 +185,25 @@ func testSendMail(t *testing.T, addr string, config *tls.Config, want []string) return testClient, nil } + initTLS = func(network, addr string, config *tls.Config) (*tls.Conn, error) { + if network != "tcp" { + t.Errorf("Invalid network, got %q, want tcp", network) + } + assertAddr(t, addr, testClient.addr) + assertConfig(t, config, testClient.config) + return testTLSConn, nil + } + + newClient = func(conn net.Conn, host string) (smtpClient, error) { + if conn != testTLSConn { + t.Error("Invalid TLS connection used") + } + if host != testHost { + t.Errorf("Invalid host, got %q, want %q", host, testHost) + } + return testClient, nil + } + msg := NewMessage() msg.SetHeader("From", testFrom) msg.SetHeader("To", testTo...)