From a0730d2c47921dcf2ede8854f918481ca7e54a86 Mon Sep 17 00:00:00 2001 From: alexcesaro Date: Wed, 22 Oct 2014 17:47:24 +0200 Subject: [PATCH] Added the SetTLSConfig mailer setting --- README.md | 18 ++++- gomail_test.go | 8 +- mailer.go | 41 ++++++++--- send.go | 68 +++++++++++++++++ send_test.go | 193 +++++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 311 insertions(+), 17 deletions(-) create mode 100644 send.go create mode 100644 send_test.go diff --git a/README.md b/README.md index fb3e56c..f09c5ea 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ ## Introduction -Package gomail provides a simple interface to send emails. +Gomail is a very simple and powerful package to send emails. It requires Go 1.2 or newer. @@ -54,7 +54,7 @@ func main() { msg.Attach(f) // Send the email to Bob, Cora and Dan - mailer := gomail.NewMailer("smtp.example.com", "user", "123456", 25) + mailer := gomail.NewMailer("smtp.example.com", "user", "123456", 587) if err := mailer.Send(msg); err != nil { panic(err) } @@ -62,6 +62,20 @@ func main() { ``` +## FAQ + +### x509: certificate signed by unknown authority + +If you get this error it means the certificate used by the SMTP server is not +considered valid by the client running Gomail. As a quick workaround you can +bypass the verification of the server's certificate chain and host name by using +`SetTLSConfig`: + + mailer := gomail.NewMailer("smtp.example.com", "user", "123456", 587, gomail.SetTLSConfig(&tls.Config{InsecureSkipVerify: true})) + +Note, however, that this is insecure and should not be used in production. + + ## Contact You are more than welcome to open issues and send pull requests if you find a diff --git a/gomail_test.go b/gomail_test.go index 694374b..77e3e43 100644 --- a/gomail_test.go +++ b/gomail_test.go @@ -450,7 +450,7 @@ func TestBase64LineLength(t *testing.T) { func testMessage(t *testing.T, msg *Message, bCount int, emails ...message) { now = stubNow - mailer := NewMailer("host", "username", "password", 25, SetSendMail(stubSendMail(t, bCount, emails...))) + mailer := NewMailer("host", "username", "password", 587, SetSendMail(stubSendMail(t, bCount, emails...))) err := mailer.Send(msg) if err != nil { @@ -470,8 +470,8 @@ func stubSendMail(t *testing.T, bCount int, emails ...message) SendMailFunc { } want := emails[i] - if addr != "host:25" { - t.Fatalf("Invalid address, got %q, want host:25", addr) + if addr != "host:587" { + t.Fatalf("Invalid address, got %q, want host:587", addr) } if from != want.from { @@ -592,7 +592,7 @@ func BenchmarkFull(b *testing.B) { msg.Attach(CreateFile("benchmark.txt", []byte("Benchmark"))) msg.Embed(CreateFile("benchmark.jpg", []byte("Benchmark"))) - mailer := NewMailer("host", "username", "password", 25, SetSendMail(emptyFunc)) + mailer := NewMailer("host", "username", "password", 587, SetSendMail(emptyFunc)) if err := mailer.Send(msg); err != nil { panic(err) } diff --git a/mailer.go b/mailer.go index 1c5ec1f..193de54 100644 --- a/mailer.go +++ b/mailer.go @@ -2,9 +2,11 @@ package gomail import ( "bytes" + "crypto/tls" "errors" "fmt" "io/ioutil" + "net" "net/mail" "net/smtp" "strings" @@ -12,15 +14,17 @@ import ( // A Mailer represents an SMTP server. type Mailer struct { - addr string - auth smtp.Auth - send SendMailFunc + addr string + host string + config *tls.Config + auth smtp.Auth + send SendMailFunc } // A MailerSetting can be used in a mailer constructor to configure it. type MailerSetting func(m *Mailer) -// SetSendMail is an option to set the email-sending function of a mailer. +// SetSendMail allows to set the email-sending function of a mailer. // // Example: // @@ -34,6 +38,14 @@ func SetSendMail(s SendMailFunc) MailerSetting { } } +// SetTLSConfig allows to set the TLS configuration used to connect the SMTP +// server. +func SetTLSConfig(c *tls.Config) MailerSetting { + return func(m *Mailer) { + m.config = c + } +} + // A SendMailFunc is a function to send emails with the same signature than // smtp.SendMail. type SendMailFunc func(addr string, a smtp.Auth, from string, to []string, msg []byte) error @@ -52,22 +64,29 @@ func NewMailer(host string, username string, password string, port int, settings // // Example: // -// gomail.NewCustomMailer("host:25", smtp.CRAMMD5Auth("username", "secret")) +// gomail.NewCustomMailer("host:587", smtp.CRAMMD5Auth("username", "secret")) func NewCustomMailer(addr string, auth smtp.Auth, settings ...MailerSetting) *Mailer { + // Error is not handled here to preserve backward compatibility + host, _, _ := net.SplitHostPort(addr) + m := &Mailer{ addr: addr, + host: host, auth: auth, - send: smtp.SendMail, } - m.applySettings(settings) - return m -} - -func (m *Mailer) applySettings(settings []MailerSetting) { for _, s := range settings { s(m) } + + if m.config == nil { + m.config = &tls.Config{ServerName: host} + } + if m.send == nil { + m.send = m.getSendMailFunc() + } + + return m } // Send sends the emails to all the recipients of the message. diff --git a/send.go b/send.go new file mode 100644 index 0000000..22062f6 --- /dev/null +++ b/send.go @@ -0,0 +1,68 @@ +package gomail + +import ( + "crypto/tls" + "io" + "net/smtp" +) + +func (m *Mailer) getSendMailFunc() SendMailFunc { + return func(addr string, a smtp.Auth, from string, to []string, msg []byte) error { + c, err := initSMTP(addr) + if err != nil { + return err + } + if ok, _ := c.Extension("STARTTLS"); ok { + return c.StartTLS(m.config) + } + defer c.Close() + + if a != nil { + if ok, _ := c.Extension("AUTH"); ok { + if err = c.Auth(a); err != nil { + return err + } + } + } + + if err = c.Mail(from); err != nil { + return err + } + + for _, addr := range to { + if err = c.Rcpt(addr); err != nil { + return err + } + } + + w, err := c.Data() + if err != nil { + return err + } + _, err = w.Write(msg) + if err != nil { + return err + } + err = w.Close() + if err != nil { + return err + } + + return c.Quit() + } +} + +var initSMTP = func(addr string) (smtpClient, error) { + return smtp.Dial(addr) +} + +type smtpClient interface { + Extension(string) (bool, string) + StartTLS(*tls.Config) error + Auth(smtp.Auth) error + Mail(string) error + Rcpt(string) error + Data() (io.WriteCloser, error) + Quit() error + Close() error +} diff --git a/send_test.go b/send_test.go new file mode 100644 index 0000000..f564ec0 --- /dev/null +++ b/send_test.go @@ -0,0 +1,193 @@ +package gomail + +import ( + "crypto/tls" + "io" + "net/smtp" + "testing" +) + +var ( + testAddr = "smtp.example.com:587" + testConfig = &tls.Config{InsecureSkipVerify: true} + testHost = "smtp.example.com" + testAuth = smtp.PlainAuth("", "user", "pwd", "smtp.example.com") + testFrom = "from@example.com" + testTo = []string{"to1@example.com", "to2@example.com"} + testBody = "Test message" +) + +const wantMsg = "To: to1@example.com, to2@example.com\r\n" + + "From: from@example.com\r\n" + + "Mime-Version: 1.0\r\n" + + "Date: 25 Jun 14 17:46 +0000\r\n" + + "Content-Type: text/plain; charset=UTF-8\r\n" + + "Content-Transfer-Encoding: quoted-printable\r\n" + + "\r\n" + + "Test message" + +func TestDefaultSendMail(t *testing.T) { + testSendMail(t, testAddr, nil, []string{ + "Extension STARTTLS", + "StartTLS", + "Extension AUTH", + "Auth", + "Mail " + testFrom, + "Rcpt " + testTo[0], + "Rcpt " + testTo[1], + "Data", + "Write message", + "Close writer", + "Quit", + "Close", + }) +} + +func TestTLSConfigSendMail(t *testing.T) { + testSendMail(t, testAddr, testConfig, []string{ + "Extension STARTTLS", + "StartTLS", + "Extension AUTH", + "Auth", + "Mail " + testFrom, + "Rcpt " + testTo[0], + "Rcpt " + testTo[1], + "Data", + "Write message", + "Close writer", + "Quit", + "Close", + }) +} + +type mockClient struct { + t *testing.T + i int + want []string + addr string + auth smtp.Auth + config *tls.Config +} + +func (c *mockClient) Extension(ext string) (bool, string) { + c.do("Extension " + ext) + return true, "" +} + +func (c *mockClient) StartTLS(config *tls.Config) error { + assertConfig(c.t, config, c.config) + c.do("StartTLS") + return nil +} + +func (c *mockClient) Auth(a smtp.Auth) error { + assertAuth(c.t, a, c.auth) + c.do("Auth") + return nil +} + +func (c *mockClient) Mail(from string) error { + c.do("Mail " + from) + return nil +} + +func (c *mockClient) Rcpt(to string) error { + c.do("Rcpt " + to) + return nil +} + +func (c *mockClient) Data() (io.WriteCloser, error) { + c.do("Data") + return &mockWriter{c: c, want: wantMsg}, nil +} + +func (c *mockClient) Quit() error { + c.do("Quit") + return nil +} + +func (c *mockClient) Close() error { + c.do("Close") + return nil +} + +func (c *mockClient) do(cmd string) { + if c.i >= len(c.want) { + c.t.Fatalf("Invalid command %q", cmd) + } + + if cmd != c.want[c.i] { + c.t.Fatalf("Invalid command, got %q, want %q", cmd, c.want[c.i]) + } + c.i++ +} + +type mockWriter struct { + want string + c *mockClient +} + +func (w *mockWriter) Write(p []byte) (int, error) { + w.c.do("Write message") + compareBodies(w.c.t, string(p), w.want) + return len(p), nil +} + +func (w *mockWriter) Close() error { + w.c.do("Close writer") + return nil +} + +func testSendMail(t *testing.T, addr string, config *tls.Config, want []string) { + testClient := &mockClient{ + t: t, + want: want, + addr: addr, + auth: testAuth, + config: config, + } + + initSMTP = func(addr string) (smtpClient, error) { + assertAddr(t, addr, testClient.addr) + return testClient, nil + } + + msg := NewMessage() + msg.SetHeader("From", testFrom) + msg.SetHeader("To", testTo...) + msg.SetBody("text/plain", testBody) + + var settings []MailerSetting + if config != nil { + settings = []MailerSetting{SetTLSConfig(config)} + } + + mailer := NewCustomMailer(addr, testAuth, settings...) + if err := mailer.Send(msg); err != nil { + t.Error(err) + } +} + +func assertAuth(t *testing.T, got, want smtp.Auth) { + if got != want { + t.Errorf("Invalid auth, got %#v, want %#v", got, want) + } +} + +func assertAddr(t *testing.T, got, want string) { + if got != want { + t.Errorf("Invalid addr, got %q, want %q", got, want) + } +} + +func assertConfig(t *testing.T, got, want *tls.Config) { + if want == nil { + want = &tls.Config{ServerName: testHost} + } + if got.ServerName != want.ServerName { + t.Errorf("Invalid field ServerName in config, got %q, want %q", got.ServerName, want.ServerName) + } + if got.InsecureSkipVerify != want.InsecureSkipVerify { + t.Errorf("Invalid field InsecureSkipVerify in config, got %v, want %v", got.InsecureSkipVerify, want.InsecureSkipVerify) + } +}