diff --git a/gomail_test.go b/gomail_test.go index 30d0eeb..3cdd9a9 100644 --- a/gomail_test.go +++ b/gomail_test.go @@ -2,7 +2,8 @@ package gomail import ( "encoding/base64" - "net/smtp" + "io" + "io/ioutil" "path/filepath" "regexp" "strconv" @@ -11,6 +12,12 @@ import ( "time" ) +func init() { + now = func() time.Time { + return time.Date(2014, 06, 25, 17, 46, 0, 0, time.UTC) + } +} + type message struct { from string to []string @@ -23,8 +30,8 @@ func TestMessage(t *testing.T) { msg.SetHeader("To", msg.FormatAddress("to@example.com", "Señor To"), "tobis@example.com") msg.SetAddressHeader("Cc", "cc@example.com", "A, B") msg.SetAddressHeader("X-To", "ccbis@example.com", "à, b") - msg.SetDateHeader("X-Date", stubNow()) - msg.SetHeader("X-Date-2", msg.FormatDate(stubNow())) + msg.SetDateHeader("X-Date", now()) + msg.SetHeader("X-Date-2", msg.FormatDate(now())) msg.SetHeader("Subject", "¡Hola, señor!") msg.SetHeaders(map[string][]string{ "X-Headers": {"Test", "Café"}, @@ -488,31 +495,20 @@ func TestBase64LineLength(t *testing.T) { } func testMessage(t *testing.T, msg *Message, bCount int, emails ...message) { - now = stubNow - mailer := NewMailer("host", "username", "password", 587, SetSendMail(stubSendMail(t, bCount, emails...))) - - err := mailer.Send(msg) + err := Send(stubSendMail(t, bCount, emails...), msg) if err != nil { t.Error(err) } } -func stubNow() time.Time { - return time.Date(2014, 06, 25, 17, 46, 0, 0, time.UTC) -} - -func stubSendMail(t *testing.T, bCount int, emails ...message) SendMailFunc { +func stubSendMail(t *testing.T, bCount int, emails ...message) SendFunc { i := 0 - return func(addr string, a smtp.Auth, from string, to []string, msg []byte) error { + return func(from string, to []string, msg io.Reader) error { if i > len(emails) { t.Fatalf("Only %d mails should be sent", len(emails)) } want := emails[i] - if addr != "host:587" { - t.Fatalf("Invalid address, got %q, want host:587", addr) - } - if from != want.from { t.Fatalf("Invalid from, got %q, want %q", from, want.from) } @@ -531,7 +527,11 @@ func stubSendMail(t *testing.T, bCount int, emails ...message) SendMailFunc { } } - got := string(msg) + content, err := ioutil.ReadAll(msg) + if err != nil { + t.Error(err) + } + got := string(content) wantMsg := string("Mime-Version: 1.0\r\n" + "Date: Wed, 25 Jun 2014 17:46:00 +0000\r\n" + want.content) @@ -613,7 +613,7 @@ func getBoundaries(t *testing.T, count int, msg string) []string { var boundaryRegExp = regexp.MustCompile("boundary=(\\w+)") func BenchmarkFull(b *testing.B) { - emptyFunc := func(addr string, a smtp.Auth, from string, to []string, msg []byte) error { + emptyFunc := func(from string, to []string, msg io.Reader) error { return nil } @@ -631,8 +631,7 @@ func BenchmarkFull(b *testing.B) { msg.Attach(CreateFile("benchmark.txt", []byte("Benchmark"))) msg.Embed(CreateFile("benchmark.jpg", []byte("Benchmark"))) - mailer := NewMailer("host", "username", "password", 587, SetSendMail(emptyFunc)) - if err := mailer.Send(msg); err != nil { + if err := Send(SendFunc(emptyFunc), msg); err != nil { panic(err) } } diff --git a/mailer.go b/mailer.go deleted file mode 100644 index 99ab619..0000000 --- a/mailer.go +++ /dev/null @@ -1,205 +0,0 @@ -package gomail - -import ( - "crypto/tls" - "errors" - "fmt" - "io/ioutil" - "net" - "net/mail" - "net/smtp" - "strings" -) - -// A Mailer represents an SMTP server. -type Mailer struct { - addr string - host string - config *tls.Config - auth smtp.Auth - send SendMailFunc -} - -// A MailerSetting can be used in a mailer constructor to configure it. -type MailerSetting func(m *Mailer) - -// SetSendMail allows to set the email-sending function of a mailer. -// -// Example: -// -// myFunc := func(addr string, a smtp.Auth, from string, to []string, msg []byte) error { -// // Implement your email-sending function similar to smtp.SendMail -// } -// mailer := gomail.NewMailer("host", "user", "pwd", 465, SetSendMail(myFunc)) -func SetSendMail(s SendMailFunc) MailerSetting { - return func(m *Mailer) { - m.send = s - } -} - -// SetTLSConfig allows to set the TLS configuration used to connect the SMTP -// server. -func SetTLSConfig(c *tls.Config) MailerSetting { - return func(m *Mailer) { - m.config = c - } -} - -// A SendMailFunc is a function to send emails with the same signature than -// smtp.SendMail. -type SendMailFunc func(addr string, a smtp.Auth, from string, to []string, msg []byte) error - -// NewMailer returns a mailer. The given parameters are used to connect to the -// SMTP server via a PLAIN authentication mechanism. -func NewMailer(host string, username string, password string, port int, settings ...MailerSetting) *Mailer { - return NewCustomMailer( - fmt.Sprintf("%s:%d", host, port), - smtp.PlainAuth("", username, password, host), - settings..., - ) -} - -// NewCustomMailer creates a mailer with the given authentication mechanism. -// -// Example: -// -// gomail.NewCustomMailer("host:587", smtp.CRAMMD5Auth("username", "secret")) -func NewCustomMailer(addr string, auth smtp.Auth, settings ...MailerSetting) *Mailer { - // Error is not handled here to preserve backward compatibility - host, port, _ := net.SplitHostPort(addr) - - m := &Mailer{ - addr: addr, - host: host, - auth: auth, - } - - for _, s := range settings { - s(m) - } - - if m.config == nil { - m.config = &tls.Config{ServerName: host} - } - if m.send == nil { - m.send = m.getSendMailFunc(port == "465") - } - - return m -} - -// Send sends the emails to all the recipients of the message. -func (m *Mailer) Send(msg *Message) error { - message := msg.Export() - - from, err := getFrom(message) - if err != nil { - return err - } - recipients, bcc, err := getRecipients(message) - if err != nil { - return err - } - - h := flattenHeader(message, "") - body, err := ioutil.ReadAll(message.Body) - if err != nil { - return err - } - - mail := append(h, body...) - if err := m.send(m.addr, m.auth, from, recipients, mail); err != nil { - return err - } - - for _, to := range bcc { - h = flattenHeader(message, to) - mail = append(h, body...) - if err := m.send(m.addr, m.auth, from, []string{to}, mail); err != nil { - return err - } - } - - return nil -} - -func flattenHeader(msg *mail.Message, bcc string) []byte { - buf := getBuffer() - defer putBuffer(buf) - - for field, value := range msg.Header { - if field != "Bcc" { - buf.WriteString(field) - buf.WriteString(": ") - buf.WriteString(strings.Join(value, ", ")) - buf.WriteString("\r\n") - } else if bcc != "" { - for _, to := range value { - if strings.Contains(to, bcc) { - buf.WriteString(field) - buf.WriteString(": ") - buf.WriteString(to) - buf.WriteString("\r\n") - } - } - } - } - buf.WriteString("\r\n") - - return buf.Bytes() -} - -func getFrom(msg *mail.Message) (string, error) { - from := msg.Header.Get("Sender") - if from == "" { - from = msg.Header.Get("From") - if from == "" { - return "", errors.New("mailer: invalid message, \"From\" field is absent") - } - } - - return parseAddress(from) -} - -func getRecipients(msg *mail.Message) (recipients, bcc []string, err error) { - for _, field := range []string{"Bcc", "To", "Cc"} { - if addresses, ok := msg.Header[field]; ok { - for _, addr := range addresses { - switch field { - case "Bcc": - bcc, err = addAdress(bcc, addr) - default: - recipients, err = addAdress(recipients, addr) - } - if err != nil { - return recipients, bcc, err - } - } - } - } - - return recipients, bcc, nil -} - -func addAdress(list []string, addr string) ([]string, error) { - addr, err := parseAddress(addr) - if err != nil { - return list, err - } - for _, a := range list { - if addr == a { - return list, nil - } - } - - return append(list, addr), nil -} - -func parseAddress(field string) (string, error) { - a, err := mail.ParseAddress(field) - if a == nil { - return "", err - } - - return a.Address, err -} diff --git a/send.go b/send.go index 77aa6a2..05a6c68 100644 --- a/send.go +++ b/send.go @@ -1,102 +1,161 @@ package gomail import ( - "crypto/tls" + "bytes" + "errors" + "fmt" "io" - "net" - "net/smtp" + "io/ioutil" + "net/mail" + "strings" ) -func (m *Mailer) getSendMailFunc(ssl bool) SendMailFunc { - return func(addr string, a smtp.Auth, from string, to []string, msg []byte) error { - var c smtpClient - var err error - if ssl { - c, err = sslDial(addr, m.host, m.config) - } else { - c, err = starttlsDial(addr, m.config) +// Sender is the interface that wraps the Send method. +// +// Send sends an email to the given addresses. +type Sender interface { + Send(from string, to []string, msg io.Reader) error +} + +// SendCloser is the interface that groups the Send and Close methods. +type SendCloser interface { + Sender + Close() error +} + +// A SendFunc is a function that sends emails to the given adresses. +// The SendFunc type is an adapter to allow the use of ordinary functions as +// email senders. If f is a function with the appropriate signature, SendFunc(f) +// is a Sender object that calls f. +type SendFunc func(from string, to []string, msg io.Reader) error + +// Send calls f(from, to, msg). +func (f SendFunc) Send(from string, to []string, msg io.Reader) error { + return f(from, to, msg) +} + +// Send sends emails using the given Sender. +func Send(s Sender, msg ...*Message) error { + for i, m := range msg { + if err := send(s, m); err != nil { + return fmt.Errorf("gomail: could not send email %d: %v", i+1, err) } - if err != nil { + } + + return nil +} + +func send(s Sender, msg *Message) error { + message := msg.Export() + + from, err := getFrom(message) + if err != nil { + return err + } + recipients, bcc, err := getRecipients(message) + if err != nil { + return err + } + + h := flattenHeader(message, "") + body, err := ioutil.ReadAll(message.Body) + if err != nil { + return err + } + + mail := bytes.NewReader(append(h, body...)) + if err := s.Send(from, recipients, mail); err != nil { + return err + } + + for _, to := range bcc { + h = flattenHeader(message, to) + mail = bytes.NewReader(append(h, body...)) + if err := s.Send(from, []string{to}, mail); err != nil { return err } - defer c.Close() + } - if a != nil { - if ok, _ := c.Extension("AUTH"); ok { - if err = c.Auth(a); err != nil { - return err + return nil +} + +func flattenHeader(msg *mail.Message, bcc string) []byte { + buf := getBuffer() + defer putBuffer(buf) + + for field, value := range msg.Header { + if field != "Bcc" { + buf.WriteString(field) + buf.WriteString(": ") + buf.WriteString(strings.Join(value, ", ")) + buf.WriteString("\r\n") + } else if bcc != "" { + for _, to := range value { + if strings.Contains(to, bcc) { + buf.WriteString(field) + buf.WriteString(": ") + buf.WriteString(to) + buf.WriteString("\r\n") } } } + } + buf.WriteString("\r\n") - if err = c.Mail(from); err != nil { - return err + return buf.Bytes() +} + +func getFrom(msg *mail.Message) (string, error) { + from := msg.Header.Get("Sender") + if from == "" { + from = msg.Header.Get("From") + if from == "" { + return "", errors.New("mailer: invalid message, \"From\" field is absent") } + } - for _, addr := range to { - if err = c.Rcpt(addr); err != nil { - return err + return parseAddress(from) +} + +func getRecipients(msg *mail.Message) (recipients, bcc []string, err error) { + for _, field := range []string{"Bcc", "To", "Cc"} { + if addresses, ok := msg.Header[field]; ok { + for _, addr := range addresses { + switch field { + case "Bcc": + bcc, err = addAdress(bcc, addr) + default: + recipients, err = addAdress(recipients, addr) + } + if err != nil { + return recipients, bcc, err + } } } - - w, err := c.Data() - if err != nil { - return err - } - _, err = w.Write(msg) - if err != nil { - return err - } - err = w.Close() - if err != nil { - return err - } - - return c.Quit() } + + return recipients, bcc, nil } -func sslDial(addr, host string, config *tls.Config) (smtpClient, error) { - conn, err := initTLS("tcp", addr, config) +func addAdress(list []string, addr string) ([]string, error) { + addr, err := parseAddress(addr) if err != nil { - return nil, err + return list, err + } + for _, a := range list { + if addr == a { + return list, nil + } } - return newClient(conn, host) + return append(list, addr), nil } -func starttlsDial(addr string, config *tls.Config) (smtpClient, error) { - c, err := initSMTP(addr) - if err != nil { - return c, err +func parseAddress(field string) (string, error) { + a, err := mail.ParseAddress(field) + if a == nil { + return "", err } - if ok, _ := c.Extension("STARTTLS"); ok { - return c, c.StartTLS(config) - } - - return c, nil -} - -var initSMTP = func(addr string) (smtpClient, error) { - return smtp.Dial(addr) -} - -var initTLS = func(network, addr string, config *tls.Config) (*tls.Conn, error) { - return tls.Dial(network, addr, config) -} - -var newClient = func(conn net.Conn, host string) (smtpClient, error) { - return smtp.NewClient(conn, host) -} - -type smtpClient interface { - Extension(string) (bool, string) - StartTLS(*tls.Config) error - Auth(smtp.Auth) error - Mail(string) error - Rcpt(string) error - Data() (io.WriteCloser, error) - Quit() error - Close() error + return a.Address, err } diff --git a/send_test.go b/send_test.go index 7f074bc..f252ac5 100644 --- a/send_test.go +++ b/send_test.go @@ -1,245 +1,79 @@ package gomail import ( - "crypto/tls" "io" - "net" - "net/smtp" + "io/ioutil" + "reflect" "testing" ) -var ( - testAddr = "smtp.example.com:587" - testSSLAddr = "smtp.example.com:465" - testTLSConn = &tls.Conn{} - testConfig = &tls.Config{InsecureSkipVerify: true} - testHost = "smtp.example.com" - testAuth = smtp.PlainAuth("", "user", "pwd", "smtp.example.com") - testFrom = "from@example.com" - testTo = []string{"to1@example.com", "to2@example.com"} - testBody = "Test message" +const ( + testTo1 = "to1@example.com" + testTo2 = "to2@example.com" + testFrom = "from@example.com" + testBody = "Test message" + testMsg = "To: " + testTo1 + ", " + testTo2 + "\r\n" + + "From: " + testFrom + "\r\n" + + "Mime-Version: 1.0\r\n" + + "Date: Wed, 25 Jun 2014 17:46:00 +0000\r\n" + + "Content-Type: text/plain; charset=UTF-8\r\n" + + "Content-Transfer-Encoding: quoted-printable\r\n" + + "\r\n" + + testBody ) -const wantMsg = "To: to1@example.com, to2@example.com\r\n" + - "From: from@example.com\r\n" + - "Mime-Version: 1.0\r\n" + - "Date: Wed, 25 Jun 2014 17:46:00 +0000\r\n" + - "Content-Type: text/plain; charset=UTF-8\r\n" + - "Content-Transfer-Encoding: quoted-printable\r\n" + - "\r\n" + - "Test message" +type mockSender SendFunc -func TestDefaultSendMail(t *testing.T) { - testSendMail(t, testAddr, nil, []string{ - "Extension STARTTLS", - "StartTLS", - "Extension AUTH", - "Auth", - "Mail " + testFrom, - "Rcpt " + testTo[0], - "Rcpt " + testTo[1], - "Data", - "Write message", - "Close writer", - "Quit", - "Close", - }) +func (s mockSender) Send(from string, to []string, msg io.Reader) error { + return s(from, to, msg) } -func TestSSLSendMail(t *testing.T) { - testSendMail(t, testSSLAddr, nil, []string{ - "Extension AUTH", - "Auth", - "Mail " + testFrom, - "Rcpt " + testTo[0], - "Rcpt " + testTo[1], - "Data", - "Write message", - "Close writer", - "Quit", - "Close", - }) +type mockSendCloser struct { + mockSender + close func() error } -func TestTLSConfigSendMail(t *testing.T) { - testSendMail(t, testAddr, testConfig, []string{ - "Extension STARTTLS", - "StartTLS", - "Extension AUTH", - "Auth", - "Mail " + testFrom, - "Rcpt " + testTo[0], - "Rcpt " + testTo[1], - "Data", - "Write message", - "Close writer", - "Quit", - "Close", - }) +func (s *mockSendCloser) Close() error { + return s.close() } -func TestTLSConfigSSLSendMail(t *testing.T) { - testSendMail(t, testSSLAddr, testConfig, []string{ - "Extension AUTH", - "Auth", - "Mail " + testFrom, - "Rcpt " + testTo[0], - "Rcpt " + testTo[1], - "Data", - "Write message", - "Close writer", - "Quit", - "Close", - }) -} - -type mockClient struct { - t *testing.T - i int - want []string - addr string - auth smtp.Auth - config *tls.Config -} - -func (c *mockClient) Extension(ext string) (bool, string) { - c.do("Extension " + ext) - return true, "" -} - -func (c *mockClient) StartTLS(config *tls.Config) error { - assertConfig(c.t, config, c.config) - c.do("StartTLS") - return nil -} - -func (c *mockClient) Auth(a smtp.Auth) error { - assertAuth(c.t, a, c.auth) - c.do("Auth") - return nil -} - -func (c *mockClient) Mail(from string) error { - c.do("Mail " + from) - return nil -} - -func (c *mockClient) Rcpt(to string) error { - c.do("Rcpt " + to) - return nil -} - -func (c *mockClient) Data() (io.WriteCloser, error) { - c.do("Data") - return &mockWriter{c: c, want: wantMsg}, nil -} - -func (c *mockClient) Quit() error { - c.do("Quit") - return nil -} - -func (c *mockClient) Close() error { - c.do("Close") - return nil -} - -func (c *mockClient) do(cmd string) { - if c.i >= len(c.want) { - c.t.Fatalf("Invalid command %q", cmd) +func TestSend(t *testing.T) { + s := &mockSendCloser{ + mockSender: stubSend(t, testFrom, []string{testTo1, testTo2}, testMsg), + close: func() error { + t.Error("Close() should not be called in Send()") + return nil + }, } - - if cmd != c.want[c.i] { - c.t.Fatalf("Invalid command, got %q, want %q", cmd, c.want[c.i]) + if err := Send(s, getTestMessage()); err != nil { + t.Errorf("Send(): %v", err) } - c.i++ } -type mockWriter struct { - want string - c *mockClient +func getTestMessage() *Message { + m := NewMessage() + m.SetHeader("From", testFrom) + m.SetHeader("To", testTo1, testTo2) + m.SetBody("text/plain", testBody) + + return m } -func (w *mockWriter) Write(p []byte) (int, error) { - w.c.do("Write message") - compareBodies(w.c.t, string(p), w.want) - return len(p), nil -} - -func (w *mockWriter) Close() error { - w.c.do("Close writer") - return nil -} - -func testSendMail(t *testing.T, addr string, config *tls.Config, want []string) { - testClient := &mockClient{ - t: t, - want: want, - addr: addr, - auth: testAuth, - config: config, - } - - initSMTP = func(addr string) (smtpClient, error) { - assertAddr(t, addr, testClient.addr) - return testClient, nil - } - - initTLS = func(network, addr string, config *tls.Config) (*tls.Conn, error) { - if network != "tcp" { - t.Errorf("Invalid network, got %q, want tcp", network) +func stubSend(t *testing.T, wantFrom string, wantTo []string, wantBody string) mockSender { + return func(from string, to []string, msg io.Reader) error { + if from != wantFrom { + t.Errorf("invalid from, got %q, want %q", from, wantFrom) } - assertAddr(t, addr, testClient.addr) - assertConfig(t, config, testClient.config) - return testTLSConn, nil - } - - newClient = func(conn net.Conn, host string) (smtpClient, error) { - if conn != testTLSConn { - t.Error("Invalid TLS connection used") + if !reflect.DeepEqual(to, wantTo) { + t.Errorf("invalid to, got %v, want %v", to, wantTo) } - if host != testHost { - t.Errorf("Invalid host, got %q, want %q", host, testHost) + + content, err := ioutil.ReadAll(msg) + if err != nil { + t.Fatal(err) } - return testClient, nil - } + compareBodies(t, string(content), wantBody) - msg := NewMessage() - msg.SetHeader("From", testFrom) - msg.SetHeader("To", testTo...) - msg.SetBody("text/plain", testBody) - - var settings []MailerSetting - if config != nil { - settings = []MailerSetting{SetTLSConfig(config)} - } - - mailer := NewCustomMailer(addr, testAuth, settings...) - if err := mailer.Send(msg); err != nil { - t.Error(err) - } -} - -func assertAuth(t *testing.T, got, want smtp.Auth) { - if got != want { - t.Errorf("Invalid auth, got %#v, want %#v", got, want) - } -} - -func assertAddr(t *testing.T, got, want string) { - if got != want { - t.Errorf("Invalid addr, got %q, want %q", got, want) - } -} - -func assertConfig(t *testing.T, got, want *tls.Config) { - if want == nil { - want = &tls.Config{ServerName: testHost} - } - if got.ServerName != want.ServerName { - t.Errorf("Invalid field ServerName in config, got %q, want %q", got.ServerName, want.ServerName) - } - if got.InsecureSkipVerify != want.InsecureSkipVerify { - t.Errorf("Invalid field InsecureSkipVerify in config, got %v, want %v", got.InsecureSkipVerify, want.InsecureSkipVerify) + return nil } } diff --git a/smtp.go b/smtp.go new file mode 100644 index 0000000..4222be4 --- /dev/null +++ b/smtp.go @@ -0,0 +1,168 @@ +package gomail + +import ( + "crypto/tls" + "fmt" + "io" + "net" + "net/smtp" +) + +// An SMTPDialer is a dialer to an SMTP server. +type SMTPDialer struct { + // Host represents the host of the SMTP server. + Host string + // Port represents the port of the SMTP server. + Port int + // Auth represents the authentication mechanism used to authenticate to the + // SMTP server. + Auth smtp.Auth + // SSL defines whether an SSL connection is used. It should be false in + // most cases since the authentication mechanism should use the STARTTLS + // extension instead. + SSL bool + // TSLConfig represents the TLS configuration used for the TLS (when the + // STARTTLS extension is used) or SSL connection. + TLSConfig *tls.Config +} + +// NewPlainDialer returns an SMTPDialer. The given parameters are used to +// connect to the SMTP server via a PLAIN authentication mechanism. +func NewPlainDialer(host, username, password string, port int) *SMTPDialer { + return &SMTPDialer{ + Host: host, + Port: port, + Auth: smtp.PlainAuth("", username, password, host), + SSL: port == 465, + } +} + +// Dial dials and authenticates to an SMTP server. The returned SendCloser +// should be closed when done using it. +func (d *SMTPDialer) Dial() (SendCloser, error) { + c, err := d.dial() + if err != nil { + return nil, err + } + + if d.Auth != nil { + if ok, _ := c.Extension("AUTH"); ok { + if err = c.Auth(d.Auth); err != nil { + c.Close() + return nil, err + } + } + } + + return &smtpSender{c}, nil +} + +func (d *SMTPDialer) dial() (smtpClient, error) { + if d.SSL { + return d.sslDial() + } + return d.starttlsDial() +} + +func (d *SMTPDialer) starttlsDial() (smtpClient, error) { + c, err := smtpDial(addr(d.Host, d.Port)) + if err != nil { + return nil, err + } + + if ok, _ := c.Extension("STARTTLS"); ok { + if err := c.StartTLS(d.tlsConfig()); err != nil { + c.Close() + return nil, err + } + } + + return c, nil +} + +func (d *SMTPDialer) sslDial() (smtpClient, error) { + conn, err := tlsDial("tcp", addr(d.Host, d.Port), d.tlsConfig()) + if err != nil { + return nil, err + } + + return newClient(conn, d.Host) +} + +func (d *SMTPDialer) tlsConfig() *tls.Config { + if d.TLSConfig == nil { + return &tls.Config{ServerName: d.Host} + } + + return d.TLSConfig +} + +func addr(host string, port int) string { + return fmt.Sprintf("%s:%d", host, port) +} + +// DialAndSend opens a connection to an SMTP server, sends the given emails and +// closes the connection. +func (d *SMTPDialer) DialAndSend(msg ...*Message) error { + s, err := d.Dial() + if err != nil { + return err + } + defer s.Close() + + return Send(s, msg...) +} + +type smtpSender struct { + smtpClient +} + +func (c *smtpSender) Send(from string, to []string, msg io.Reader) error { + if err := c.Mail(from); err != nil { + return err + } + + for _, addr := range to { + if err := c.Rcpt(addr); err != nil { + return err + } + } + + w, err := c.Data() + if err != nil { + return err + } + + if _, err = io.Copy(w, msg); err != nil { + w.Close() + return err + } + + return w.Close() +} + +func (c *smtpSender) Close() error { + return c.Quit() +} + +// Stubbed out for tests. +var ( + smtpDial = func(addr string) (smtpClient, error) { + return smtp.Dial(addr) + } + tlsDial = tls.Dial + newClient = func(conn net.Conn, host string) (smtpClient, error) { + return smtp.NewClient(conn, host) + } +) + +type smtpClient interface { + Extension(string) (bool, string) + StartTLS(*tls.Config) error + Auth(smtp.Auth) error + Mail(string) error + Rcpt(string) error + Data() (io.WriteCloser, error) + Quit() error + Close() error +} diff --git a/smtp_test.go b/smtp_test.go new file mode 100644 index 0000000..1d8cdc4 --- /dev/null +++ b/smtp_test.go @@ -0,0 +1,248 @@ +package gomail + +import ( + "crypto/tls" + "io" + "net" + "net/smtp" + "reflect" + "testing" +) + +var ( + testHost = "smtp.example.com" + testPort = 587 + testSSLPort = 465 + testTLSConn = &tls.Conn{} + testConfig = &tls.Config{InsecureSkipVerify: true} + testAuth = smtp.PlainAuth("", "user", "pwd", testHost) +) + +func TestSMTPDialer(t *testing.T) { + d := NewPlainDialer(testHost, "user", "pwd", testPort) + testSendMail(t, d, []string{ + "Extension STARTTLS", + "StartTLS", + "Extension AUTH", + "Auth", + "Mail " + testFrom, + "Rcpt " + testTo1, + "Rcpt " + testTo2, + "Data", + "Write message", + "Close writer", + "Quit", + "Close", + }) +} + +func TestSMTPDialerSSL(t *testing.T) { + d := NewPlainDialer(testHost, "user", "pwd", testSSLPort) + testSendMail(t, d, []string{ + "Extension AUTH", + "Auth", + "Mail " + testFrom, + "Rcpt " + testTo1, + "Rcpt " + testTo2, + "Data", + "Write message", + "Close writer", + "Quit", + "Close", + }) +} + +func TestSMTPDialerConfig(t *testing.T) { + d := NewPlainDialer(testHost, "user", "pwd", testPort) + d.TLSConfig = testConfig + testSendMail(t, d, []string{ + "Extension STARTTLS", + "StartTLS", + "Extension AUTH", + "Auth", + "Mail " + testFrom, + "Rcpt " + testTo1, + "Rcpt " + testTo2, + "Data", + "Write message", + "Close writer", + "Quit", + "Close", + }) +} + +func TestSMTPDialerSSLConfig(t *testing.T) { + d := NewPlainDialer(testHost, "user", "pwd", testSSLPort) + d.TLSConfig = testConfig + testSendMail(t, d, []string{ + "Extension AUTH", + "Auth", + "Mail " + testFrom, + "Rcpt " + testTo1, + "Rcpt " + testTo2, + "Data", + "Write message", + "Close writer", + "Quit", + "Close", + }) +} + +func TestSMTPDialerNoAuth(t *testing.T) { + d := &SMTPDialer{ + Host: testHost, + Port: testPort, + } + testSendMail(t, d, []string{ + "Extension STARTTLS", + "StartTLS", + "Mail " + testFrom, + "Rcpt " + testTo1, + "Rcpt " + testTo2, + "Data", + "Write message", + "Close writer", + "Quit", + "Close", + }) +} + +type mockClient struct { + t *testing.T + i int + want []string + addr string + auth smtp.Auth + config *tls.Config +} + +func (c *mockClient) Extension(ext string) (bool, string) { + c.do("Extension " + ext) + return true, "" +} + +func (c *mockClient) StartTLS(config *tls.Config) error { + assertConfig(c.t, config, c.config) + c.do("StartTLS") + return nil +} + +func (c *mockClient) Auth(a smtp.Auth) error { + assertAuth(c.t, a, c.auth) + c.do("Auth") + return nil +} + +func (c *mockClient) Mail(from string) error { + c.do("Mail " + from) + return nil +} + +func (c *mockClient) Rcpt(to string) error { + c.do("Rcpt " + to) + return nil +} + +func (c *mockClient) Data() (io.WriteCloser, error) { + c.do("Data") + return &mockWriter{c: c, want: testMsg}, nil +} + +func (c *mockClient) Quit() error { + c.do("Quit") + return nil +} + +func (c *mockClient) Close() error { + c.do("Close") + return nil +} + +func (c *mockClient) do(cmd string) { + if c.i >= len(c.want) { + c.t.Fatalf("Invalid command %q", cmd) + } + + if cmd != c.want[c.i] { + c.t.Fatalf("Invalid command, got %q, want %q", cmd, c.want[c.i]) + } + c.i++ +} + +type mockWriter struct { + want string + c *mockClient +} + +func (w *mockWriter) Write(p []byte) (int, error) { + w.c.do("Write message") + compareBodies(w.c.t, string(p), w.want) + return len(p), nil +} + +func (w *mockWriter) Close() error { + w.c.do("Close writer") + return nil +} + +func testSendMail(t *testing.T, d *SMTPDialer, want []string) { + testClient := &mockClient{ + t: t, + want: want, + addr: addr(d.Host, d.Port), + auth: testAuth, + config: d.TLSConfig, + } + + smtpDial = func(addr string) (smtpClient, error) { + assertAddr(t, addr, testClient.addr) + return testClient, nil + } + + tlsDial = func(network, addr string, config *tls.Config) (*tls.Conn, error) { + if network != "tcp" { + t.Errorf("Invalid network, got %q, want tcp", network) + } + assertAddr(t, addr, testClient.addr) + assertConfig(t, config, testClient.config) + return testTLSConn, nil + } + + newClient = func(conn net.Conn, host string) (smtpClient, error) { + if conn != testTLSConn { + t.Error("Invalid TLS connection used") + } + if host != testHost { + t.Errorf("Invalid host, got %q, want %q", host, testHost) + } + return testClient, nil + } + + if err := d.DialAndSend(getTestMessage()); err != nil { + t.Error(err) + } +} + +func assertAuth(t *testing.T, got, want smtp.Auth) { + if !reflect.DeepEqual(got, want) { + t.Errorf("Invalid auth, got %#v, want %#v", got, want) + } +} + +func assertAddr(t *testing.T, got, want string) { + if got != want { + t.Errorf("Invalid addr, got %q, want %q", got, want) + } +} + +func assertConfig(t *testing.T, got, want *tls.Config) { + if want == nil { + want = &tls.Config{ServerName: testHost} + } + if got.ServerName != want.ServerName { + t.Errorf("Invalid field ServerName in config, got %q, want %q", got.ServerName, want.ServerName) + } + if got.InsecureSkipVerify != want.InsecureSkipVerify { + t.Errorf("Invalid field InsecureSkipVerify in config, got %v, want %v", got.InsecureSkipVerify, want.InsecureSkipVerify) + } +}