Replaced Mailer type with the Sender interface and SMTPDialer type
Fixes #10 Fixes #17 Fixes #32
This commit is contained in:
parent
a7fe250544
commit
01674ee5b6
|
@ -2,7 +2,8 @@ package gomail
|
|||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"net/smtp"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
@ -11,6 +12,12 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
func init() {
|
||||
now = func() time.Time {
|
||||
return time.Date(2014, 06, 25, 17, 46, 0, 0, time.UTC)
|
||||
}
|
||||
}
|
||||
|
||||
type message struct {
|
||||
from string
|
||||
to []string
|
||||
|
@ -23,8 +30,8 @@ func TestMessage(t *testing.T) {
|
|||
msg.SetHeader("To", msg.FormatAddress("to@example.com", "Señor To"), "tobis@example.com")
|
||||
msg.SetAddressHeader("Cc", "cc@example.com", "A, B")
|
||||
msg.SetAddressHeader("X-To", "ccbis@example.com", "à, b")
|
||||
msg.SetDateHeader("X-Date", stubNow())
|
||||
msg.SetHeader("X-Date-2", msg.FormatDate(stubNow()))
|
||||
msg.SetDateHeader("X-Date", now())
|
||||
msg.SetHeader("X-Date-2", msg.FormatDate(now()))
|
||||
msg.SetHeader("Subject", "¡Hola, señor!")
|
||||
msg.SetHeaders(map[string][]string{
|
||||
"X-Headers": {"Test", "Café"},
|
||||
|
@ -488,31 +495,20 @@ func TestBase64LineLength(t *testing.T) {
|
|||
}
|
||||
|
||||
func testMessage(t *testing.T, msg *Message, bCount int, emails ...message) {
|
||||
now = stubNow
|
||||
mailer := NewMailer("host", "username", "password", 587, SetSendMail(stubSendMail(t, bCount, emails...)))
|
||||
|
||||
err := mailer.Send(msg)
|
||||
err := Send(stubSendMail(t, bCount, emails...), msg)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func stubNow() time.Time {
|
||||
return time.Date(2014, 06, 25, 17, 46, 0, 0, time.UTC)
|
||||
}
|
||||
|
||||
func stubSendMail(t *testing.T, bCount int, emails ...message) SendMailFunc {
|
||||
func stubSendMail(t *testing.T, bCount int, emails ...message) SendFunc {
|
||||
i := 0
|
||||
return func(addr string, a smtp.Auth, from string, to []string, msg []byte) error {
|
||||
return func(from string, to []string, msg io.Reader) error {
|
||||
if i > len(emails) {
|
||||
t.Fatalf("Only %d mails should be sent", len(emails))
|
||||
}
|
||||
want := emails[i]
|
||||
|
||||
if addr != "host:587" {
|
||||
t.Fatalf("Invalid address, got %q, want host:587", addr)
|
||||
}
|
||||
|
||||
if from != want.from {
|
||||
t.Fatalf("Invalid from, got %q, want %q", from, want.from)
|
||||
}
|
||||
|
@ -531,7 +527,11 @@ func stubSendMail(t *testing.T, bCount int, emails ...message) SendMailFunc {
|
|||
}
|
||||
}
|
||||
|
||||
got := string(msg)
|
||||
content, err := ioutil.ReadAll(msg)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
got := string(content)
|
||||
wantMsg := string("Mime-Version: 1.0\r\n" +
|
||||
"Date: Wed, 25 Jun 2014 17:46:00 +0000\r\n" +
|
||||
want.content)
|
||||
|
@ -613,7 +613,7 @@ func getBoundaries(t *testing.T, count int, msg string) []string {
|
|||
var boundaryRegExp = regexp.MustCompile("boundary=(\\w+)")
|
||||
|
||||
func BenchmarkFull(b *testing.B) {
|
||||
emptyFunc := func(addr string, a smtp.Auth, from string, to []string, msg []byte) error {
|
||||
emptyFunc := func(from string, to []string, msg io.Reader) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -631,8 +631,7 @@ func BenchmarkFull(b *testing.B) {
|
|||
msg.Attach(CreateFile("benchmark.txt", []byte("Benchmark")))
|
||||
msg.Embed(CreateFile("benchmark.jpg", []byte("Benchmark")))
|
||||
|
||||
mailer := NewMailer("host", "username", "password", 587, SetSendMail(emptyFunc))
|
||||
if err := mailer.Send(msg); err != nil {
|
||||
if err := Send(SendFunc(emptyFunc), msg); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
|
205
mailer.go
205
mailer.go
|
@ -1,205 +0,0 @@
|
|||
package gomail
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net"
|
||||
"net/mail"
|
||||
"net/smtp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// A Mailer represents an SMTP server.
|
||||
type Mailer struct {
|
||||
addr string
|
||||
host string
|
||||
config *tls.Config
|
||||
auth smtp.Auth
|
||||
send SendMailFunc
|
||||
}
|
||||
|
||||
// A MailerSetting can be used in a mailer constructor to configure it.
|
||||
type MailerSetting func(m *Mailer)
|
||||
|
||||
// SetSendMail allows to set the email-sending function of a mailer.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// myFunc := func(addr string, a smtp.Auth, from string, to []string, msg []byte) error {
|
||||
// // Implement your email-sending function similar to smtp.SendMail
|
||||
// }
|
||||
// mailer := gomail.NewMailer("host", "user", "pwd", 465, SetSendMail(myFunc))
|
||||
func SetSendMail(s SendMailFunc) MailerSetting {
|
||||
return func(m *Mailer) {
|
||||
m.send = s
|
||||
}
|
||||
}
|
||||
|
||||
// SetTLSConfig allows to set the TLS configuration used to connect the SMTP
|
||||
// server.
|
||||
func SetTLSConfig(c *tls.Config) MailerSetting {
|
||||
return func(m *Mailer) {
|
||||
m.config = c
|
||||
}
|
||||
}
|
||||
|
||||
// A SendMailFunc is a function to send emails with the same signature than
|
||||
// smtp.SendMail.
|
||||
type SendMailFunc func(addr string, a smtp.Auth, from string, to []string, msg []byte) error
|
||||
|
||||
// NewMailer returns a mailer. The given parameters are used to connect to the
|
||||
// SMTP server via a PLAIN authentication mechanism.
|
||||
func NewMailer(host string, username string, password string, port int, settings ...MailerSetting) *Mailer {
|
||||
return NewCustomMailer(
|
||||
fmt.Sprintf("%s:%d", host, port),
|
||||
smtp.PlainAuth("", username, password, host),
|
||||
settings...,
|
||||
)
|
||||
}
|
||||
|
||||
// NewCustomMailer creates a mailer with the given authentication mechanism.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// gomail.NewCustomMailer("host:587", smtp.CRAMMD5Auth("username", "secret"))
|
||||
func NewCustomMailer(addr string, auth smtp.Auth, settings ...MailerSetting) *Mailer {
|
||||
// Error is not handled here to preserve backward compatibility
|
||||
host, port, _ := net.SplitHostPort(addr)
|
||||
|
||||
m := &Mailer{
|
||||
addr: addr,
|
||||
host: host,
|
||||
auth: auth,
|
||||
}
|
||||
|
||||
for _, s := range settings {
|
||||
s(m)
|
||||
}
|
||||
|
||||
if m.config == nil {
|
||||
m.config = &tls.Config{ServerName: host}
|
||||
}
|
||||
if m.send == nil {
|
||||
m.send = m.getSendMailFunc(port == "465")
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
// Send sends the emails to all the recipients of the message.
|
||||
func (m *Mailer) Send(msg *Message) error {
|
||||
message := msg.Export()
|
||||
|
||||
from, err := getFrom(message)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
recipients, bcc, err := getRecipients(message)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h := flattenHeader(message, "")
|
||||
body, err := ioutil.ReadAll(message.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mail := append(h, body...)
|
||||
if err := m.send(m.addr, m.auth, from, recipients, mail); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, to := range bcc {
|
||||
h = flattenHeader(message, to)
|
||||
mail = append(h, body...)
|
||||
if err := m.send(m.addr, m.auth, from, []string{to}, mail); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func flattenHeader(msg *mail.Message, bcc string) []byte {
|
||||
buf := getBuffer()
|
||||
defer putBuffer(buf)
|
||||
|
||||
for field, value := range msg.Header {
|
||||
if field != "Bcc" {
|
||||
buf.WriteString(field)
|
||||
buf.WriteString(": ")
|
||||
buf.WriteString(strings.Join(value, ", "))
|
||||
buf.WriteString("\r\n")
|
||||
} else if bcc != "" {
|
||||
for _, to := range value {
|
||||
if strings.Contains(to, bcc) {
|
||||
buf.WriteString(field)
|
||||
buf.WriteString(": ")
|
||||
buf.WriteString(to)
|
||||
buf.WriteString("\r\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
buf.WriteString("\r\n")
|
||||
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func getFrom(msg *mail.Message) (string, error) {
|
||||
from := msg.Header.Get("Sender")
|
||||
if from == "" {
|
||||
from = msg.Header.Get("From")
|
||||
if from == "" {
|
||||
return "", errors.New("mailer: invalid message, \"From\" field is absent")
|
||||
}
|
||||
}
|
||||
|
||||
return parseAddress(from)
|
||||
}
|
||||
|
||||
func getRecipients(msg *mail.Message) (recipients, bcc []string, err error) {
|
||||
for _, field := range []string{"Bcc", "To", "Cc"} {
|
||||
if addresses, ok := msg.Header[field]; ok {
|
||||
for _, addr := range addresses {
|
||||
switch field {
|
||||
case "Bcc":
|
||||
bcc, err = addAdress(bcc, addr)
|
||||
default:
|
||||
recipients, err = addAdress(recipients, addr)
|
||||
}
|
||||
if err != nil {
|
||||
return recipients, bcc, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return recipients, bcc, nil
|
||||
}
|
||||
|
||||
func addAdress(list []string, addr string) ([]string, error) {
|
||||
addr, err := parseAddress(addr)
|
||||
if err != nil {
|
||||
return list, err
|
||||
}
|
||||
for _, a := range list {
|
||||
if addr == a {
|
||||
return list, nil
|
||||
}
|
||||
}
|
||||
|
||||
return append(list, addr), nil
|
||||
}
|
||||
|
||||
func parseAddress(field string) (string, error) {
|
||||
a, err := mail.ParseAddress(field)
|
||||
if a == nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return a.Address, err
|
||||
}
|
205
send.go
205
send.go
|
@ -1,102 +1,161 @@
|
|||
package gomail
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/smtp"
|
||||
"io/ioutil"
|
||||
"net/mail"
|
||||
"strings"
|
||||
)
|
||||
|
||||
func (m *Mailer) getSendMailFunc(ssl bool) SendMailFunc {
|
||||
return func(addr string, a smtp.Auth, from string, to []string, msg []byte) error {
|
||||
var c smtpClient
|
||||
var err error
|
||||
if ssl {
|
||||
c, err = sslDial(addr, m.host, m.config)
|
||||
} else {
|
||||
c, err = starttlsDial(addr, m.config)
|
||||
// Sender is the interface that wraps the Send method.
|
||||
//
|
||||
// Send sends an email to the given addresses.
|
||||
type Sender interface {
|
||||
Send(from string, to []string, msg io.Reader) error
|
||||
}
|
||||
|
||||
// SendCloser is the interface that groups the Send and Close methods.
|
||||
type SendCloser interface {
|
||||
Sender
|
||||
Close() error
|
||||
}
|
||||
|
||||
// A SendFunc is a function that sends emails to the given adresses.
|
||||
// The SendFunc type is an adapter to allow the use of ordinary functions as
|
||||
// email senders. If f is a function with the appropriate signature, SendFunc(f)
|
||||
// is a Sender object that calls f.
|
||||
type SendFunc func(from string, to []string, msg io.Reader) error
|
||||
|
||||
// Send calls f(from, to, msg).
|
||||
func (f SendFunc) Send(from string, to []string, msg io.Reader) error {
|
||||
return f(from, to, msg)
|
||||
}
|
||||
|
||||
// Send sends emails using the given Sender.
|
||||
func Send(s Sender, msg ...*Message) error {
|
||||
for i, m := range msg {
|
||||
if err := send(s, m); err != nil {
|
||||
return fmt.Errorf("gomail: could not send email %d: %v", i+1, err)
|
||||
}
|
||||
if err != nil {
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func send(s Sender, msg *Message) error {
|
||||
message := msg.Export()
|
||||
|
||||
from, err := getFrom(message)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
recipients, bcc, err := getRecipients(message)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
h := flattenHeader(message, "")
|
||||
body, err := ioutil.ReadAll(message.Body)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mail := bytes.NewReader(append(h, body...))
|
||||
if err := s.Send(from, recipients, mail); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, to := range bcc {
|
||||
h = flattenHeader(message, to)
|
||||
mail = bytes.NewReader(append(h, body...))
|
||||
if err := s.Send(from, []string{to}, mail); err != nil {
|
||||
return err
|
||||
}
|
||||
defer c.Close()
|
||||
}
|
||||
|
||||
if a != nil {
|
||||
if ok, _ := c.Extension("AUTH"); ok {
|
||||
if err = c.Auth(a); err != nil {
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
func flattenHeader(msg *mail.Message, bcc string) []byte {
|
||||
buf := getBuffer()
|
||||
defer putBuffer(buf)
|
||||
|
||||
for field, value := range msg.Header {
|
||||
if field != "Bcc" {
|
||||
buf.WriteString(field)
|
||||
buf.WriteString(": ")
|
||||
buf.WriteString(strings.Join(value, ", "))
|
||||
buf.WriteString("\r\n")
|
||||
} else if bcc != "" {
|
||||
for _, to := range value {
|
||||
if strings.Contains(to, bcc) {
|
||||
buf.WriteString(field)
|
||||
buf.WriteString(": ")
|
||||
buf.WriteString(to)
|
||||
buf.WriteString("\r\n")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
buf.WriteString("\r\n")
|
||||
|
||||
if err = c.Mail(from); err != nil {
|
||||
return err
|
||||
return buf.Bytes()
|
||||
}
|
||||
|
||||
func getFrom(msg *mail.Message) (string, error) {
|
||||
from := msg.Header.Get("Sender")
|
||||
if from == "" {
|
||||
from = msg.Header.Get("From")
|
||||
if from == "" {
|
||||
return "", errors.New("mailer: invalid message, \"From\" field is absent")
|
||||
}
|
||||
}
|
||||
|
||||
for _, addr := range to {
|
||||
if err = c.Rcpt(addr); err != nil {
|
||||
return err
|
||||
return parseAddress(from)
|
||||
}
|
||||
|
||||
func getRecipients(msg *mail.Message) (recipients, bcc []string, err error) {
|
||||
for _, field := range []string{"Bcc", "To", "Cc"} {
|
||||
if addresses, ok := msg.Header[field]; ok {
|
||||
for _, addr := range addresses {
|
||||
switch field {
|
||||
case "Bcc":
|
||||
bcc, err = addAdress(bcc, addr)
|
||||
default:
|
||||
recipients, err = addAdress(recipients, addr)
|
||||
}
|
||||
if err != nil {
|
||||
return recipients, bcc, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
w, err := c.Data()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.Write(msg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = w.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return c.Quit()
|
||||
}
|
||||
|
||||
return recipients, bcc, nil
|
||||
}
|
||||
|
||||
func sslDial(addr, host string, config *tls.Config) (smtpClient, error) {
|
||||
conn, err := initTLS("tcp", addr, config)
|
||||
func addAdress(list []string, addr string) ([]string, error) {
|
||||
addr, err := parseAddress(addr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return list, err
|
||||
}
|
||||
for _, a := range list {
|
||||
if addr == a {
|
||||
return list, nil
|
||||
}
|
||||
}
|
||||
|
||||
return newClient(conn, host)
|
||||
return append(list, addr), nil
|
||||
}
|
||||
|
||||
func starttlsDial(addr string, config *tls.Config) (smtpClient, error) {
|
||||
c, err := initSMTP(addr)
|
||||
if err != nil {
|
||||
return c, err
|
||||
func parseAddress(field string) (string, error) {
|
||||
a, err := mail.ParseAddress(field)
|
||||
if a == nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if ok, _ := c.Extension("STARTTLS"); ok {
|
||||
return c, c.StartTLS(config)
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
var initSMTP = func(addr string) (smtpClient, error) {
|
||||
return smtp.Dial(addr)
|
||||
}
|
||||
|
||||
var initTLS = func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
return tls.Dial(network, addr, config)
|
||||
}
|
||||
|
||||
var newClient = func(conn net.Conn, host string) (smtpClient, error) {
|
||||
return smtp.NewClient(conn, host)
|
||||
}
|
||||
|
||||
type smtpClient interface {
|
||||
Extension(string) (bool, string)
|
||||
StartTLS(*tls.Config) error
|
||||
Auth(smtp.Auth) error
|
||||
Mail(string) error
|
||||
Rcpt(string) error
|
||||
Data() (io.WriteCloser, error)
|
||||
Quit() error
|
||||
Close() error
|
||||
return a.Address, err
|
||||
}
|
||||
|
|
268
send_test.go
268
send_test.go
|
@ -1,245 +1,79 @@
|
|||
package gomail
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/smtp"
|
||||
"io/ioutil"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
testAddr = "smtp.example.com:587"
|
||||
testSSLAddr = "smtp.example.com:465"
|
||||
testTLSConn = &tls.Conn{}
|
||||
testConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
testHost = "smtp.example.com"
|
||||
testAuth = smtp.PlainAuth("", "user", "pwd", "smtp.example.com")
|
||||
testFrom = "from@example.com"
|
||||
testTo = []string{"to1@example.com", "to2@example.com"}
|
||||
testBody = "Test message"
|
||||
const (
|
||||
testTo1 = "to1@example.com"
|
||||
testTo2 = "to2@example.com"
|
||||
testFrom = "from@example.com"
|
||||
testBody = "Test message"
|
||||
testMsg = "To: " + testTo1 + ", " + testTo2 + "\r\n" +
|
||||
"From: " + testFrom + "\r\n" +
|
||||
"Mime-Version: 1.0\r\n" +
|
||||
"Date: Wed, 25 Jun 2014 17:46:00 +0000\r\n" +
|
||||
"Content-Type: text/plain; charset=UTF-8\r\n" +
|
||||
"Content-Transfer-Encoding: quoted-printable\r\n" +
|
||||
"\r\n" +
|
||||
testBody
|
||||
)
|
||||
|
||||
const wantMsg = "To: to1@example.com, to2@example.com\r\n" +
|
||||
"From: from@example.com\r\n" +
|
||||
"Mime-Version: 1.0\r\n" +
|
||||
"Date: Wed, 25 Jun 2014 17:46:00 +0000\r\n" +
|
||||
"Content-Type: text/plain; charset=UTF-8\r\n" +
|
||||
"Content-Transfer-Encoding: quoted-printable\r\n" +
|
||||
"\r\n" +
|
||||
"Test message"
|
||||
type mockSender SendFunc
|
||||
|
||||
func TestDefaultSendMail(t *testing.T) {
|
||||
testSendMail(t, testAddr, nil, []string{
|
||||
"Extension STARTTLS",
|
||||
"StartTLS",
|
||||
"Extension AUTH",
|
||||
"Auth",
|
||||
"Mail " + testFrom,
|
||||
"Rcpt " + testTo[0],
|
||||
"Rcpt " + testTo[1],
|
||||
"Data",
|
||||
"Write message",
|
||||
"Close writer",
|
||||
"Quit",
|
||||
"Close",
|
||||
})
|
||||
func (s mockSender) Send(from string, to []string, msg io.Reader) error {
|
||||
return s(from, to, msg)
|
||||
}
|
||||
|
||||
func TestSSLSendMail(t *testing.T) {
|
||||
testSendMail(t, testSSLAddr, nil, []string{
|
||||
"Extension AUTH",
|
||||
"Auth",
|
||||
"Mail " + testFrom,
|
||||
"Rcpt " + testTo[0],
|
||||
"Rcpt " + testTo[1],
|
||||
"Data",
|
||||
"Write message",
|
||||
"Close writer",
|
||||
"Quit",
|
||||
"Close",
|
||||
})
|
||||
type mockSendCloser struct {
|
||||
mockSender
|
||||
close func() error
|
||||
}
|
||||
|
||||
func TestTLSConfigSendMail(t *testing.T) {
|
||||
testSendMail(t, testAddr, testConfig, []string{
|
||||
"Extension STARTTLS",
|
||||
"StartTLS",
|
||||
"Extension AUTH",
|
||||
"Auth",
|
||||
"Mail " + testFrom,
|
||||
"Rcpt " + testTo[0],
|
||||
"Rcpt " + testTo[1],
|
||||
"Data",
|
||||
"Write message",
|
||||
"Close writer",
|
||||
"Quit",
|
||||
"Close",
|
||||
})
|
||||
func (s *mockSendCloser) Close() error {
|
||||
return s.close()
|
||||
}
|
||||
|
||||
func TestTLSConfigSSLSendMail(t *testing.T) {
|
||||
testSendMail(t, testSSLAddr, testConfig, []string{
|
||||
"Extension AUTH",
|
||||
"Auth",
|
||||
"Mail " + testFrom,
|
||||
"Rcpt " + testTo[0],
|
||||
"Rcpt " + testTo[1],
|
||||
"Data",
|
||||
"Write message",
|
||||
"Close writer",
|
||||
"Quit",
|
||||
"Close",
|
||||
})
|
||||
}
|
||||
|
||||
type mockClient struct {
|
||||
t *testing.T
|
||||
i int
|
||||
want []string
|
||||
addr string
|
||||
auth smtp.Auth
|
||||
config *tls.Config
|
||||
}
|
||||
|
||||
func (c *mockClient) Extension(ext string) (bool, string) {
|
||||
c.do("Extension " + ext)
|
||||
return true, ""
|
||||
}
|
||||
|
||||
func (c *mockClient) StartTLS(config *tls.Config) error {
|
||||
assertConfig(c.t, config, c.config)
|
||||
c.do("StartTLS")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mockClient) Auth(a smtp.Auth) error {
|
||||
assertAuth(c.t, a, c.auth)
|
||||
c.do("Auth")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mockClient) Mail(from string) error {
|
||||
c.do("Mail " + from)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mockClient) Rcpt(to string) error {
|
||||
c.do("Rcpt " + to)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mockClient) Data() (io.WriteCloser, error) {
|
||||
c.do("Data")
|
||||
return &mockWriter{c: c, want: wantMsg}, nil
|
||||
}
|
||||
|
||||
func (c *mockClient) Quit() error {
|
||||
c.do("Quit")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mockClient) Close() error {
|
||||
c.do("Close")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mockClient) do(cmd string) {
|
||||
if c.i >= len(c.want) {
|
||||
c.t.Fatalf("Invalid command %q", cmd)
|
||||
func TestSend(t *testing.T) {
|
||||
s := &mockSendCloser{
|
||||
mockSender: stubSend(t, testFrom, []string{testTo1, testTo2}, testMsg),
|
||||
close: func() error {
|
||||
t.Error("Close() should not be called in Send()")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
if cmd != c.want[c.i] {
|
||||
c.t.Fatalf("Invalid command, got %q, want %q", cmd, c.want[c.i])
|
||||
if err := Send(s, getTestMessage()); err != nil {
|
||||
t.Errorf("Send(): %v", err)
|
||||
}
|
||||
c.i++
|
||||
}
|
||||
|
||||
type mockWriter struct {
|
||||
want string
|
||||
c *mockClient
|
||||
func getTestMessage() *Message {
|
||||
m := NewMessage()
|
||||
m.SetHeader("From", testFrom)
|
||||
m.SetHeader("To", testTo1, testTo2)
|
||||
m.SetBody("text/plain", testBody)
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (w *mockWriter) Write(p []byte) (int, error) {
|
||||
w.c.do("Write message")
|
||||
compareBodies(w.c.t, string(p), w.want)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *mockWriter) Close() error {
|
||||
w.c.do("Close writer")
|
||||
return nil
|
||||
}
|
||||
|
||||
func testSendMail(t *testing.T, addr string, config *tls.Config, want []string) {
|
||||
testClient := &mockClient{
|
||||
t: t,
|
||||
want: want,
|
||||
addr: addr,
|
||||
auth: testAuth,
|
||||
config: config,
|
||||
}
|
||||
|
||||
initSMTP = func(addr string) (smtpClient, error) {
|
||||
assertAddr(t, addr, testClient.addr)
|
||||
return testClient, nil
|
||||
}
|
||||
|
||||
initTLS = func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
if network != "tcp" {
|
||||
t.Errorf("Invalid network, got %q, want tcp", network)
|
||||
func stubSend(t *testing.T, wantFrom string, wantTo []string, wantBody string) mockSender {
|
||||
return func(from string, to []string, msg io.Reader) error {
|
||||
if from != wantFrom {
|
||||
t.Errorf("invalid from, got %q, want %q", from, wantFrom)
|
||||
}
|
||||
assertAddr(t, addr, testClient.addr)
|
||||
assertConfig(t, config, testClient.config)
|
||||
return testTLSConn, nil
|
||||
}
|
||||
|
||||
newClient = func(conn net.Conn, host string) (smtpClient, error) {
|
||||
if conn != testTLSConn {
|
||||
t.Error("Invalid TLS connection used")
|
||||
if !reflect.DeepEqual(to, wantTo) {
|
||||
t.Errorf("invalid to, got %v, want %v", to, wantTo)
|
||||
}
|
||||
if host != testHost {
|
||||
t.Errorf("Invalid host, got %q, want %q", host, testHost)
|
||||
|
||||
content, err := ioutil.ReadAll(msg)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return testClient, nil
|
||||
}
|
||||
compareBodies(t, string(content), wantBody)
|
||||
|
||||
msg := NewMessage()
|
||||
msg.SetHeader("From", testFrom)
|
||||
msg.SetHeader("To", testTo...)
|
||||
msg.SetBody("text/plain", testBody)
|
||||
|
||||
var settings []MailerSetting
|
||||
if config != nil {
|
||||
settings = []MailerSetting{SetTLSConfig(config)}
|
||||
}
|
||||
|
||||
mailer := NewCustomMailer(addr, testAuth, settings...)
|
||||
if err := mailer.Send(msg); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func assertAuth(t *testing.T, got, want smtp.Auth) {
|
||||
if got != want {
|
||||
t.Errorf("Invalid auth, got %#v, want %#v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func assertAddr(t *testing.T, got, want string) {
|
||||
if got != want {
|
||||
t.Errorf("Invalid addr, got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func assertConfig(t *testing.T, got, want *tls.Config) {
|
||||
if want == nil {
|
||||
want = &tls.Config{ServerName: testHost}
|
||||
}
|
||||
if got.ServerName != want.ServerName {
|
||||
t.Errorf("Invalid field ServerName in config, got %q, want %q", got.ServerName, want.ServerName)
|
||||
}
|
||||
if got.InsecureSkipVerify != want.InsecureSkipVerify {
|
||||
t.Errorf("Invalid field InsecureSkipVerify in config, got %v, want %v", got.InsecureSkipVerify, want.InsecureSkipVerify)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
|
|
@ -0,0 +1,168 @@
|
|||
package gomail
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/smtp"
|
||||
)
|
||||
|
||||
// An SMTPDialer is a dialer to an SMTP server.
|
||||
type SMTPDialer struct {
|
||||
// Host represents the host of the SMTP server.
|
||||
Host string
|
||||
// Port represents the port of the SMTP server.
|
||||
Port int
|
||||
// Auth represents the authentication mechanism used to authenticate to the
|
||||
// SMTP server.
|
||||
Auth smtp.Auth
|
||||
// SSL defines whether an SSL connection is used. It should be false in
|
||||
// most cases since the authentication mechanism should use the STARTTLS
|
||||
// extension instead.
|
||||
SSL bool
|
||||
// TSLConfig represents the TLS configuration used for the TLS (when the
|
||||
// STARTTLS extension is used) or SSL connection.
|
||||
TLSConfig *tls.Config
|
||||
}
|
||||
|
||||
// NewPlainDialer returns an SMTPDialer. The given parameters are used to
|
||||
// connect to the SMTP server via a PLAIN authentication mechanism.
|
||||
func NewPlainDialer(host, username, password string, port int) *SMTPDialer {
|
||||
return &SMTPDialer{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Auth: smtp.PlainAuth("", username, password, host),
|
||||
SSL: port == 465,
|
||||
}
|
||||
}
|
||||
|
||||
// Dial dials and authenticates to an SMTP server. The returned SendCloser
|
||||
// should be closed when done using it.
|
||||
func (d *SMTPDialer) Dial() (SendCloser, error) {
|
||||
c, err := d.dial()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if d.Auth != nil {
|
||||
if ok, _ := c.Extension("AUTH"); ok {
|
||||
if err = c.Auth(d.Auth); err != nil {
|
||||
c.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &smtpSender{c}, nil
|
||||
}
|
||||
|
||||
func (d *SMTPDialer) dial() (smtpClient, error) {
|
||||
if d.SSL {
|
||||
return d.sslDial()
|
||||
}
|
||||
return d.starttlsDial()
|
||||
}
|
||||
|
||||
func (d *SMTPDialer) starttlsDial() (smtpClient, error) {
|
||||
c, err := smtpDial(addr(d.Host, d.Port))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if ok, _ := c.Extension("STARTTLS"); ok {
|
||||
if err := c.StartTLS(d.tlsConfig()); err != nil {
|
||||
c.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (d *SMTPDialer) sslDial() (smtpClient, error) {
|
||||
conn, err := tlsDial("tcp", addr(d.Host, d.Port), d.tlsConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return newClient(conn, d.Host)
|
||||
}
|
||||
|
||||
func (d *SMTPDialer) tlsConfig() *tls.Config {
|
||||
if d.TLSConfig == nil {
|
||||
return &tls.Config{ServerName: d.Host}
|
||||
}
|
||||
|
||||
return d.TLSConfig
|
||||
}
|
||||
|
||||
func addr(host string, port int) string {
|
||||
return fmt.Sprintf("%s:%d", host, port)
|
||||
}
|
||||
|
||||
// DialAndSend opens a connection to an SMTP server, sends the given emails and
|
||||
// closes the connection.
|
||||
func (d *SMTPDialer) DialAndSend(msg ...*Message) error {
|
||||
s, err := d.Dial()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer s.Close()
|
||||
|
||||
return Send(s, msg...)
|
||||
}
|
||||
|
||||
type smtpSender struct {
|
||||
smtpClient
|
||||
}
|
||||
|
||||
func (c *smtpSender) Send(from string, to []string, msg io.Reader) error {
|
||||
if err := c.Mail(from); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, addr := range to {
|
||||
if err := c.Rcpt(addr); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
w, err := c.Data()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err = io.Copy(w, msg); err != nil {
|
||||
w.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return w.Close()
|
||||
}
|
||||
|
||||
func (c *smtpSender) Close() error {
|
||||
return c.Quit()
|
||||
}
|
||||
|
||||
// Stubbed out for tests.
|
||||
var (
|
||||
smtpDial = func(addr string) (smtpClient, error) {
|
||||
return smtp.Dial(addr)
|
||||
}
|
||||
tlsDial = tls.Dial
|
||||
newClient = func(conn net.Conn, host string) (smtpClient, error) {
|
||||
return smtp.NewClient(conn, host)
|
||||
}
|
||||
)
|
||||
|
||||
type smtpClient interface {
|
||||
Extension(string) (bool, string)
|
||||
StartTLS(*tls.Config) error
|
||||
Auth(smtp.Auth) error
|
||||
Mail(string) error
|
||||
Rcpt(string) error
|
||||
Data() (io.WriteCloser, error)
|
||||
Quit() error
|
||||
Close() error
|
||||
}
|
|
@ -0,0 +1,248 @@
|
|||
package gomail
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/smtp"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
testHost = "smtp.example.com"
|
||||
testPort = 587
|
||||
testSSLPort = 465
|
||||
testTLSConn = &tls.Conn{}
|
||||
testConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
testAuth = smtp.PlainAuth("", "user", "pwd", testHost)
|
||||
)
|
||||
|
||||
func TestSMTPDialer(t *testing.T) {
|
||||
d := NewPlainDialer(testHost, "user", "pwd", testPort)
|
||||
testSendMail(t, d, []string{
|
||||
"Extension STARTTLS",
|
||||
"StartTLS",
|
||||
"Extension AUTH",
|
||||
"Auth",
|
||||
"Mail " + testFrom,
|
||||
"Rcpt " + testTo1,
|
||||
"Rcpt " + testTo2,
|
||||
"Data",
|
||||
"Write message",
|
||||
"Close writer",
|
||||
"Quit",
|
||||
"Close",
|
||||
})
|
||||
}
|
||||
|
||||
func TestSMTPDialerSSL(t *testing.T) {
|
||||
d := NewPlainDialer(testHost, "user", "pwd", testSSLPort)
|
||||
testSendMail(t, d, []string{
|
||||
"Extension AUTH",
|
||||
"Auth",
|
||||
"Mail " + testFrom,
|
||||
"Rcpt " + testTo1,
|
||||
"Rcpt " + testTo2,
|
||||
"Data",
|
||||
"Write message",
|
||||
"Close writer",
|
||||
"Quit",
|
||||
"Close",
|
||||
})
|
||||
}
|
||||
|
||||
func TestSMTPDialerConfig(t *testing.T) {
|
||||
d := NewPlainDialer(testHost, "user", "pwd", testPort)
|
||||
d.TLSConfig = testConfig
|
||||
testSendMail(t, d, []string{
|
||||
"Extension STARTTLS",
|
||||
"StartTLS",
|
||||
"Extension AUTH",
|
||||
"Auth",
|
||||
"Mail " + testFrom,
|
||||
"Rcpt " + testTo1,
|
||||
"Rcpt " + testTo2,
|
||||
"Data",
|
||||
"Write message",
|
||||
"Close writer",
|
||||
"Quit",
|
||||
"Close",
|
||||
})
|
||||
}
|
||||
|
||||
func TestSMTPDialerSSLConfig(t *testing.T) {
|
||||
d := NewPlainDialer(testHost, "user", "pwd", testSSLPort)
|
||||
d.TLSConfig = testConfig
|
||||
testSendMail(t, d, []string{
|
||||
"Extension AUTH",
|
||||
"Auth",
|
||||
"Mail " + testFrom,
|
||||
"Rcpt " + testTo1,
|
||||
"Rcpt " + testTo2,
|
||||
"Data",
|
||||
"Write message",
|
||||
"Close writer",
|
||||
"Quit",
|
||||
"Close",
|
||||
})
|
||||
}
|
||||
|
||||
func TestSMTPDialerNoAuth(t *testing.T) {
|
||||
d := &SMTPDialer{
|
||||
Host: testHost,
|
||||
Port: testPort,
|
||||
}
|
||||
testSendMail(t, d, []string{
|
||||
"Extension STARTTLS",
|
||||
"StartTLS",
|
||||
"Mail " + testFrom,
|
||||
"Rcpt " + testTo1,
|
||||
"Rcpt " + testTo2,
|
||||
"Data",
|
||||
"Write message",
|
||||
"Close writer",
|
||||
"Quit",
|
||||
"Close",
|
||||
})
|
||||
}
|
||||
|
||||
type mockClient struct {
|
||||
t *testing.T
|
||||
i int
|
||||
want []string
|
||||
addr string
|
||||
auth smtp.Auth
|
||||
config *tls.Config
|
||||
}
|
||||
|
||||
func (c *mockClient) Extension(ext string) (bool, string) {
|
||||
c.do("Extension " + ext)
|
||||
return true, ""
|
||||
}
|
||||
|
||||
func (c *mockClient) StartTLS(config *tls.Config) error {
|
||||
assertConfig(c.t, config, c.config)
|
||||
c.do("StartTLS")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mockClient) Auth(a smtp.Auth) error {
|
||||
assertAuth(c.t, a, c.auth)
|
||||
c.do("Auth")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mockClient) Mail(from string) error {
|
||||
c.do("Mail " + from)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mockClient) Rcpt(to string) error {
|
||||
c.do("Rcpt " + to)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mockClient) Data() (io.WriteCloser, error) {
|
||||
c.do("Data")
|
||||
return &mockWriter{c: c, want: testMsg}, nil
|
||||
}
|
||||
|
||||
func (c *mockClient) Quit() error {
|
||||
c.do("Quit")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mockClient) Close() error {
|
||||
c.do("Close")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mockClient) do(cmd string) {
|
||||
if c.i >= len(c.want) {
|
||||
c.t.Fatalf("Invalid command %q", cmd)
|
||||
}
|
||||
|
||||
if cmd != c.want[c.i] {
|
||||
c.t.Fatalf("Invalid command, got %q, want %q", cmd, c.want[c.i])
|
||||
}
|
||||
c.i++
|
||||
}
|
||||
|
||||
type mockWriter struct {
|
||||
want string
|
||||
c *mockClient
|
||||
}
|
||||
|
||||
func (w *mockWriter) Write(p []byte) (int, error) {
|
||||
w.c.do("Write message")
|
||||
compareBodies(w.c.t, string(p), w.want)
|
||||
return len(p), nil
|
||||
}
|
||||
|
||||
func (w *mockWriter) Close() error {
|
||||
w.c.do("Close writer")
|
||||
return nil
|
||||
}
|
||||
|
||||
func testSendMail(t *testing.T, d *SMTPDialer, want []string) {
|
||||
testClient := &mockClient{
|
||||
t: t,
|
||||
want: want,
|
||||
addr: addr(d.Host, d.Port),
|
||||
auth: testAuth,
|
||||
config: d.TLSConfig,
|
||||
}
|
||||
|
||||
smtpDial = func(addr string) (smtpClient, error) {
|
||||
assertAddr(t, addr, testClient.addr)
|
||||
return testClient, nil
|
||||
}
|
||||
|
||||
tlsDial = func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
if network != "tcp" {
|
||||
t.Errorf("Invalid network, got %q, want tcp", network)
|
||||
}
|
||||
assertAddr(t, addr, testClient.addr)
|
||||
assertConfig(t, config, testClient.config)
|
||||
return testTLSConn, nil
|
||||
}
|
||||
|
||||
newClient = func(conn net.Conn, host string) (smtpClient, error) {
|
||||
if conn != testTLSConn {
|
||||
t.Error("Invalid TLS connection used")
|
||||
}
|
||||
if host != testHost {
|
||||
t.Errorf("Invalid host, got %q, want %q", host, testHost)
|
||||
}
|
||||
return testClient, nil
|
||||
}
|
||||
|
||||
if err := d.DialAndSend(getTestMessage()); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func assertAuth(t *testing.T, got, want smtp.Auth) {
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("Invalid auth, got %#v, want %#v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func assertAddr(t *testing.T, got, want string) {
|
||||
if got != want {
|
||||
t.Errorf("Invalid addr, got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func assertConfig(t *testing.T, got, want *tls.Config) {
|
||||
if want == nil {
|
||||
want = &tls.Config{ServerName: testHost}
|
||||
}
|
||||
if got.ServerName != want.ServerName {
|
||||
t.Errorf("Invalid field ServerName in config, got %q, want %q", got.ServerName, want.ServerName)
|
||||
}
|
||||
if got.InsecureSkipVerify != want.InsecureSkipVerify {
|
||||
t.Errorf("Invalid field InsecureSkipVerify in config, got %v, want %v", got.InsecureSkipVerify, want.InsecureSkipVerify)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue