diff --git a/README.md b/README.md
index c23be4a..b3be9e1 100644
--- a/README.md
+++ b/README.md
@@ -60,7 +60,7 @@ bypass the verification of the server's certificate chain and host name by using
)
func main() {
- d := gomail.NewPlainDialer("smtp.example.com", 587, "user", "123456")
+ d := gomail.NewDialer("smtp.example.com", 587, "user", "123456")
d.TLSConfig = &tls.Config{InsecureSkipVerify: true}
// Send emails using d.
diff --git a/auth.go b/auth.go
index 4bcdd06..d28b83a 100644
--- a/auth.go
+++ b/auth.go
@@ -7,51 +7,33 @@ import (
"net/smtp"
)
-// plainAuth is an smtp.Auth that implements the PLAIN authentication mechanism.
-// It fallbacks to the LOGIN mechanism if it is the only mechanism advertised
-// by the server.
-type plainAuth struct {
+// loginAuth is an smtp.Auth that implements the LOGIN authentication mechanism.
+type loginAuth struct {
username string
password string
host string
- login bool
}
-func (a *plainAuth) Start(server *smtp.ServerInfo) (string, []byte, error) {
+func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) {
+ if !server.TLS {
+ advertised := false
+ for _, mechanism := range server.Auth {
+ if mechanism == "LOGIN" {
+ advertised = true
+ break
+ }
+ }
+ if !advertised {
+ return "", nil, errors.New("gomail: unencrypted connection")
+ }
+ }
if server.Name != a.host {
return "", nil, errors.New("gomail: wrong host name")
}
-
- var plain, login bool
- for _, a := range server.Auth {
- switch a {
- case "PLAIN":
- plain = true
- case "LOGIN":
- login = true
- }
- }
-
- if !server.TLS && !plain && !login {
- return "", nil, errors.New("gomail: unencrypted connection")
- }
-
- if !plain && login {
- a.login = true
- return "LOGIN", nil, nil
- }
-
- return "PLAIN", []byte("\x00" + a.username + "\x00" + a.password), nil
+ return "LOGIN", nil, nil
}
-func (a *plainAuth) Next(fromServer []byte, more bool) ([]byte, error) {
- if !a.login {
- if more {
- return nil, errors.New("gomail: unexpected server challenge")
- }
- return nil, nil
- }
-
+func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
if !more {
return nil, nil
}
diff --git a/auth_test.go b/auth_test.go
index 20b4772..428ef34 100644
--- a/auth_test.go
+++ b/auth_test.go
@@ -11,103 +11,51 @@ const (
testHost = "smtp.example.com"
)
-var testAuth = &plainAuth{
- username: testUser,
- password: testPwd,
- host: testHost,
-}
-
-type plainAuthTest struct {
+type authTest struct {
auths []string
challenges []string
tls bool
- wantProto string
wantData []string
wantError bool
}
func TestNoAdvertisement(t *testing.T) {
- testPlainAuth(t, &plainAuthTest{
- auths: []string{},
- challenges: []string{"Username:", "Password:"},
- tls: false,
- wantProto: "PLAIN",
- wantError: true,
+ testLoginAuth(t, &authTest{
+ auths: []string{},
+ tls: false,
+ wantError: true,
})
}
func TestNoAdvertisementTLS(t *testing.T) {
- testPlainAuth(t, &plainAuthTest{
+ testLoginAuth(t, &authTest{
auths: []string{},
challenges: []string{"Username:", "Password:"},
tls: true,
- wantProto: "PLAIN",
- wantData: []string{"\x00" + testUser + "\x00" + testPwd},
- })
-}
-
-func TestPlain(t *testing.T) {
- testPlainAuth(t, &plainAuthTest{
- auths: []string{"PLAIN"},
- challenges: []string{"Username:", "Password:"},
- tls: false,
- wantProto: "PLAIN",
- wantData: []string{"\x00" + testUser + "\x00" + testPwd},
- })
-}
-
-func TestPlainTLS(t *testing.T) {
- testPlainAuth(t, &plainAuthTest{
- auths: []string{"PLAIN"},
- challenges: []string{"Username:", "Password:"},
- tls: true,
- wantProto: "PLAIN",
- wantData: []string{"\x00" + testUser + "\x00" + testPwd},
- })
-}
-
-func TestPlainAndLogin(t *testing.T) {
- testPlainAuth(t, &plainAuthTest{
- auths: []string{"PLAIN", "LOGIN"},
- challenges: []string{"Username:", "Password:"},
- tls: false,
- wantProto: "PLAIN",
- wantData: []string{"\x00" + testUser + "\x00" + testPwd},
- })
-}
-
-func TestPlainAndLoginTLS(t *testing.T) {
- testPlainAuth(t, &plainAuthTest{
- auths: []string{"PLAIN", "LOGIN"},
- challenges: []string{"Username:", "Password:"},
- tls: true,
- wantProto: "PLAIN",
- wantData: []string{"\x00" + testUser + "\x00" + testPwd},
+ wantData: []string{"", testUser, testPwd},
})
}
func TestLogin(t *testing.T) {
- testPlainAuth(t, &plainAuthTest{
- auths: []string{"LOGIN"},
+ testLoginAuth(t, &authTest{
+ auths: []string{"PLAIN", "LOGIN"},
challenges: []string{"Username:", "Password:"},
tls: false,
- wantProto: "LOGIN",
wantData: []string{"", testUser, testPwd},
})
}
func TestLoginTLS(t *testing.T) {
- testPlainAuth(t, &plainAuthTest{
+ testLoginAuth(t, &authTest{
auths: []string{"LOGIN"},
challenges: []string{"Username:", "Password:"},
tls: true,
- wantProto: "LOGIN",
wantData: []string{"", testUser, testPwd},
})
}
-func testPlainAuth(t *testing.T, test *plainAuthTest) {
- auth := &plainAuth{
+func testLoginAuth(t *testing.T, test *authTest) {
+ auth := &loginAuth{
username: testUser,
password: testPwd,
host: testHost,
@@ -119,13 +67,13 @@ func testPlainAuth(t *testing.T, test *plainAuthTest) {
}
proto, toServer, err := auth.Start(server)
if err != nil && !test.wantError {
- t.Fatalf("plainAuth.Start(): %v", err)
+ t.Fatalf("loginAuth.Start(): %v", err)
}
if err != nil && test.wantError {
return
}
- if proto != test.wantProto {
- t.Errorf("invalid protocol, got %q, want %q", proto, test.wantProto)
+ if proto != "LOGIN" {
+ t.Errorf("invalid protocol, got %q, want LOGIN", proto)
}
i := 0
@@ -134,10 +82,6 @@ func testPlainAuth(t *testing.T, test *plainAuthTest) {
t.Errorf("Invalid response, got %q, want %q", got, test.wantData[i])
}
- if proto == "PLAIN" {
- return
- }
-
for _, challenge := range test.challenges {
i++
if i >= len(test.wantData) {
@@ -146,7 +90,7 @@ func testPlainAuth(t *testing.T, test *plainAuthTest) {
toServer, err = auth.Next([]byte(challenge), true)
if err != nil {
- t.Fatalf("plainAuth.Auth(): %v", err)
+ t.Fatalf("loginAuth.Auth(): %v", err)
}
got = string(toServer)
if got != test.wantData[i] {
diff --git a/example_test.go b/example_test.go
index 4fca7f6..c50d0ea 100644
--- a/example_test.go
+++ b/example_test.go
@@ -5,7 +5,6 @@ import (
"html/template"
"io"
"log"
- "net/smtp"
"time"
"gopkg.in/gomail.v2"
@@ -20,7 +19,7 @@ func Example() {
m.SetBody("text/html", "Hello Bob and Cora!")
m.Attach("/home/Alex/lolcat.jpg")
- d := gomail.NewPlainDialer("smtp.example.com", 587, "user", "123456")
+ d := gomail.NewDialer("smtp.example.com", 587, "user", "123456")
// Send the email to Bob, Cora and Dan.
if err := d.DialAndSend(m); err != nil {
@@ -33,7 +32,7 @@ func Example_daemon() {
ch := make(chan *gomail.Message)
go func() {
- d := gomail.NewPlainDialer("smtp.example.com", 587, "user", "123456")
+ d := gomail.NewDialer("smtp.example.com", 587, "user", "123456")
var s gomail.SendCloser
var err error
@@ -80,7 +79,7 @@ func Example_newsletter() {
Address string
}
- d := gomail.NewPlainDialer("smtp.example.com", 587, "user", "123456")
+ d := gomail.NewDialer("smtp.example.com", 587, "user", "123456")
s, err := d.Dial()
if err != nil {
panic(err)
@@ -114,24 +113,6 @@ func Example_noAuth() {
}
}
-// Send an email using the CRAM-MD5 authentication mechanism.
-func Example_cRAMMD5() {
- m := gomail.NewMessage()
- m.SetHeader("From", "from@example.com")
- m.SetHeader("To", "to@example.com")
- m.SetHeader("Subject", "Hello!")
- m.SetBody("text/plain", "Hello!")
-
- d := gomail.Dialer{
- Host: "localhost",
- Port: 587,
- Auth: smtp.CRAMMD5Auth("username", "secret"),
- }
- if err := d.DialAndSend(m); err != nil {
- panic(err)
- }
-}
-
// Send an email using an API or postfix.
func Example_noSMTP() {
m := gomail.NewMessage()
diff --git a/smtp.go b/smtp.go
index 9de4b22..af5d52b 100644
--- a/smtp.go
+++ b/smtp.go
@@ -6,6 +6,7 @@ import (
"io"
"net"
"net/smtp"
+ "strings"
)
// A Dialer is a dialer to an SMTP server.
@@ -14,6 +15,10 @@ type Dialer struct {
Host string
// Port represents the port of the SMTP server.
Port int
+ // Username is the username to use to authenticate to the SMTP server.
+ Username string
+ // Password is the password to use to authenticate to the SMTP server.
+ Password string
// Auth represents the authentication mechanism used to authenticate to the
// SMTP server.
Auth smtp.Auth
@@ -29,98 +34,89 @@ type Dialer struct {
LocalName string
}
-// NewPlainDialer returns a Dialer. The given parameters are used to connect to
-// the SMTP server via a PLAIN authentication mechanism.
-//
-// It fallbacks to the LOGIN mechanism if it is the only mechanism advertised by
-// the server.
-func NewPlainDialer(host string, port int, username, password string) *Dialer {
+// NewDialer returns a new SMTP Dialer. The given parameters are used to connect
+// to the SMTP server.
+func NewDialer(host string, port int, username, password string) *Dialer {
return &Dialer{
- Host: host,
- Port: port,
- Auth: &plainAuth{
- username: username,
- password: password,
- host: host,
- },
- SSL: port == 465,
+ Host: host,
+ Port: port,
+ Username: username,
+ Password: password,
+ SSL: port == 465,
}
}
+// NewPlainDialer returns a new SMTP Dialer. The given parameters are used to
+// connect to the SMTP server.
+//
+// Deprecated: Use NewDialer instead.
+func NewPlainDialer(host string, port int, username, password string) *Dialer {
+ return NewDialer(host, port, username, password)
+}
+
// Dial dials and authenticates to an SMTP server. The returned SendCloser
// should be closed when done using it.
func (d *Dialer) Dial() (SendCloser, error) {
- c, err := d.dial()
+ conn, err := netDial("tcp", addr(d.Host, d.Port))
if err != nil {
return nil, err
}
- if d.Auth != nil {
- if ok, _ := c.Extension("AUTH"); ok {
- if err = c.Auth(d.Auth); err != nil {
+ if d.SSL {
+ conn = tlsClient(conn, d.tlsConfig())
+ }
+
+ c, err := smtpNewClient(conn, d.Host)
+ if err != nil {
+ return nil, err
+ }
+
+ if d.LocalName != "" {
+ if err := c.Hello(d.LocalName); err != nil {
+ return nil, err
+ }
+ }
+
+ if !d.SSL {
+ if ok, _ := c.Extension("STARTTLS"); ok {
+ if err := c.StartTLS(d.tlsConfig()); err != nil {
c.Close()
return nil, err
}
}
}
- return &smtpSender{c}, nil
-}
-
-func (d *Dialer) dial() (smtpClient, error) {
- if d.SSL {
- return d.sslDial()
- }
- return d.starttlsDial()
-}
-
-func (d *Dialer) starttlsDial() (smtpClient, error) {
- c, err := smtpDial(addr(d.Host, d.Port))
- if err != nil {
- return nil, err
- }
-
- if d.LocalName != "" {
- if err := c.Hello(d.LocalName); err != nil {
- return nil, err
+ if d.Auth == nil && d.Username != "" {
+ if ok, auths := c.Extension("AUTH"); ok {
+ if strings.Contains(auths, "CRAM-MD5") {
+ d.Auth = smtp.CRAMMD5Auth(d.Username, d.Password)
+ } else if strings.Contains(auths, "LOGIN") &&
+ !strings.Contains(auths, "PLAIN") {
+ d.Auth = &loginAuth{
+ username: d.Username,
+ password: d.Password,
+ host: d.Host,
+ }
+ } else {
+ d.Auth = smtp.PlainAuth("", d.Username, d.Password, d.Host)
+ }
}
}
- if ok, _ := c.Extension("STARTTLS"); ok {
- if err := c.StartTLS(d.tlsConfig()); err != nil {
+ if d.Auth != nil {
+ if err = c.Auth(d.Auth); err != nil {
c.Close()
return nil, err
}
}
- return c, nil
-}
-
-func (d *Dialer) sslDial() (smtpClient, error) {
- conn, err := tlsDial("tcp", addr(d.Host, d.Port), d.tlsConfig())
- if err != nil {
- return nil, err
- }
-
- c, err := newClient(conn, d.Host)
- if err != nil {
- return nil, err
- }
-
- if d.LocalName != "" {
- if err := c.Hello(d.LocalName); err != nil {
- return nil, err
- }
- }
-
- return c, nil
+ return &smtpSender{c}, nil
}
func (d *Dialer) tlsConfig() *tls.Config {
if d.TLSConfig == nil {
return &tls.Config{ServerName: d.Host}
}
-
return d.TLSConfig
}
@@ -174,11 +170,9 @@ func (c *smtpSender) Close() error {
// Stubbed out for tests.
var (
- smtpDial = func(addr string) (smtpClient, error) {
- return smtp.Dial(addr)
- }
- tlsDial = tls.Dial
- newClient = func(conn net.Conn, host string) (smtpClient, error) {
+ netDial = net.Dial
+ tlsClient = tls.Client
+ smtpNewClient = func(conn net.Conn, host string) (smtpClient, error) {
return smtp.NewClient(conn, host)
}
)
diff --git a/smtp_test.go b/smtp_test.go
index 070e763..300c9b7 100644
--- a/smtp_test.go
+++ b/smtp_test.go
@@ -16,12 +16,14 @@ const (
)
var (
+ testConn = &net.TCPConn{}
testTLSConn = &tls.Conn{}
testConfig = &tls.Config{InsecureSkipVerify: true}
+ testAuth = smtp.PlainAuth("", testUser, testPwd, testHost)
)
func TestDialer(t *testing.T) {
- d := NewPlainDialer(testHost, testPort, "user", "pwd")
+ d := NewDialer(testHost, testPort, "user", "pwd")
testSendMail(t, d, []string{
"Extension STARTTLS",
"StartTLS",
@@ -39,7 +41,7 @@ func TestDialer(t *testing.T) {
}
func TestDialerSSL(t *testing.T) {
- d := NewPlainDialer(testHost, testSSLPort, "user", "pwd")
+ d := NewDialer(testHost, testSSLPort, "user", "pwd")
testSendMail(t, d, []string{
"Extension AUTH",
"Auth",
@@ -55,7 +57,7 @@ func TestDialerSSL(t *testing.T) {
}
func TestDialerConfig(t *testing.T) {
- d := NewPlainDialer(testHost, testPort, "user", "pwd")
+ d := NewDialer(testHost, testPort, "user", "pwd")
d.LocalName = "test"
d.TLSConfig = testConfig
testSendMail(t, d, []string{
@@ -76,7 +78,7 @@ func TestDialerConfig(t *testing.T) {
}
func TestDialerSSLConfig(t *testing.T) {
- d := NewPlainDialer(testHost, testSSLPort, "user", "pwd")
+ d := NewDialer(testHost, testSSLPort, "user", "pwd")
d.LocalName = "test"
d.TLSConfig = testConfig
testSendMail(t, d, []string{
@@ -118,7 +120,6 @@ type mockClient struct {
i int
want []string
addr string
- auth smtp.Auth
config *tls.Config
}
@@ -139,7 +140,9 @@ func (c *mockClient) StartTLS(config *tls.Config) error {
}
func (c *mockClient) Auth(a smtp.Auth) error {
- assertAuth(c.t, a, c.auth)
+ if !reflect.DeepEqual(a, testAuth) {
+ c.t.Errorf("Invalid auth, got %#v, want %#v", a, testAuth)
+ }
c.do("Auth")
return nil
}
@@ -205,28 +208,29 @@ func testSendMail(t *testing.T, d *Dialer, want []string) {
t: t,
want: want,
addr: addr(d.Host, d.Port),
- auth: testAuth,
config: d.TLSConfig,
}
- smtpDial = func(addr string) (smtpClient, error) {
- assertAddr(t, addr, testClient.addr)
- return testClient, nil
- }
-
- tlsDial = func(network, addr string, config *tls.Config) (*tls.Conn, error) {
+ netDial = func(network, address string) (net.Conn, error) {
if network != "tcp" {
t.Errorf("Invalid network, got %q, want tcp", network)
}
- assertAddr(t, addr, testClient.addr)
- assertConfig(t, config, testClient.config)
- return testTLSConn, nil
+ if address != testClient.addr {
+ t.Errorf("Invalid address, got %q, want %q",
+ address, testClient.addr)
+ }
+ return testConn, nil
}
- newClient = func(conn net.Conn, host string) (smtpClient, error) {
- if conn != testTLSConn {
- t.Error("Invalid TLS connection used")
+ tlsClient = func(conn net.Conn, config *tls.Config) *tls.Conn {
+ if conn != testConn {
+ t.Errorf("Invalid conn, got %#v, want %#v", conn, testConn)
}
+ assertConfig(t, config, testClient.config)
+ return testTLSConn
+ }
+
+ smtpNewClient = func(conn net.Conn, host string) (smtpClient, error) {
if host != testHost {
t.Errorf("Invalid host, got %q, want %q", host, testHost)
}
@@ -238,18 +242,6 @@ func testSendMail(t *testing.T, d *Dialer, want []string) {
}
}
-func assertAuth(t *testing.T, got, want smtp.Auth) {
- if !reflect.DeepEqual(got, want) {
- t.Errorf("Invalid auth, got %#v, want %#v", got, want)
- }
-}
-
-func assertAddr(t *testing.T, got, want string) {
- if got != want {
- t.Errorf("Invalid addr, got %q, want %q", got, want)
- }
-}
-
func assertConfig(t *testing.T, got, want *tls.Config) {
if want == nil {
want = &tls.Config{ServerName: testHost}