Removed LoginAuth and updated NewPlainDialer to handle LOGIN auth
This commit is contained in:
parent
31a7bd9a49
commit
f01c0a3645
|
@ -0,0 +1,67 @@
|
|||
package gomail
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/smtp"
|
||||
)
|
||||
|
||||
// plainAuth is an smtp.Auth that implements the PLAIN authentication mechanism.
|
||||
// It fallbacks to the LOGIN mechanism if it is the only mechanism advertised
|
||||
// by the server.
|
||||
type plainAuth struct {
|
||||
username string
|
||||
password string
|
||||
host string
|
||||
login bool
|
||||
}
|
||||
|
||||
func (a *plainAuth) Start(server *smtp.ServerInfo) (string, []byte, error) {
|
||||
if server.Name != a.host {
|
||||
return "", nil, errors.New("gomail: wrong host name")
|
||||
}
|
||||
|
||||
var plain, login bool
|
||||
for _, a := range server.Auth {
|
||||
switch a {
|
||||
case "PLAIN":
|
||||
plain = true
|
||||
case "LOGIN":
|
||||
login = true
|
||||
}
|
||||
}
|
||||
|
||||
if !server.TLS && !plain && !login {
|
||||
return "", nil, errors.New("gomail: unencrypted connection")
|
||||
}
|
||||
|
||||
if !plain && login {
|
||||
a.login = true
|
||||
return "LOGIN", nil, nil
|
||||
}
|
||||
|
||||
return "PLAIN", []byte(a.username + "\x00" + a.password), nil
|
||||
}
|
||||
|
||||
func (a *plainAuth) Next(fromServer []byte, more bool) ([]byte, error) {
|
||||
if !a.login {
|
||||
if more {
|
||||
return nil, errors.New("gomail: unexpected server challenge")
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
if !more {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case bytes.Equal(fromServer, []byte("Username:")):
|
||||
return []byte(a.username), nil
|
||||
case bytes.Equal(fromServer, []byte("Password:")):
|
||||
return []byte(a.password), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("gomail: unexpected server challenge: %s", fromServer)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,156 @@
|
|||
package gomail
|
||||
|
||||
import (
|
||||
"net/smtp"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const (
|
||||
testUser = "user"
|
||||
testPwd = "pwd"
|
||||
testHost = "smtp.example.com"
|
||||
)
|
||||
|
||||
var testAuth = &plainAuth{
|
||||
username: testUser,
|
||||
password: testPwd,
|
||||
host: testHost,
|
||||
}
|
||||
|
||||
type plainAuthTest struct {
|
||||
auths []string
|
||||
challenges []string
|
||||
tls bool
|
||||
wantProto string
|
||||
wantData []string
|
||||
wantError bool
|
||||
}
|
||||
|
||||
func TestNoAdvertisement(t *testing.T) {
|
||||
testPlainAuth(t, &plainAuthTest{
|
||||
auths: []string{},
|
||||
challenges: []string{"Username:", "Password:"},
|
||||
tls: false,
|
||||
wantProto: "PLAIN",
|
||||
wantError: true,
|
||||
})
|
||||
}
|
||||
|
||||
func TestNoAdvertisementTLS(t *testing.T) {
|
||||
testPlainAuth(t, &plainAuthTest{
|
||||
auths: []string{},
|
||||
challenges: []string{"Username:", "Password:"},
|
||||
tls: true,
|
||||
wantProto: "PLAIN",
|
||||
wantData: []string{testUser + "\x00" + testPwd},
|
||||
})
|
||||
}
|
||||
|
||||
func TestPlain(t *testing.T) {
|
||||
testPlainAuth(t, &plainAuthTest{
|
||||
auths: []string{"PLAIN"},
|
||||
challenges: []string{"Username:", "Password:"},
|
||||
tls: false,
|
||||
wantProto: "PLAIN",
|
||||
wantData: []string{testUser + "\x00" + testPwd},
|
||||
})
|
||||
}
|
||||
|
||||
func TestPlainTLS(t *testing.T) {
|
||||
testPlainAuth(t, &plainAuthTest{
|
||||
auths: []string{"PLAIN"},
|
||||
challenges: []string{"Username:", "Password:"},
|
||||
tls: true,
|
||||
wantProto: "PLAIN",
|
||||
wantData: []string{testUser + "\x00" + testPwd},
|
||||
})
|
||||
}
|
||||
|
||||
func TestPlainAndLogin(t *testing.T) {
|
||||
testPlainAuth(t, &plainAuthTest{
|
||||
auths: []string{"PLAIN", "LOGIN"},
|
||||
challenges: []string{"Username:", "Password:"},
|
||||
tls: false,
|
||||
wantProto: "PLAIN",
|
||||
wantData: []string{testUser + "\x00" + testPwd},
|
||||
})
|
||||
}
|
||||
|
||||
func TestPlainAndLoginTLS(t *testing.T) {
|
||||
testPlainAuth(t, &plainAuthTest{
|
||||
auths: []string{"PLAIN", "LOGIN"},
|
||||
challenges: []string{"Username:", "Password:"},
|
||||
tls: true,
|
||||
wantProto: "PLAIN",
|
||||
wantData: []string{testUser + "\x00" + testPwd},
|
||||
})
|
||||
}
|
||||
|
||||
func TestLogin(t *testing.T) {
|
||||
testPlainAuth(t, &plainAuthTest{
|
||||
auths: []string{"LOGIN"},
|
||||
challenges: []string{"Username:", "Password:"},
|
||||
tls: false,
|
||||
wantProto: "LOGIN",
|
||||
wantData: []string{"", testUser, testPwd},
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoginTLS(t *testing.T) {
|
||||
testPlainAuth(t, &plainAuthTest{
|
||||
auths: []string{"LOGIN"},
|
||||
challenges: []string{"Username:", "Password:"},
|
||||
tls: true,
|
||||
wantProto: "LOGIN",
|
||||
wantData: []string{"", testUser, testPwd},
|
||||
})
|
||||
}
|
||||
|
||||
func testPlainAuth(t *testing.T, test *plainAuthTest) {
|
||||
auth := &plainAuth{
|
||||
username: testUser,
|
||||
password: testPwd,
|
||||
host: testHost,
|
||||
}
|
||||
server := &smtp.ServerInfo{
|
||||
Name: testHost,
|
||||
TLS: test.tls,
|
||||
Auth: test.auths,
|
||||
}
|
||||
proto, toServer, err := auth.Start(server)
|
||||
if err != nil && !test.wantError {
|
||||
t.Fatalf("plainAuth.Start(): %v", err)
|
||||
}
|
||||
if err != nil && test.wantError {
|
||||
return
|
||||
}
|
||||
if proto != test.wantProto {
|
||||
t.Errorf("invalid protocol, got %q, want %q", proto, test.wantProto)
|
||||
}
|
||||
|
||||
i := 0
|
||||
got := string(toServer)
|
||||
if got != test.wantData[i] {
|
||||
t.Errorf("Invalid response, got %q, want %q", got, test.wantData[i])
|
||||
}
|
||||
|
||||
if proto == "PLAIN" {
|
||||
return
|
||||
}
|
||||
|
||||
for _, challenge := range test.challenges {
|
||||
i++
|
||||
if i >= len(test.wantData) {
|
||||
t.Fatalf("unexpected challenge: %q", challenge)
|
||||
}
|
||||
|
||||
toServer, err = auth.Next([]byte(challenge), true)
|
||||
if err != nil {
|
||||
t.Fatalf("plainAuth.Auth(): %v", err)
|
||||
}
|
||||
got = string(toServer)
|
||||
if got != test.wantData[i] {
|
||||
t.Errorf("Invalid response, got %q, want %q", got, test.wantData[i])
|
||||
}
|
||||
}
|
||||
}
|
54
login.go
54
login.go
|
@ -1,54 +0,0 @@
|
|||
package gomail
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/smtp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type loginAuth struct {
|
||||
username string
|
||||
password string
|
||||
host string
|
||||
}
|
||||
|
||||
// LoginAuth returns an Auth that implements the LOGIN authentication mechanism.
|
||||
func LoginAuth(username, password, host string) smtp.Auth {
|
||||
return &loginAuth{username, password, host}
|
||||
}
|
||||
|
||||
func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) {
|
||||
if !server.TLS {
|
||||
advertised := false
|
||||
for _, mechanism := range server.Auth {
|
||||
if mechanism == "LOGIN" {
|
||||
advertised = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !advertised {
|
||||
return "", nil, errors.New("gomail: unencrypted connection")
|
||||
}
|
||||
}
|
||||
if server.Name != a.host {
|
||||
return "", nil, errors.New("gomail: wrong host name")
|
||||
}
|
||||
return "LOGIN", nil, nil
|
||||
}
|
||||
|
||||
func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
|
||||
if !more {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
command := strings.ToLower(strings.TrimSuffix(string(fromServer), ":"))
|
||||
switch command {
|
||||
case "username":
|
||||
return []byte(fmt.Sprintf("%s", a.username)), nil
|
||||
case "password":
|
||||
return []byte(fmt.Sprintf("%s", a.password)), nil
|
||||
default:
|
||||
return nil, fmt.Errorf("gomail: unexpected server challenge: %s", command)
|
||||
}
|
||||
}
|
|
@ -1,66 +0,0 @@
|
|||
package gomail
|
||||
|
||||
import (
|
||||
"net/smtp"
|
||||
"testing"
|
||||
)
|
||||
|
||||
type output struct {
|
||||
proto string
|
||||
data []string
|
||||
err error
|
||||
}
|
||||
|
||||
const (
|
||||
testUser = "user"
|
||||
testPwd = "pwd"
|
||||
)
|
||||
|
||||
func TestPlainAuth(t *testing.T) {
|
||||
tests := []struct {
|
||||
serverProtos []string
|
||||
serverChallenges []string
|
||||
proto string
|
||||
data []string
|
||||
}{
|
||||
{
|
||||
serverProtos: []string{"LOGIN"},
|
||||
serverChallenges: []string{"Username:", "Password:"},
|
||||
proto: "LOGIN",
|
||||
data: []string{"", testUser, testPwd},
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
auth := LoginAuth(testUser, testPwd, testHost)
|
||||
server := &smtp.ServerInfo{
|
||||
Name: testHost,
|
||||
TLS: true,
|
||||
Auth: test.serverProtos,
|
||||
}
|
||||
proto, toServer, err := auth.Start(server)
|
||||
if err != nil {
|
||||
t.Fatalf("Start error: %v", err)
|
||||
}
|
||||
if proto != test.proto {
|
||||
t.Errorf("Invalid protocol, got %q, want %q", proto, test.proto)
|
||||
}
|
||||
|
||||
i := 0
|
||||
got := string(toServer)
|
||||
if got != test.data[i] {
|
||||
t.Errorf("Invalid response, got %q, want %q", got, test.data[i])
|
||||
}
|
||||
for _, challenge := range test.serverChallenges {
|
||||
toServer, err = auth.Next([]byte(challenge), true)
|
||||
if err != nil {
|
||||
t.Fatalf("Auth error: %v", err)
|
||||
}
|
||||
i++
|
||||
got = string(toServer)
|
||||
if got != test.data[i] {
|
||||
t.Errorf("Invalid response, got %q, want %q", got, test.data[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
13
smtp.go
13
smtp.go
|
@ -26,14 +26,19 @@ type SMTPDialer struct {
|
|||
TLSConfig *tls.Config
|
||||
}
|
||||
|
||||
// NewPlainDialer returns an SMTPDialer. The given parameters are used to
|
||||
// connect to the SMTP server via a PLAIN authentication mechanism.
|
||||
// NewPlainDialer returns a Dialer. The given parameters are used to connect to
|
||||
// the SMTP server via a PLAIN authentication mechanism. It fallbacks to the
|
||||
// LOGIN mechanism if it is the only mechanism advertised by the server.
|
||||
func NewPlainDialer(host, username, password string, port int) *SMTPDialer {
|
||||
return &SMTPDialer{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Auth: smtp.PlainAuth("", username, password, host),
|
||||
SSL: port == 465,
|
||||
Auth: &plainAuth{
|
||||
username: username,
|
||||
password: password,
|
||||
host: host,
|
||||
},
|
||||
SSL: port == 465,
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -9,13 +9,14 @@ import (
|
|||
"testing"
|
||||
)
|
||||
|
||||
var (
|
||||
testHost = "smtp.example.com"
|
||||
const (
|
||||
testPort = 587
|
||||
testSSLPort = 465
|
||||
)
|
||||
|
||||
var (
|
||||
testTLSConn = &tls.Conn{}
|
||||
testConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
testAuth = smtp.PlainAuth("", "user", "pwd", testHost)
|
||||
)
|
||||
|
||||
func TestSMTPDialer(t *testing.T) {
|
||||
|
|
Loading…
Reference in New Issue