Replaced Mailer type with the Sender interface and SMTPDialer type

Fixes #10
Fixes #17
Fixes #32
This commit is contained in:
Alexandre Cesaro 2015-06-30 11:55:03 +02:00
parent a7fe250544
commit 01674ee5b6
6 changed files with 619 additions and 516 deletions

View File

@ -2,7 +2,8 @@ package gomail
import ( import (
"encoding/base64" "encoding/base64"
"net/smtp" "io"
"io/ioutil"
"path/filepath" "path/filepath"
"regexp" "regexp"
"strconv" "strconv"
@ -11,6 +12,12 @@ import (
"time" "time"
) )
func init() {
now = func() time.Time {
return time.Date(2014, 06, 25, 17, 46, 0, 0, time.UTC)
}
}
type message struct { type message struct {
from string from string
to []string to []string
@ -23,8 +30,8 @@ func TestMessage(t *testing.T) {
msg.SetHeader("To", msg.FormatAddress("to@example.com", "Señor To"), "tobis@example.com") msg.SetHeader("To", msg.FormatAddress("to@example.com", "Señor To"), "tobis@example.com")
msg.SetAddressHeader("Cc", "cc@example.com", "A, B") msg.SetAddressHeader("Cc", "cc@example.com", "A, B")
msg.SetAddressHeader("X-To", "ccbis@example.com", "à, b") msg.SetAddressHeader("X-To", "ccbis@example.com", "à, b")
msg.SetDateHeader("X-Date", stubNow()) msg.SetDateHeader("X-Date", now())
msg.SetHeader("X-Date-2", msg.FormatDate(stubNow())) msg.SetHeader("X-Date-2", msg.FormatDate(now()))
msg.SetHeader("Subject", "¡Hola, señor!") msg.SetHeader("Subject", "¡Hola, señor!")
msg.SetHeaders(map[string][]string{ msg.SetHeaders(map[string][]string{
"X-Headers": {"Test", "Café"}, "X-Headers": {"Test", "Café"},
@ -488,31 +495,20 @@ func TestBase64LineLength(t *testing.T) {
} }
func testMessage(t *testing.T, msg *Message, bCount int, emails ...message) { func testMessage(t *testing.T, msg *Message, bCount int, emails ...message) {
now = stubNow err := Send(stubSendMail(t, bCount, emails...), msg)
mailer := NewMailer("host", "username", "password", 587, SetSendMail(stubSendMail(t, bCount, emails...)))
err := mailer.Send(msg)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
} }
func stubNow() time.Time { func stubSendMail(t *testing.T, bCount int, emails ...message) SendFunc {
return time.Date(2014, 06, 25, 17, 46, 0, 0, time.UTC)
}
func stubSendMail(t *testing.T, bCount int, emails ...message) SendMailFunc {
i := 0 i := 0
return func(addr string, a smtp.Auth, from string, to []string, msg []byte) error { return func(from string, to []string, msg io.Reader) error {
if i > len(emails) { if i > len(emails) {
t.Fatalf("Only %d mails should be sent", len(emails)) t.Fatalf("Only %d mails should be sent", len(emails))
} }
want := emails[i] want := emails[i]
if addr != "host:587" {
t.Fatalf("Invalid address, got %q, want host:587", addr)
}
if from != want.from { if from != want.from {
t.Fatalf("Invalid from, got %q, want %q", from, want.from) t.Fatalf("Invalid from, got %q, want %q", from, want.from)
} }
@ -531,7 +527,11 @@ func stubSendMail(t *testing.T, bCount int, emails ...message) SendMailFunc {
} }
} }
got := string(msg) content, err := ioutil.ReadAll(msg)
if err != nil {
t.Error(err)
}
got := string(content)
wantMsg := string("Mime-Version: 1.0\r\n" + wantMsg := string("Mime-Version: 1.0\r\n" +
"Date: Wed, 25 Jun 2014 17:46:00 +0000\r\n" + "Date: Wed, 25 Jun 2014 17:46:00 +0000\r\n" +
want.content) want.content)
@ -613,7 +613,7 @@ func getBoundaries(t *testing.T, count int, msg string) []string {
var boundaryRegExp = regexp.MustCompile("boundary=(\\w+)") var boundaryRegExp = regexp.MustCompile("boundary=(\\w+)")
func BenchmarkFull(b *testing.B) { func BenchmarkFull(b *testing.B) {
emptyFunc := func(addr string, a smtp.Auth, from string, to []string, msg []byte) error { emptyFunc := func(from string, to []string, msg io.Reader) error {
return nil return nil
} }
@ -631,8 +631,7 @@ func BenchmarkFull(b *testing.B) {
msg.Attach(CreateFile("benchmark.txt", []byte("Benchmark"))) msg.Attach(CreateFile("benchmark.txt", []byte("Benchmark")))
msg.Embed(CreateFile("benchmark.jpg", []byte("Benchmark"))) msg.Embed(CreateFile("benchmark.jpg", []byte("Benchmark")))
mailer := NewMailer("host", "username", "password", 587, SetSendMail(emptyFunc)) if err := Send(SendFunc(emptyFunc), msg); err != nil {
if err := mailer.Send(msg); err != nil {
panic(err) panic(err)
} }
} }

205
mailer.go
View File

@ -1,205 +0,0 @@
package gomail
import (
"crypto/tls"
"errors"
"fmt"
"io/ioutil"
"net"
"net/mail"
"net/smtp"
"strings"
)
// A Mailer represents an SMTP server.
type Mailer struct {
addr string
host string
config *tls.Config
auth smtp.Auth
send SendMailFunc
}
// A MailerSetting can be used in a mailer constructor to configure it.
type MailerSetting func(m *Mailer)
// SetSendMail allows to set the email-sending function of a mailer.
//
// Example:
//
// myFunc := func(addr string, a smtp.Auth, from string, to []string, msg []byte) error {
// // Implement your email-sending function similar to smtp.SendMail
// }
// mailer := gomail.NewMailer("host", "user", "pwd", 465, SetSendMail(myFunc))
func SetSendMail(s SendMailFunc) MailerSetting {
return func(m *Mailer) {
m.send = s
}
}
// SetTLSConfig allows to set the TLS configuration used to connect the SMTP
// server.
func SetTLSConfig(c *tls.Config) MailerSetting {
return func(m *Mailer) {
m.config = c
}
}
// A SendMailFunc is a function to send emails with the same signature than
// smtp.SendMail.
type SendMailFunc func(addr string, a smtp.Auth, from string, to []string, msg []byte) error
// NewMailer returns a mailer. The given parameters are used to connect to the
// SMTP server via a PLAIN authentication mechanism.
func NewMailer(host string, username string, password string, port int, settings ...MailerSetting) *Mailer {
return NewCustomMailer(
fmt.Sprintf("%s:%d", host, port),
smtp.PlainAuth("", username, password, host),
settings...,
)
}
// NewCustomMailer creates a mailer with the given authentication mechanism.
//
// Example:
//
// gomail.NewCustomMailer("host:587", smtp.CRAMMD5Auth("username", "secret"))
func NewCustomMailer(addr string, auth smtp.Auth, settings ...MailerSetting) *Mailer {
// Error is not handled here to preserve backward compatibility
host, port, _ := net.SplitHostPort(addr)
m := &Mailer{
addr: addr,
host: host,
auth: auth,
}
for _, s := range settings {
s(m)
}
if m.config == nil {
m.config = &tls.Config{ServerName: host}
}
if m.send == nil {
m.send = m.getSendMailFunc(port == "465")
}
return m
}
// Send sends the emails to all the recipients of the message.
func (m *Mailer) Send(msg *Message) error {
message := msg.Export()
from, err := getFrom(message)
if err != nil {
return err
}
recipients, bcc, err := getRecipients(message)
if err != nil {
return err
}
h := flattenHeader(message, "")
body, err := ioutil.ReadAll(message.Body)
if err != nil {
return err
}
mail := append(h, body...)
if err := m.send(m.addr, m.auth, from, recipients, mail); err != nil {
return err
}
for _, to := range bcc {
h = flattenHeader(message, to)
mail = append(h, body...)
if err := m.send(m.addr, m.auth, from, []string{to}, mail); err != nil {
return err
}
}
return nil
}
func flattenHeader(msg *mail.Message, bcc string) []byte {
buf := getBuffer()
defer putBuffer(buf)
for field, value := range msg.Header {
if field != "Bcc" {
buf.WriteString(field)
buf.WriteString(": ")
buf.WriteString(strings.Join(value, ", "))
buf.WriteString("\r\n")
} else if bcc != "" {
for _, to := range value {
if strings.Contains(to, bcc) {
buf.WriteString(field)
buf.WriteString(": ")
buf.WriteString(to)
buf.WriteString("\r\n")
}
}
}
}
buf.WriteString("\r\n")
return buf.Bytes()
}
func getFrom(msg *mail.Message) (string, error) {
from := msg.Header.Get("Sender")
if from == "" {
from = msg.Header.Get("From")
if from == "" {
return "", errors.New("mailer: invalid message, \"From\" field is absent")
}
}
return parseAddress(from)
}
func getRecipients(msg *mail.Message) (recipients, bcc []string, err error) {
for _, field := range []string{"Bcc", "To", "Cc"} {
if addresses, ok := msg.Header[field]; ok {
for _, addr := range addresses {
switch field {
case "Bcc":
bcc, err = addAdress(bcc, addr)
default:
recipients, err = addAdress(recipients, addr)
}
if err != nil {
return recipients, bcc, err
}
}
}
}
return recipients, bcc, nil
}
func addAdress(list []string, addr string) ([]string, error) {
addr, err := parseAddress(addr)
if err != nil {
return list, err
}
for _, a := range list {
if addr == a {
return list, nil
}
}
return append(list, addr), nil
}
func parseAddress(field string) (string, error) {
a, err := mail.ParseAddress(field)
if a == nil {
return "", err
}
return a.Address, err
}

243
send.go
View File

@ -1,102 +1,161 @@
package gomail package gomail
import ( import (
"crypto/tls" "bytes"
"errors"
"fmt"
"io" "io"
"net" "io/ioutil"
"net/smtp" "net/mail"
"strings"
) )
func (m *Mailer) getSendMailFunc(ssl bool) SendMailFunc { // Sender is the interface that wraps the Send method.
return func(addr string, a smtp.Auth, from string, to []string, msg []byte) error { //
var c smtpClient // Send sends an email to the given addresses.
var err error type Sender interface {
if ssl { Send(from string, to []string, msg io.Reader) error
c, err = sslDial(addr, m.host, m.config)
} else {
c, err = starttlsDial(addr, m.config)
}
if err != nil {
return err
}
defer c.Close()
if a != nil {
if ok, _ := c.Extension("AUTH"); ok {
if err = c.Auth(a); err != nil {
return err
}
}
} }
if err = c.Mail(from); err != nil { // SendCloser is the interface that groups the Send and Close methods.
return err type SendCloser interface {
} Sender
for _, addr := range to {
if err = c.Rcpt(addr); err != nil {
return err
}
}
w, err := c.Data()
if err != nil {
return err
}
_, err = w.Write(msg)
if err != nil {
return err
}
err = w.Close()
if err != nil {
return err
}
return c.Quit()
}
}
func sslDial(addr, host string, config *tls.Config) (smtpClient, error) {
conn, err := initTLS("tcp", addr, config)
if err != nil {
return nil, err
}
return newClient(conn, host)
}
func starttlsDial(addr string, config *tls.Config) (smtpClient, error) {
c, err := initSMTP(addr)
if err != nil {
return c, err
}
if ok, _ := c.Extension("STARTTLS"); ok {
return c, c.StartTLS(config)
}
return c, nil
}
var initSMTP = func(addr string) (smtpClient, error) {
return smtp.Dial(addr)
}
var initTLS = func(network, addr string, config *tls.Config) (*tls.Conn, error) {
return tls.Dial(network, addr, config)
}
var newClient = func(conn net.Conn, host string) (smtpClient, error) {
return smtp.NewClient(conn, host)
}
type smtpClient interface {
Extension(string) (bool, string)
StartTLS(*tls.Config) error
Auth(smtp.Auth) error
Mail(string) error
Rcpt(string) error
Data() (io.WriteCloser, error)
Quit() error
Close() error Close() error
} }
// A SendFunc is a function that sends emails to the given adresses.
// The SendFunc type is an adapter to allow the use of ordinary functions as
// email senders. If f is a function with the appropriate signature, SendFunc(f)
// is a Sender object that calls f.
type SendFunc func(from string, to []string, msg io.Reader) error
// Send calls f(from, to, msg).
func (f SendFunc) Send(from string, to []string, msg io.Reader) error {
return f(from, to, msg)
}
// Send sends emails using the given Sender.
func Send(s Sender, msg ...*Message) error {
for i, m := range msg {
if err := send(s, m); err != nil {
return fmt.Errorf("gomail: could not send email %d: %v", i+1, err)
}
}
return nil
}
func send(s Sender, msg *Message) error {
message := msg.Export()
from, err := getFrom(message)
if err != nil {
return err
}
recipients, bcc, err := getRecipients(message)
if err != nil {
return err
}
h := flattenHeader(message, "")
body, err := ioutil.ReadAll(message.Body)
if err != nil {
return err
}
mail := bytes.NewReader(append(h, body...))
if err := s.Send(from, recipients, mail); err != nil {
return err
}
for _, to := range bcc {
h = flattenHeader(message, to)
mail = bytes.NewReader(append(h, body...))
if err := s.Send(from, []string{to}, mail); err != nil {
return err
}
}
return nil
}
func flattenHeader(msg *mail.Message, bcc string) []byte {
buf := getBuffer()
defer putBuffer(buf)
for field, value := range msg.Header {
if field != "Bcc" {
buf.WriteString(field)
buf.WriteString(": ")
buf.WriteString(strings.Join(value, ", "))
buf.WriteString("\r\n")
} else if bcc != "" {
for _, to := range value {
if strings.Contains(to, bcc) {
buf.WriteString(field)
buf.WriteString(": ")
buf.WriteString(to)
buf.WriteString("\r\n")
}
}
}
}
buf.WriteString("\r\n")
return buf.Bytes()
}
func getFrom(msg *mail.Message) (string, error) {
from := msg.Header.Get("Sender")
if from == "" {
from = msg.Header.Get("From")
if from == "" {
return "", errors.New("mailer: invalid message, \"From\" field is absent")
}
}
return parseAddress(from)
}
func getRecipients(msg *mail.Message) (recipients, bcc []string, err error) {
for _, field := range []string{"Bcc", "To", "Cc"} {
if addresses, ok := msg.Header[field]; ok {
for _, addr := range addresses {
switch field {
case "Bcc":
bcc, err = addAdress(bcc, addr)
default:
recipients, err = addAdress(recipients, addr)
}
if err != nil {
return recipients, bcc, err
}
}
}
}
return recipients, bcc, nil
}
func addAdress(list []string, addr string) ([]string, error) {
addr, err := parseAddress(addr)
if err != nil {
return list, err
}
for _, a := range list {
if addr == a {
return list, nil
}
}
return append(list, addr), nil
}
func parseAddress(field string) (string, error) {
a, err := mail.ParseAddress(field)
if a == nil {
return "", err
}
return a.Address, err
}

View File

@ -1,245 +1,79 @@
package gomail package gomail
import ( import (
"crypto/tls"
"io" "io"
"net" "io/ioutil"
"net/smtp" "reflect"
"testing" "testing"
) )
var ( const (
testAddr = "smtp.example.com:587" testTo1 = "to1@example.com"
testSSLAddr = "smtp.example.com:465" testTo2 = "to2@example.com"
testTLSConn = &tls.Conn{}
testConfig = &tls.Config{InsecureSkipVerify: true}
testHost = "smtp.example.com"
testAuth = smtp.PlainAuth("", "user", "pwd", "smtp.example.com")
testFrom = "from@example.com" testFrom = "from@example.com"
testTo = []string{"to1@example.com", "to2@example.com"}
testBody = "Test message" testBody = "Test message"
) testMsg = "To: " + testTo1 + ", " + testTo2 + "\r\n" +
"From: " + testFrom + "\r\n" +
const wantMsg = "To: to1@example.com, to2@example.com\r\n" +
"From: from@example.com\r\n" +
"Mime-Version: 1.0\r\n" + "Mime-Version: 1.0\r\n" +
"Date: Wed, 25 Jun 2014 17:46:00 +0000\r\n" + "Date: Wed, 25 Jun 2014 17:46:00 +0000\r\n" +
"Content-Type: text/plain; charset=UTF-8\r\n" + "Content-Type: text/plain; charset=UTF-8\r\n" +
"Content-Transfer-Encoding: quoted-printable\r\n" + "Content-Transfer-Encoding: quoted-printable\r\n" +
"\r\n" + "\r\n" +
"Test message" testBody
)
func TestDefaultSendMail(t *testing.T) { type mockSender SendFunc
testSendMail(t, testAddr, nil, []string{
"Extension STARTTLS", func (s mockSender) Send(from string, to []string, msg io.Reader) error {
"StartTLS", return s(from, to, msg)
"Extension AUTH",
"Auth",
"Mail " + testFrom,
"Rcpt " + testTo[0],
"Rcpt " + testTo[1],
"Data",
"Write message",
"Close writer",
"Quit",
"Close",
})
} }
func TestSSLSendMail(t *testing.T) { type mockSendCloser struct {
testSendMail(t, testSSLAddr, nil, []string{ mockSender
"Extension AUTH", close func() error
"Auth",
"Mail " + testFrom,
"Rcpt " + testTo[0],
"Rcpt " + testTo[1],
"Data",
"Write message",
"Close writer",
"Quit",
"Close",
})
} }
func TestTLSConfigSendMail(t *testing.T) { func (s *mockSendCloser) Close() error {
testSendMail(t, testAddr, testConfig, []string{ return s.close()
"Extension STARTTLS",
"StartTLS",
"Extension AUTH",
"Auth",
"Mail " + testFrom,
"Rcpt " + testTo[0],
"Rcpt " + testTo[1],
"Data",
"Write message",
"Close writer",
"Quit",
"Close",
})
} }
func TestTLSConfigSSLSendMail(t *testing.T) { func TestSend(t *testing.T) {
testSendMail(t, testSSLAddr, testConfig, []string{ s := &mockSendCloser{
"Extension AUTH", mockSender: stubSend(t, testFrom, []string{testTo1, testTo2}, testMsg),
"Auth", close: func() error {
"Mail " + testFrom, t.Error("Close() should not be called in Send()")
"Rcpt " + testTo[0], return nil
"Rcpt " + testTo[1], },
"Data", }
"Write message", if err := Send(s, getTestMessage()); err != nil {
"Close writer", t.Errorf("Send(): %v", err)
"Quit", }
"Close",
})
} }
type mockClient struct { func getTestMessage() *Message {
t *testing.T m := NewMessage()
i int m.SetHeader("From", testFrom)
want []string m.SetHeader("To", testTo1, testTo2)
addr string m.SetBody("text/plain", testBody)
auth smtp.Auth
config *tls.Config return m
} }
func (c *mockClient) Extension(ext string) (bool, string) { func stubSend(t *testing.T, wantFrom string, wantTo []string, wantBody string) mockSender {
c.do("Extension " + ext) return func(from string, to []string, msg io.Reader) error {
return true, "" if from != wantFrom {
t.Errorf("invalid from, got %q, want %q", from, wantFrom)
}
if !reflect.DeepEqual(to, wantTo) {
t.Errorf("invalid to, got %v, want %v", to, wantTo)
} }
func (c *mockClient) StartTLS(config *tls.Config) error { content, err := ioutil.ReadAll(msg)
assertConfig(c.t, config, c.config) if err != nil {
c.do("StartTLS") t.Fatal(err)
}
compareBodies(t, string(content), wantBody)
return nil return nil
} }
func (c *mockClient) Auth(a smtp.Auth) error {
assertAuth(c.t, a, c.auth)
c.do("Auth")
return nil
}
func (c *mockClient) Mail(from string) error {
c.do("Mail " + from)
return nil
}
func (c *mockClient) Rcpt(to string) error {
c.do("Rcpt " + to)
return nil
}
func (c *mockClient) Data() (io.WriteCloser, error) {
c.do("Data")
return &mockWriter{c: c, want: wantMsg}, nil
}
func (c *mockClient) Quit() error {
c.do("Quit")
return nil
}
func (c *mockClient) Close() error {
c.do("Close")
return nil
}
func (c *mockClient) do(cmd string) {
if c.i >= len(c.want) {
c.t.Fatalf("Invalid command %q", cmd)
}
if cmd != c.want[c.i] {
c.t.Fatalf("Invalid command, got %q, want %q", cmd, c.want[c.i])
}
c.i++
}
type mockWriter struct {
want string
c *mockClient
}
func (w *mockWriter) Write(p []byte) (int, error) {
w.c.do("Write message")
compareBodies(w.c.t, string(p), w.want)
return len(p), nil
}
func (w *mockWriter) Close() error {
w.c.do("Close writer")
return nil
}
func testSendMail(t *testing.T, addr string, config *tls.Config, want []string) {
testClient := &mockClient{
t: t,
want: want,
addr: addr,
auth: testAuth,
config: config,
}
initSMTP = func(addr string) (smtpClient, error) {
assertAddr(t, addr, testClient.addr)
return testClient, nil
}
initTLS = func(network, addr string, config *tls.Config) (*tls.Conn, error) {
if network != "tcp" {
t.Errorf("Invalid network, got %q, want tcp", network)
}
assertAddr(t, addr, testClient.addr)
assertConfig(t, config, testClient.config)
return testTLSConn, nil
}
newClient = func(conn net.Conn, host string) (smtpClient, error) {
if conn != testTLSConn {
t.Error("Invalid TLS connection used")
}
if host != testHost {
t.Errorf("Invalid host, got %q, want %q", host, testHost)
}
return testClient, nil
}
msg := NewMessage()
msg.SetHeader("From", testFrom)
msg.SetHeader("To", testTo...)
msg.SetBody("text/plain", testBody)
var settings []MailerSetting
if config != nil {
settings = []MailerSetting{SetTLSConfig(config)}
}
mailer := NewCustomMailer(addr, testAuth, settings...)
if err := mailer.Send(msg); err != nil {
t.Error(err)
}
}
func assertAuth(t *testing.T, got, want smtp.Auth) {
if got != want {
t.Errorf("Invalid auth, got %#v, want %#v", got, want)
}
}
func assertAddr(t *testing.T, got, want string) {
if got != want {
t.Errorf("Invalid addr, got %q, want %q", got, want)
}
}
func assertConfig(t *testing.T, got, want *tls.Config) {
if want == nil {
want = &tls.Config{ServerName: testHost}
}
if got.ServerName != want.ServerName {
t.Errorf("Invalid field ServerName in config, got %q, want %q", got.ServerName, want.ServerName)
}
if got.InsecureSkipVerify != want.InsecureSkipVerify {
t.Errorf("Invalid field InsecureSkipVerify in config, got %v, want %v", got.InsecureSkipVerify, want.InsecureSkipVerify)
}
} }

168
smtp.go Normal file
View File

@ -0,0 +1,168 @@
package gomail
import (
"crypto/tls"
"fmt"
"io"
"net"
"net/smtp"
)
// An SMTPDialer is a dialer to an SMTP server.
type SMTPDialer struct {
// Host represents the host of the SMTP server.
Host string
// Port represents the port of the SMTP server.
Port int
// Auth represents the authentication mechanism used to authenticate to the
// SMTP server.
Auth smtp.Auth
// SSL defines whether an SSL connection is used. It should be false in
// most cases since the authentication mechanism should use the STARTTLS
// extension instead.
SSL bool
// TSLConfig represents the TLS configuration used for the TLS (when the
// STARTTLS extension is used) or SSL connection.
TLSConfig *tls.Config
}
// NewPlainDialer returns an SMTPDialer. The given parameters are used to
// connect to the SMTP server via a PLAIN authentication mechanism.
func NewPlainDialer(host, username, password string, port int) *SMTPDialer {
return &SMTPDialer{
Host: host,
Port: port,
Auth: smtp.PlainAuth("", username, password, host),
SSL: port == 465,
}
}
// Dial dials and authenticates to an SMTP server. The returned SendCloser
// should be closed when done using it.
func (d *SMTPDialer) Dial() (SendCloser, error) {
c, err := d.dial()
if err != nil {
return nil, err
}
if d.Auth != nil {
if ok, _ := c.Extension("AUTH"); ok {
if err = c.Auth(d.Auth); err != nil {
c.Close()
return nil, err
}
}
}
return &smtpSender{c}, nil
}
func (d *SMTPDialer) dial() (smtpClient, error) {
if d.SSL {
return d.sslDial()
}
return d.starttlsDial()
}
func (d *SMTPDialer) starttlsDial() (smtpClient, error) {
c, err := smtpDial(addr(d.Host, d.Port))
if err != nil {
return nil, err
}
if ok, _ := c.Extension("STARTTLS"); ok {
if err := c.StartTLS(d.tlsConfig()); err != nil {
c.Close()
return nil, err
}
}
return c, nil
}
func (d *SMTPDialer) sslDial() (smtpClient, error) {
conn, err := tlsDial("tcp", addr(d.Host, d.Port), d.tlsConfig())
if err != nil {
return nil, err
}
return newClient(conn, d.Host)
}
func (d *SMTPDialer) tlsConfig() *tls.Config {
if d.TLSConfig == nil {
return &tls.Config{ServerName: d.Host}
}
return d.TLSConfig
}
func addr(host string, port int) string {
return fmt.Sprintf("%s:%d", host, port)
}
// DialAndSend opens a connection to an SMTP server, sends the given emails and
// closes the connection.
func (d *SMTPDialer) DialAndSend(msg ...*Message) error {
s, err := d.Dial()
if err != nil {
return err
}
defer s.Close()
return Send(s, msg...)
}
type smtpSender struct {
smtpClient
}
func (c *smtpSender) Send(from string, to []string, msg io.Reader) error {
if err := c.Mail(from); err != nil {
return err
}
for _, addr := range to {
if err := c.Rcpt(addr); err != nil {
return err
}
}
w, err := c.Data()
if err != nil {
return err
}
if _, err = io.Copy(w, msg); err != nil {
w.Close()
return err
}
return w.Close()
}
func (c *smtpSender) Close() error {
return c.Quit()
}
// Stubbed out for tests.
var (
smtpDial = func(addr string) (smtpClient, error) {
return smtp.Dial(addr)
}
tlsDial = tls.Dial
newClient = func(conn net.Conn, host string) (smtpClient, error) {
return smtp.NewClient(conn, host)
}
)
type smtpClient interface {
Extension(string) (bool, string)
StartTLS(*tls.Config) error
Auth(smtp.Auth) error
Mail(string) error
Rcpt(string) error
Data() (io.WriteCloser, error)
Quit() error
Close() error
}

248
smtp_test.go Normal file
View File

@ -0,0 +1,248 @@
package gomail
import (
"crypto/tls"
"io"
"net"
"net/smtp"
"reflect"
"testing"
)
var (
testHost = "smtp.example.com"
testPort = 587
testSSLPort = 465
testTLSConn = &tls.Conn{}
testConfig = &tls.Config{InsecureSkipVerify: true}
testAuth = smtp.PlainAuth("", "user", "pwd", testHost)
)
func TestSMTPDialer(t *testing.T) {
d := NewPlainDialer(testHost, "user", "pwd", testPort)
testSendMail(t, d, []string{
"Extension STARTTLS",
"StartTLS",
"Extension AUTH",
"Auth",
"Mail " + testFrom,
"Rcpt " + testTo1,
"Rcpt " + testTo2,
"Data",
"Write message",
"Close writer",
"Quit",
"Close",
})
}
func TestSMTPDialerSSL(t *testing.T) {
d := NewPlainDialer(testHost, "user", "pwd", testSSLPort)
testSendMail(t, d, []string{
"Extension AUTH",
"Auth",
"Mail " + testFrom,
"Rcpt " + testTo1,
"Rcpt " + testTo2,
"Data",
"Write message",
"Close writer",
"Quit",
"Close",
})
}
func TestSMTPDialerConfig(t *testing.T) {
d := NewPlainDialer(testHost, "user", "pwd", testPort)
d.TLSConfig = testConfig
testSendMail(t, d, []string{
"Extension STARTTLS",
"StartTLS",
"Extension AUTH",
"Auth",
"Mail " + testFrom,
"Rcpt " + testTo1,
"Rcpt " + testTo2,
"Data",
"Write message",
"Close writer",
"Quit",
"Close",
})
}
func TestSMTPDialerSSLConfig(t *testing.T) {
d := NewPlainDialer(testHost, "user", "pwd", testSSLPort)
d.TLSConfig = testConfig
testSendMail(t, d, []string{
"Extension AUTH",
"Auth",
"Mail " + testFrom,
"Rcpt " + testTo1,
"Rcpt " + testTo2,
"Data",
"Write message",
"Close writer",
"Quit",
"Close",
})
}
func TestSMTPDialerNoAuth(t *testing.T) {
d := &SMTPDialer{
Host: testHost,
Port: testPort,
}
testSendMail(t, d, []string{
"Extension STARTTLS",
"StartTLS",
"Mail " + testFrom,
"Rcpt " + testTo1,
"Rcpt " + testTo2,
"Data",
"Write message",
"Close writer",
"Quit",
"Close",
})
}
type mockClient struct {
t *testing.T
i int
want []string
addr string
auth smtp.Auth
config *tls.Config
}
func (c *mockClient) Extension(ext string) (bool, string) {
c.do("Extension " + ext)
return true, ""
}
func (c *mockClient) StartTLS(config *tls.Config) error {
assertConfig(c.t, config, c.config)
c.do("StartTLS")
return nil
}
func (c *mockClient) Auth(a smtp.Auth) error {
assertAuth(c.t, a, c.auth)
c.do("Auth")
return nil
}
func (c *mockClient) Mail(from string) error {
c.do("Mail " + from)
return nil
}
func (c *mockClient) Rcpt(to string) error {
c.do("Rcpt " + to)
return nil
}
func (c *mockClient) Data() (io.WriteCloser, error) {
c.do("Data")
return &mockWriter{c: c, want: testMsg}, nil
}
func (c *mockClient) Quit() error {
c.do("Quit")
return nil
}
func (c *mockClient) Close() error {
c.do("Close")
return nil
}
func (c *mockClient) do(cmd string) {
if c.i >= len(c.want) {
c.t.Fatalf("Invalid command %q", cmd)
}
if cmd != c.want[c.i] {
c.t.Fatalf("Invalid command, got %q, want %q", cmd, c.want[c.i])
}
c.i++
}
type mockWriter struct {
want string
c *mockClient
}
func (w *mockWriter) Write(p []byte) (int, error) {
w.c.do("Write message")
compareBodies(w.c.t, string(p), w.want)
return len(p), nil
}
func (w *mockWriter) Close() error {
w.c.do("Close writer")
return nil
}
func testSendMail(t *testing.T, d *SMTPDialer, want []string) {
testClient := &mockClient{
t: t,
want: want,
addr: addr(d.Host, d.Port),
auth: testAuth,
config: d.TLSConfig,
}
smtpDial = func(addr string) (smtpClient, error) {
assertAddr(t, addr, testClient.addr)
return testClient, nil
}
tlsDial = func(network, addr string, config *tls.Config) (*tls.Conn, error) {
if network != "tcp" {
t.Errorf("Invalid network, got %q, want tcp", network)
}
assertAddr(t, addr, testClient.addr)
assertConfig(t, config, testClient.config)
return testTLSConn, nil
}
newClient = func(conn net.Conn, host string) (smtpClient, error) {
if conn != testTLSConn {
t.Error("Invalid TLS connection used")
}
if host != testHost {
t.Errorf("Invalid host, got %q, want %q", host, testHost)
}
return testClient, nil
}
if err := d.DialAndSend(getTestMessage()); err != nil {
t.Error(err)
}
}
func assertAuth(t *testing.T, got, want smtp.Auth) {
if !reflect.DeepEqual(got, want) {
t.Errorf("Invalid auth, got %#v, want %#v", got, want)
}
}
func assertAddr(t *testing.T, got, want string) {
if got != want {
t.Errorf("Invalid addr, got %q, want %q", got, want)
}
}
func assertConfig(t *testing.T, got, want *tls.Config) {
if want == nil {
want = &tls.Config{ServerName: testHost}
}
if got.ServerName != want.ServerName {
t.Errorf("Invalid field ServerName in config, got %q, want %q", got.ServerName, want.ServerName)
}
if got.InsecureSkipVerify != want.InsecureSkipVerify {
t.Errorf("Invalid field InsecureSkipVerify in config, got %v, want %v", got.InsecureSkipVerify, want.InsecureSkipVerify)
}
}