diff --git a/README.md b/README.md index c23be4a..b3be9e1 100644 --- a/README.md +++ b/README.md @@ -60,7 +60,7 @@ bypass the verification of the server's certificate chain and host name by using ) func main() { - d := gomail.NewPlainDialer("smtp.example.com", 587, "user", "123456") + d := gomail.NewDialer("smtp.example.com", 587, "user", "123456") d.TLSConfig = &tls.Config{InsecureSkipVerify: true} // Send emails using d. diff --git a/auth.go b/auth.go index 4bcdd06..d28b83a 100644 --- a/auth.go +++ b/auth.go @@ -7,51 +7,33 @@ import ( "net/smtp" ) -// plainAuth is an smtp.Auth that implements the PLAIN authentication mechanism. -// It fallbacks to the LOGIN mechanism if it is the only mechanism advertised -// by the server. -type plainAuth struct { +// loginAuth is an smtp.Auth that implements the LOGIN authentication mechanism. +type loginAuth struct { username string password string host string - login bool } -func (a *plainAuth) Start(server *smtp.ServerInfo) (string, []byte, error) { +func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) { + if !server.TLS { + advertised := false + for _, mechanism := range server.Auth { + if mechanism == "LOGIN" { + advertised = true + break + } + } + if !advertised { + return "", nil, errors.New("gomail: unencrypted connection") + } + } if server.Name != a.host { return "", nil, errors.New("gomail: wrong host name") } - - var plain, login bool - for _, a := range server.Auth { - switch a { - case "PLAIN": - plain = true - case "LOGIN": - login = true - } - } - - if !server.TLS && !plain && !login { - return "", nil, errors.New("gomail: unencrypted connection") - } - - if !plain && login { - a.login = true - return "LOGIN", nil, nil - } - - return "PLAIN", []byte("\x00" + a.username + "\x00" + a.password), nil + return "LOGIN", nil, nil } -func (a *plainAuth) Next(fromServer []byte, more bool) ([]byte, error) { - if !a.login { - if more { - return nil, errors.New("gomail: unexpected server challenge") - } - return nil, nil - } - +func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) { if !more { return nil, nil } diff --git a/auth_test.go b/auth_test.go index 20b4772..428ef34 100644 --- a/auth_test.go +++ b/auth_test.go @@ -11,103 +11,51 @@ const ( testHost = "smtp.example.com" ) -var testAuth = &plainAuth{ - username: testUser, - password: testPwd, - host: testHost, -} - -type plainAuthTest struct { +type authTest struct { auths []string challenges []string tls bool - wantProto string wantData []string wantError bool } func TestNoAdvertisement(t *testing.T) { - testPlainAuth(t, &plainAuthTest{ - auths: []string{}, - challenges: []string{"Username:", "Password:"}, - tls: false, - wantProto: "PLAIN", - wantError: true, + testLoginAuth(t, &authTest{ + auths: []string{}, + tls: false, + wantError: true, }) } func TestNoAdvertisementTLS(t *testing.T) { - testPlainAuth(t, &plainAuthTest{ + testLoginAuth(t, &authTest{ auths: []string{}, challenges: []string{"Username:", "Password:"}, tls: true, - wantProto: "PLAIN", - wantData: []string{"\x00" + testUser + "\x00" + testPwd}, - }) -} - -func TestPlain(t *testing.T) { - testPlainAuth(t, &plainAuthTest{ - auths: []string{"PLAIN"}, - challenges: []string{"Username:", "Password:"}, - tls: false, - wantProto: "PLAIN", - wantData: []string{"\x00" + testUser + "\x00" + testPwd}, - }) -} - -func TestPlainTLS(t *testing.T) { - testPlainAuth(t, &plainAuthTest{ - auths: []string{"PLAIN"}, - challenges: []string{"Username:", "Password:"}, - tls: true, - wantProto: "PLAIN", - wantData: []string{"\x00" + testUser + "\x00" + testPwd}, - }) -} - -func TestPlainAndLogin(t *testing.T) { - testPlainAuth(t, &plainAuthTest{ - auths: []string{"PLAIN", "LOGIN"}, - challenges: []string{"Username:", "Password:"}, - tls: false, - wantProto: "PLAIN", - wantData: []string{"\x00" + testUser + "\x00" + testPwd}, - }) -} - -func TestPlainAndLoginTLS(t *testing.T) { - testPlainAuth(t, &plainAuthTest{ - auths: []string{"PLAIN", "LOGIN"}, - challenges: []string{"Username:", "Password:"}, - tls: true, - wantProto: "PLAIN", - wantData: []string{"\x00" + testUser + "\x00" + testPwd}, + wantData: []string{"", testUser, testPwd}, }) } func TestLogin(t *testing.T) { - testPlainAuth(t, &plainAuthTest{ - auths: []string{"LOGIN"}, + testLoginAuth(t, &authTest{ + auths: []string{"PLAIN", "LOGIN"}, challenges: []string{"Username:", "Password:"}, tls: false, - wantProto: "LOGIN", wantData: []string{"", testUser, testPwd}, }) } func TestLoginTLS(t *testing.T) { - testPlainAuth(t, &plainAuthTest{ + testLoginAuth(t, &authTest{ auths: []string{"LOGIN"}, challenges: []string{"Username:", "Password:"}, tls: true, - wantProto: "LOGIN", wantData: []string{"", testUser, testPwd}, }) } -func testPlainAuth(t *testing.T, test *plainAuthTest) { - auth := &plainAuth{ +func testLoginAuth(t *testing.T, test *authTest) { + auth := &loginAuth{ username: testUser, password: testPwd, host: testHost, @@ -119,13 +67,13 @@ func testPlainAuth(t *testing.T, test *plainAuthTest) { } proto, toServer, err := auth.Start(server) if err != nil && !test.wantError { - t.Fatalf("plainAuth.Start(): %v", err) + t.Fatalf("loginAuth.Start(): %v", err) } if err != nil && test.wantError { return } - if proto != test.wantProto { - t.Errorf("invalid protocol, got %q, want %q", proto, test.wantProto) + if proto != "LOGIN" { + t.Errorf("invalid protocol, got %q, want LOGIN", proto) } i := 0 @@ -134,10 +82,6 @@ func testPlainAuth(t *testing.T, test *plainAuthTest) { t.Errorf("Invalid response, got %q, want %q", got, test.wantData[i]) } - if proto == "PLAIN" { - return - } - for _, challenge := range test.challenges { i++ if i >= len(test.wantData) { @@ -146,7 +90,7 @@ func testPlainAuth(t *testing.T, test *plainAuthTest) { toServer, err = auth.Next([]byte(challenge), true) if err != nil { - t.Fatalf("plainAuth.Auth(): %v", err) + t.Fatalf("loginAuth.Auth(): %v", err) } got = string(toServer) if got != test.wantData[i] { diff --git a/example_test.go b/example_test.go index 4fca7f6..c50d0ea 100644 --- a/example_test.go +++ b/example_test.go @@ -5,7 +5,6 @@ import ( "html/template" "io" "log" - "net/smtp" "time" "gopkg.in/gomail.v2" @@ -20,7 +19,7 @@ func Example() { m.SetBody("text/html", "Hello Bob and Cora!") m.Attach("/home/Alex/lolcat.jpg") - d := gomail.NewPlainDialer("smtp.example.com", 587, "user", "123456") + d := gomail.NewDialer("smtp.example.com", 587, "user", "123456") // Send the email to Bob, Cora and Dan. if err := d.DialAndSend(m); err != nil { @@ -33,7 +32,7 @@ func Example_daemon() { ch := make(chan *gomail.Message) go func() { - d := gomail.NewPlainDialer("smtp.example.com", 587, "user", "123456") + d := gomail.NewDialer("smtp.example.com", 587, "user", "123456") var s gomail.SendCloser var err error @@ -80,7 +79,7 @@ func Example_newsletter() { Address string } - d := gomail.NewPlainDialer("smtp.example.com", 587, "user", "123456") + d := gomail.NewDialer("smtp.example.com", 587, "user", "123456") s, err := d.Dial() if err != nil { panic(err) @@ -114,24 +113,6 @@ func Example_noAuth() { } } -// Send an email using the CRAM-MD5 authentication mechanism. -func Example_cRAMMD5() { - m := gomail.NewMessage() - m.SetHeader("From", "from@example.com") - m.SetHeader("To", "to@example.com") - m.SetHeader("Subject", "Hello!") - m.SetBody("text/plain", "Hello!") - - d := gomail.Dialer{ - Host: "localhost", - Port: 587, - Auth: smtp.CRAMMD5Auth("username", "secret"), - } - if err := d.DialAndSend(m); err != nil { - panic(err) - } -} - // Send an email using an API or postfix. func Example_noSMTP() { m := gomail.NewMessage() diff --git a/smtp.go b/smtp.go index 9de4b22..af5d52b 100644 --- a/smtp.go +++ b/smtp.go @@ -6,6 +6,7 @@ import ( "io" "net" "net/smtp" + "strings" ) // A Dialer is a dialer to an SMTP server. @@ -14,6 +15,10 @@ type Dialer struct { Host string // Port represents the port of the SMTP server. Port int + // Username is the username to use to authenticate to the SMTP server. + Username string + // Password is the password to use to authenticate to the SMTP server. + Password string // Auth represents the authentication mechanism used to authenticate to the // SMTP server. Auth smtp.Auth @@ -29,98 +34,89 @@ type Dialer struct { LocalName string } -// NewPlainDialer returns a Dialer. The given parameters are used to connect to -// the SMTP server via a PLAIN authentication mechanism. -// -// It fallbacks to the LOGIN mechanism if it is the only mechanism advertised by -// the server. -func NewPlainDialer(host string, port int, username, password string) *Dialer { +// NewDialer returns a new SMTP Dialer. The given parameters are used to connect +// to the SMTP server. +func NewDialer(host string, port int, username, password string) *Dialer { return &Dialer{ - Host: host, - Port: port, - Auth: &plainAuth{ - username: username, - password: password, - host: host, - }, - SSL: port == 465, + Host: host, + Port: port, + Username: username, + Password: password, + SSL: port == 465, } } +// NewPlainDialer returns a new SMTP Dialer. The given parameters are used to +// connect to the SMTP server. +// +// Deprecated: Use NewDialer instead. +func NewPlainDialer(host string, port int, username, password string) *Dialer { + return NewDialer(host, port, username, password) +} + // Dial dials and authenticates to an SMTP server. The returned SendCloser // should be closed when done using it. func (d *Dialer) Dial() (SendCloser, error) { - c, err := d.dial() + conn, err := netDial("tcp", addr(d.Host, d.Port)) if err != nil { return nil, err } - if d.Auth != nil { - if ok, _ := c.Extension("AUTH"); ok { - if err = c.Auth(d.Auth); err != nil { + if d.SSL { + conn = tlsClient(conn, d.tlsConfig()) + } + + c, err := smtpNewClient(conn, d.Host) + if err != nil { + return nil, err + } + + if d.LocalName != "" { + if err := c.Hello(d.LocalName); err != nil { + return nil, err + } + } + + if !d.SSL { + if ok, _ := c.Extension("STARTTLS"); ok { + if err := c.StartTLS(d.tlsConfig()); err != nil { c.Close() return nil, err } } } - return &smtpSender{c}, nil -} - -func (d *Dialer) dial() (smtpClient, error) { - if d.SSL { - return d.sslDial() - } - return d.starttlsDial() -} - -func (d *Dialer) starttlsDial() (smtpClient, error) { - c, err := smtpDial(addr(d.Host, d.Port)) - if err != nil { - return nil, err - } - - if d.LocalName != "" { - if err := c.Hello(d.LocalName); err != nil { - return nil, err + if d.Auth == nil && d.Username != "" { + if ok, auths := c.Extension("AUTH"); ok { + if strings.Contains(auths, "CRAM-MD5") { + d.Auth = smtp.CRAMMD5Auth(d.Username, d.Password) + } else if strings.Contains(auths, "LOGIN") && + !strings.Contains(auths, "PLAIN") { + d.Auth = &loginAuth{ + username: d.Username, + password: d.Password, + host: d.Host, + } + } else { + d.Auth = smtp.PlainAuth("", d.Username, d.Password, d.Host) + } } } - if ok, _ := c.Extension("STARTTLS"); ok { - if err := c.StartTLS(d.tlsConfig()); err != nil { + if d.Auth != nil { + if err = c.Auth(d.Auth); err != nil { c.Close() return nil, err } } - return c, nil -} - -func (d *Dialer) sslDial() (smtpClient, error) { - conn, err := tlsDial("tcp", addr(d.Host, d.Port), d.tlsConfig()) - if err != nil { - return nil, err - } - - c, err := newClient(conn, d.Host) - if err != nil { - return nil, err - } - - if d.LocalName != "" { - if err := c.Hello(d.LocalName); err != nil { - return nil, err - } - } - - return c, nil + return &smtpSender{c}, nil } func (d *Dialer) tlsConfig() *tls.Config { if d.TLSConfig == nil { return &tls.Config{ServerName: d.Host} } - return d.TLSConfig } @@ -174,11 +170,9 @@ func (c *smtpSender) Close() error { // Stubbed out for tests. var ( - smtpDial = func(addr string) (smtpClient, error) { - return smtp.Dial(addr) - } - tlsDial = tls.Dial - newClient = func(conn net.Conn, host string) (smtpClient, error) { + netDial = net.Dial + tlsClient = tls.Client + smtpNewClient = func(conn net.Conn, host string) (smtpClient, error) { return smtp.NewClient(conn, host) } ) diff --git a/smtp_test.go b/smtp_test.go index 070e763..300c9b7 100644 --- a/smtp_test.go +++ b/smtp_test.go @@ -16,12 +16,14 @@ const ( ) var ( + testConn = &net.TCPConn{} testTLSConn = &tls.Conn{} testConfig = &tls.Config{InsecureSkipVerify: true} + testAuth = smtp.PlainAuth("", testUser, testPwd, testHost) ) func TestDialer(t *testing.T) { - d := NewPlainDialer(testHost, testPort, "user", "pwd") + d := NewDialer(testHost, testPort, "user", "pwd") testSendMail(t, d, []string{ "Extension STARTTLS", "StartTLS", @@ -39,7 +41,7 @@ func TestDialer(t *testing.T) { } func TestDialerSSL(t *testing.T) { - d := NewPlainDialer(testHost, testSSLPort, "user", "pwd") + d := NewDialer(testHost, testSSLPort, "user", "pwd") testSendMail(t, d, []string{ "Extension AUTH", "Auth", @@ -55,7 +57,7 @@ func TestDialerSSL(t *testing.T) { } func TestDialerConfig(t *testing.T) { - d := NewPlainDialer(testHost, testPort, "user", "pwd") + d := NewDialer(testHost, testPort, "user", "pwd") d.LocalName = "test" d.TLSConfig = testConfig testSendMail(t, d, []string{ @@ -76,7 +78,7 @@ func TestDialerConfig(t *testing.T) { } func TestDialerSSLConfig(t *testing.T) { - d := NewPlainDialer(testHost, testSSLPort, "user", "pwd") + d := NewDialer(testHost, testSSLPort, "user", "pwd") d.LocalName = "test" d.TLSConfig = testConfig testSendMail(t, d, []string{ @@ -118,7 +120,6 @@ type mockClient struct { i int want []string addr string - auth smtp.Auth config *tls.Config } @@ -139,7 +140,9 @@ func (c *mockClient) StartTLS(config *tls.Config) error { } func (c *mockClient) Auth(a smtp.Auth) error { - assertAuth(c.t, a, c.auth) + if !reflect.DeepEqual(a, testAuth) { + c.t.Errorf("Invalid auth, got %#v, want %#v", a, testAuth) + } c.do("Auth") return nil } @@ -205,28 +208,29 @@ func testSendMail(t *testing.T, d *Dialer, want []string) { t: t, want: want, addr: addr(d.Host, d.Port), - auth: testAuth, config: d.TLSConfig, } - smtpDial = func(addr string) (smtpClient, error) { - assertAddr(t, addr, testClient.addr) - return testClient, nil - } - - tlsDial = func(network, addr string, config *tls.Config) (*tls.Conn, error) { + netDial = func(network, address string) (net.Conn, error) { if network != "tcp" { t.Errorf("Invalid network, got %q, want tcp", network) } - assertAddr(t, addr, testClient.addr) - assertConfig(t, config, testClient.config) - return testTLSConn, nil + if address != testClient.addr { + t.Errorf("Invalid address, got %q, want %q", + address, testClient.addr) + } + return testConn, nil } - newClient = func(conn net.Conn, host string) (smtpClient, error) { - if conn != testTLSConn { - t.Error("Invalid TLS connection used") + tlsClient = func(conn net.Conn, config *tls.Config) *tls.Conn { + if conn != testConn { + t.Errorf("Invalid conn, got %#v, want %#v", conn, testConn) } + assertConfig(t, config, testClient.config) + return testTLSConn + } + + smtpNewClient = func(conn net.Conn, host string) (smtpClient, error) { if host != testHost { t.Errorf("Invalid host, got %q, want %q", host, testHost) } @@ -238,18 +242,6 @@ func testSendMail(t *testing.T, d *Dialer, want []string) { } } -func assertAuth(t *testing.T, got, want smtp.Auth) { - if !reflect.DeepEqual(got, want) { - t.Errorf("Invalid auth, got %#v, want %#v", got, want) - } -} - -func assertAddr(t *testing.T, got, want string) { - if got != want { - t.Errorf("Invalid addr, got %q, want %q", got, want) - } -} - func assertConfig(t *testing.T, got, want *tls.Config) { if want == nil { want = &tls.Config{ServerName: testHost}