diff --git a/auth.go b/auth.go new file mode 100644 index 0000000..c7b63d5 --- /dev/null +++ b/auth.go @@ -0,0 +1,67 @@ +package gomail + +import ( + "bytes" + "errors" + "fmt" + "net/smtp" +) + +// plainAuth is an smtp.Auth that implements the PLAIN authentication mechanism. +// It fallbacks to the LOGIN mechanism if it is the only mechanism advertised +// by the server. +type plainAuth struct { + username string + password string + host string + login bool +} + +func (a *plainAuth) Start(server *smtp.ServerInfo) (string, []byte, error) { + if server.Name != a.host { + return "", nil, errors.New("gomail: wrong host name") + } + + var plain, login bool + for _, a := range server.Auth { + switch a { + case "PLAIN": + plain = true + case "LOGIN": + login = true + } + } + + if !server.TLS && !plain && !login { + return "", nil, errors.New("gomail: unencrypted connection") + } + + if !plain && login { + a.login = true + return "LOGIN", nil, nil + } + + return "PLAIN", []byte(a.username + "\x00" + a.password), nil +} + +func (a *plainAuth) Next(fromServer []byte, more bool) ([]byte, error) { + if !a.login { + if more { + return nil, errors.New("gomail: unexpected server challenge") + } + return nil, nil + } + + if !more { + return nil, nil + } + + switch { + case bytes.Equal(fromServer, []byte("Username:")): + return []byte(a.username), nil + case bytes.Equal(fromServer, []byte("Password:")): + return []byte(a.password), nil + default: + return nil, fmt.Errorf("gomail: unexpected server challenge: %s", fromServer) + } +} diff --git a/auth_test.go b/auth_test.go new file mode 100644 index 0000000..092a85f --- /dev/null +++ b/auth_test.go @@ -0,0 +1,156 @@ +package gomail + +import ( + "net/smtp" + "testing" +) + +const ( + testUser = "user" + testPwd = "pwd" + testHost = "smtp.example.com" +) + +var testAuth = &plainAuth{ + username: testUser, + password: testPwd, + host: testHost, +} + +type plainAuthTest struct { + auths []string + challenges []string + tls bool + wantProto string + wantData []string + wantError bool +} + +func TestNoAdvertisement(t *testing.T) { + testPlainAuth(t, &plainAuthTest{ + auths: []string{}, + challenges: []string{"Username:", "Password:"}, + tls: false, + wantProto: "PLAIN", + wantError: true, + }) +} + +func TestNoAdvertisementTLS(t *testing.T) { + testPlainAuth(t, &plainAuthTest{ + auths: []string{}, + challenges: []string{"Username:", "Password:"}, + tls: true, + wantProto: "PLAIN", + wantData: []string{testUser + "\x00" + testPwd}, + }) +} + +func TestPlain(t *testing.T) { + testPlainAuth(t, &plainAuthTest{ + auths: []string{"PLAIN"}, + challenges: []string{"Username:", "Password:"}, + tls: false, + wantProto: "PLAIN", + wantData: []string{testUser + "\x00" + testPwd}, + }) +} + +func TestPlainTLS(t *testing.T) { + testPlainAuth(t, &plainAuthTest{ + auths: []string{"PLAIN"}, + challenges: []string{"Username:", "Password:"}, + tls: true, + wantProto: "PLAIN", + wantData: []string{testUser + "\x00" + testPwd}, + }) +} + +func TestPlainAndLogin(t *testing.T) { + testPlainAuth(t, &plainAuthTest{ + auths: []string{"PLAIN", "LOGIN"}, + challenges: []string{"Username:", "Password:"}, + tls: false, + wantProto: "PLAIN", + wantData: []string{testUser + "\x00" + testPwd}, + }) +} + +func TestPlainAndLoginTLS(t *testing.T) { + testPlainAuth(t, &plainAuthTest{ + auths: []string{"PLAIN", "LOGIN"}, + challenges: []string{"Username:", "Password:"}, + tls: true, + wantProto: "PLAIN", + wantData: []string{testUser + "\x00" + testPwd}, + }) +} + +func TestLogin(t *testing.T) { + testPlainAuth(t, &plainAuthTest{ + auths: []string{"LOGIN"}, + challenges: []string{"Username:", "Password:"}, + tls: false, + wantProto: "LOGIN", + wantData: []string{"", testUser, testPwd}, + }) +} + +func TestLoginTLS(t *testing.T) { + testPlainAuth(t, &plainAuthTest{ + auths: []string{"LOGIN"}, + challenges: []string{"Username:", "Password:"}, + tls: true, + wantProto: "LOGIN", + wantData: []string{"", testUser, testPwd}, + }) +} + +func testPlainAuth(t *testing.T, test *plainAuthTest) { + auth := &plainAuth{ + username: testUser, + password: testPwd, + host: testHost, + } + server := &smtp.ServerInfo{ + Name: testHost, + TLS: test.tls, + Auth: test.auths, + } + proto, toServer, err := auth.Start(server) + if err != nil && !test.wantError { + t.Fatalf("plainAuth.Start(): %v", err) + } + if err != nil && test.wantError { + return + } + if proto != test.wantProto { + t.Errorf("invalid protocol, got %q, want %q", proto, test.wantProto) + } + + i := 0 + got := string(toServer) + if got != test.wantData[i] { + t.Errorf("Invalid response, got %q, want %q", got, test.wantData[i]) + } + + if proto == "PLAIN" { + return + } + + for _, challenge := range test.challenges { + i++ + if i >= len(test.wantData) { + t.Fatalf("unexpected challenge: %q", challenge) + } + + toServer, err = auth.Next([]byte(challenge), true) + if err != nil { + t.Fatalf("plainAuth.Auth(): %v", err) + } + got = string(toServer) + if got != test.wantData[i] { + t.Errorf("Invalid response, got %q, want %q", got, test.wantData[i]) + } + } +} diff --git a/login.go b/login.go deleted file mode 100644 index ee4b3b4..0000000 --- a/login.go +++ /dev/null @@ -1,54 +0,0 @@ -package gomail - -import ( - "errors" - "fmt" - "net/smtp" - "strings" -) - -type loginAuth struct { - username string - password string - host string -} - -// LoginAuth returns an Auth that implements the LOGIN authentication mechanism. -func LoginAuth(username, password, host string) smtp.Auth { - return &loginAuth{username, password, host} -} - -func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) { - if !server.TLS { - advertised := false - for _, mechanism := range server.Auth { - if mechanism == "LOGIN" { - advertised = true - break - } - } - if !advertised { - return "", nil, errors.New("gomail: unencrypted connection") - } - } - if server.Name != a.host { - return "", nil, errors.New("gomail: wrong host name") - } - return "LOGIN", nil, nil -} - -func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) { - if !more { - return nil, nil - } - - command := strings.ToLower(strings.TrimSuffix(string(fromServer), ":")) - switch command { - case "username": - return []byte(fmt.Sprintf("%s", a.username)), nil - case "password": - return []byte(fmt.Sprintf("%s", a.password)), nil - default: - return nil, fmt.Errorf("gomail: unexpected server challenge: %s", command) - } -} diff --git a/login_test.go b/login_test.go deleted file mode 100644 index 64e1762..0000000 --- a/login_test.go +++ /dev/null @@ -1,66 +0,0 @@ -package gomail - -import ( - "net/smtp" - "testing" -) - -type output struct { - proto string - data []string - err error -} - -const ( - testUser = "user" - testPwd = "pwd" -) - -func TestPlainAuth(t *testing.T) { - tests := []struct { - serverProtos []string - serverChallenges []string - proto string - data []string - }{ - { - serverProtos: []string{"LOGIN"}, - serverChallenges: []string{"Username:", "Password:"}, - proto: "LOGIN", - data: []string{"", testUser, testPwd}, - }, - } - - for _, test := range tests { - auth := LoginAuth(testUser, testPwd, testHost) - server := &smtp.ServerInfo{ - Name: testHost, - TLS: true, - Auth: test.serverProtos, - } - proto, toServer, err := auth.Start(server) - if err != nil { - t.Fatalf("Start error: %v", err) - } - if proto != test.proto { - t.Errorf("Invalid protocol, got %q, want %q", proto, test.proto) - } - - i := 0 - got := string(toServer) - if got != test.data[i] { - t.Errorf("Invalid response, got %q, want %q", got, test.data[i]) - } - for _, challenge := range test.serverChallenges { - toServer, err = auth.Next([]byte(challenge), true) - if err != nil { - t.Fatalf("Auth error: %v", err) - } - i++ - got = string(toServer) - if got != test.data[i] { - t.Errorf("Invalid response, got %q, want %q", got, test.data[i]) - } - } - } -} diff --git a/smtp.go b/smtp.go index 4222be4..d80888b 100644 --- a/smtp.go +++ b/smtp.go @@ -26,14 +26,19 @@ type SMTPDialer struct { TLSConfig *tls.Config } -// NewPlainDialer returns an SMTPDialer. The given parameters are used to -// connect to the SMTP server via a PLAIN authentication mechanism. +// NewPlainDialer returns a Dialer. The given parameters are used to connect to +// the SMTP server via a PLAIN authentication mechanism. It fallbacks to the +// LOGIN mechanism if it is the only mechanism advertised by the server. func NewPlainDialer(host, username, password string, port int) *SMTPDialer { return &SMTPDialer{ Host: host, Port: port, - Auth: smtp.PlainAuth("", username, password, host), - SSL: port == 465, + Auth: &plainAuth{ + username: username, + password: password, + host: host, + }, + SSL: port == 465, } } diff --git a/smtp_test.go b/smtp_test.go index 1d8cdc4..3554c64 100644 --- a/smtp_test.go +++ b/smtp_test.go @@ -9,13 +9,14 @@ import ( "testing" ) -var ( - testHost = "smtp.example.com" +const ( testPort = 587 testSSLPort = 465 +) + +var ( testTLSConn = &tls.Conn{} testConfig = &tls.Config{InsecureSkipVerify: true} - testAuth = smtp.PlainAuth("", "user", "pwd", testHost) ) func TestSMTPDialer(t *testing.T) {