Removed LoginAuth and updated NewPlainDialer to handle LOGIN auth
This commit is contained in:
parent
31a7bd9a49
commit
f01c0a3645
|
@ -0,0 +1,67 @@
|
||||||
|
package gomail
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/smtp"
|
||||||
|
)
|
||||||
|
|
||||||
|
// plainAuth is an smtp.Auth that implements the PLAIN authentication mechanism.
|
||||||
|
// It fallbacks to the LOGIN mechanism if it is the only mechanism advertised
|
||||||
|
// by the server.
|
||||||
|
type plainAuth struct {
|
||||||
|
username string
|
||||||
|
password string
|
||||||
|
host string
|
||||||
|
login bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *plainAuth) Start(server *smtp.ServerInfo) (string, []byte, error) {
|
||||||
|
if server.Name != a.host {
|
||||||
|
return "", nil, errors.New("gomail: wrong host name")
|
||||||
|
}
|
||||||
|
|
||||||
|
var plain, login bool
|
||||||
|
for _, a := range server.Auth {
|
||||||
|
switch a {
|
||||||
|
case "PLAIN":
|
||||||
|
plain = true
|
||||||
|
case "LOGIN":
|
||||||
|
login = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !server.TLS && !plain && !login {
|
||||||
|
return "", nil, errors.New("gomail: unencrypted connection")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !plain && login {
|
||||||
|
a.login = true
|
||||||
|
return "LOGIN", nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "PLAIN", []byte(a.username + "\x00" + a.password), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *plainAuth) Next(fromServer []byte, more bool) ([]byte, error) {
|
||||||
|
if !a.login {
|
||||||
|
if more {
|
||||||
|
return nil, errors.New("gomail: unexpected server challenge")
|
||||||
|
}
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !more {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case bytes.Equal(fromServer, []byte("Username:")):
|
||||||
|
return []byte(a.username), nil
|
||||||
|
case bytes.Equal(fromServer, []byte("Password:")):
|
||||||
|
return []byte(a.password), nil
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("gomail: unexpected server challenge: %s", fromServer)
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,156 @@
|
||||||
|
package gomail
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/smtp"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
testUser = "user"
|
||||||
|
testPwd = "pwd"
|
||||||
|
testHost = "smtp.example.com"
|
||||||
|
)
|
||||||
|
|
||||||
|
var testAuth = &plainAuth{
|
||||||
|
username: testUser,
|
||||||
|
password: testPwd,
|
||||||
|
host: testHost,
|
||||||
|
}
|
||||||
|
|
||||||
|
type plainAuthTest struct {
|
||||||
|
auths []string
|
||||||
|
challenges []string
|
||||||
|
tls bool
|
||||||
|
wantProto string
|
||||||
|
wantData []string
|
||||||
|
wantError bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNoAdvertisement(t *testing.T) {
|
||||||
|
testPlainAuth(t, &plainAuthTest{
|
||||||
|
auths: []string{},
|
||||||
|
challenges: []string{"Username:", "Password:"},
|
||||||
|
tls: false,
|
||||||
|
wantProto: "PLAIN",
|
||||||
|
wantError: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNoAdvertisementTLS(t *testing.T) {
|
||||||
|
testPlainAuth(t, &plainAuthTest{
|
||||||
|
auths: []string{},
|
||||||
|
challenges: []string{"Username:", "Password:"},
|
||||||
|
tls: true,
|
||||||
|
wantProto: "PLAIN",
|
||||||
|
wantData: []string{testUser + "\x00" + testPwd},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPlain(t *testing.T) {
|
||||||
|
testPlainAuth(t, &plainAuthTest{
|
||||||
|
auths: []string{"PLAIN"},
|
||||||
|
challenges: []string{"Username:", "Password:"},
|
||||||
|
tls: false,
|
||||||
|
wantProto: "PLAIN",
|
||||||
|
wantData: []string{testUser + "\x00" + testPwd},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPlainTLS(t *testing.T) {
|
||||||
|
testPlainAuth(t, &plainAuthTest{
|
||||||
|
auths: []string{"PLAIN"},
|
||||||
|
challenges: []string{"Username:", "Password:"},
|
||||||
|
tls: true,
|
||||||
|
wantProto: "PLAIN",
|
||||||
|
wantData: []string{testUser + "\x00" + testPwd},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPlainAndLogin(t *testing.T) {
|
||||||
|
testPlainAuth(t, &plainAuthTest{
|
||||||
|
auths: []string{"PLAIN", "LOGIN"},
|
||||||
|
challenges: []string{"Username:", "Password:"},
|
||||||
|
tls: false,
|
||||||
|
wantProto: "PLAIN",
|
||||||
|
wantData: []string{testUser + "\x00" + testPwd},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPlainAndLoginTLS(t *testing.T) {
|
||||||
|
testPlainAuth(t, &plainAuthTest{
|
||||||
|
auths: []string{"PLAIN", "LOGIN"},
|
||||||
|
challenges: []string{"Username:", "Password:"},
|
||||||
|
tls: true,
|
||||||
|
wantProto: "PLAIN",
|
||||||
|
wantData: []string{testUser + "\x00" + testPwd},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLogin(t *testing.T) {
|
||||||
|
testPlainAuth(t, &plainAuthTest{
|
||||||
|
auths: []string{"LOGIN"},
|
||||||
|
challenges: []string{"Username:", "Password:"},
|
||||||
|
tls: false,
|
||||||
|
wantProto: "LOGIN",
|
||||||
|
wantData: []string{"", testUser, testPwd},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoginTLS(t *testing.T) {
|
||||||
|
testPlainAuth(t, &plainAuthTest{
|
||||||
|
auths: []string{"LOGIN"},
|
||||||
|
challenges: []string{"Username:", "Password:"},
|
||||||
|
tls: true,
|
||||||
|
wantProto: "LOGIN",
|
||||||
|
wantData: []string{"", testUser, testPwd},
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func testPlainAuth(t *testing.T, test *plainAuthTest) {
|
||||||
|
auth := &plainAuth{
|
||||||
|
username: testUser,
|
||||||
|
password: testPwd,
|
||||||
|
host: testHost,
|
||||||
|
}
|
||||||
|
server := &smtp.ServerInfo{
|
||||||
|
Name: testHost,
|
||||||
|
TLS: test.tls,
|
||||||
|
Auth: test.auths,
|
||||||
|
}
|
||||||
|
proto, toServer, err := auth.Start(server)
|
||||||
|
if err != nil && !test.wantError {
|
||||||
|
t.Fatalf("plainAuth.Start(): %v", err)
|
||||||
|
}
|
||||||
|
if err != nil && test.wantError {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if proto != test.wantProto {
|
||||||
|
t.Errorf("invalid protocol, got %q, want %q", proto, test.wantProto)
|
||||||
|
}
|
||||||
|
|
||||||
|
i := 0
|
||||||
|
got := string(toServer)
|
||||||
|
if got != test.wantData[i] {
|
||||||
|
t.Errorf("Invalid response, got %q, want %q", got, test.wantData[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
if proto == "PLAIN" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, challenge := range test.challenges {
|
||||||
|
i++
|
||||||
|
if i >= len(test.wantData) {
|
||||||
|
t.Fatalf("unexpected challenge: %q", challenge)
|
||||||
|
}
|
||||||
|
|
||||||
|
toServer, err = auth.Next([]byte(challenge), true)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("plainAuth.Auth(): %v", err)
|
||||||
|
}
|
||||||
|
got = string(toServer)
|
||||||
|
if got != test.wantData[i] {
|
||||||
|
t.Errorf("Invalid response, got %q, want %q", got, test.wantData[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
54
login.go
54
login.go
|
@ -1,54 +0,0 @@
|
||||||
package gomail
|
|
||||||
|
|
||||||
import (
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"net/smtp"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
type loginAuth struct {
|
|
||||||
username string
|
|
||||||
password string
|
|
||||||
host string
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoginAuth returns an Auth that implements the LOGIN authentication mechanism.
|
|
||||||
func LoginAuth(username, password, host string) smtp.Auth {
|
|
||||||
return &loginAuth{username, password, host}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) {
|
|
||||||
if !server.TLS {
|
|
||||||
advertised := false
|
|
||||||
for _, mechanism := range server.Auth {
|
|
||||||
if mechanism == "LOGIN" {
|
|
||||||
advertised = true
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !advertised {
|
|
||||||
return "", nil, errors.New("gomail: unencrypted connection")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if server.Name != a.host {
|
|
||||||
return "", nil, errors.New("gomail: wrong host name")
|
|
||||||
}
|
|
||||||
return "LOGIN", nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
|
|
||||||
if !more {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
command := strings.ToLower(strings.TrimSuffix(string(fromServer), ":"))
|
|
||||||
switch command {
|
|
||||||
case "username":
|
|
||||||
return []byte(fmt.Sprintf("%s", a.username)), nil
|
|
||||||
case "password":
|
|
||||||
return []byte(fmt.Sprintf("%s", a.password)), nil
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("gomail: unexpected server challenge: %s", command)
|
|
||||||
}
|
|
||||||
}
|
|
|
@ -1,66 +0,0 @@
|
||||||
package gomail
|
|
||||||
|
|
||||||
import (
|
|
||||||
"net/smtp"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
type output struct {
|
|
||||||
proto string
|
|
||||||
data []string
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
testUser = "user"
|
|
||||||
testPwd = "pwd"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPlainAuth(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
serverProtos []string
|
|
||||||
serverChallenges []string
|
|
||||||
proto string
|
|
||||||
data []string
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
serverProtos: []string{"LOGIN"},
|
|
||||||
serverChallenges: []string{"Username:", "Password:"},
|
|
||||||
proto: "LOGIN",
|
|
||||||
data: []string{"", testUser, testPwd},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, test := range tests {
|
|
||||||
auth := LoginAuth(testUser, testPwd, testHost)
|
|
||||||
server := &smtp.ServerInfo{
|
|
||||||
Name: testHost,
|
|
||||||
TLS: true,
|
|
||||||
Auth: test.serverProtos,
|
|
||||||
}
|
|
||||||
proto, toServer, err := auth.Start(server)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Start error: %v", err)
|
|
||||||
}
|
|
||||||
if proto != test.proto {
|
|
||||||
t.Errorf("Invalid protocol, got %q, want %q", proto, test.proto)
|
|
||||||
}
|
|
||||||
|
|
||||||
i := 0
|
|
||||||
got := string(toServer)
|
|
||||||
if got != test.data[i] {
|
|
||||||
t.Errorf("Invalid response, got %q, want %q", got, test.data[i])
|
|
||||||
}
|
|
||||||
for _, challenge := range test.serverChallenges {
|
|
||||||
toServer, err = auth.Next([]byte(challenge), true)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("Auth error: %v", err)
|
|
||||||
}
|
|
||||||
i++
|
|
||||||
got = string(toServer)
|
|
||||||
if got != test.data[i] {
|
|
||||||
t.Errorf("Invalid response, got %q, want %q", got, test.data[i])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
13
smtp.go
13
smtp.go
|
@ -26,14 +26,19 @@ type SMTPDialer struct {
|
||||||
TLSConfig *tls.Config
|
TLSConfig *tls.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPlainDialer returns an SMTPDialer. The given parameters are used to
|
// NewPlainDialer returns a Dialer. The given parameters are used to connect to
|
||||||
// connect to the SMTP server via a PLAIN authentication mechanism.
|
// the SMTP server via a PLAIN authentication mechanism. It fallbacks to the
|
||||||
|
// LOGIN mechanism if it is the only mechanism advertised by the server.
|
||||||
func NewPlainDialer(host, username, password string, port int) *SMTPDialer {
|
func NewPlainDialer(host, username, password string, port int) *SMTPDialer {
|
||||||
return &SMTPDialer{
|
return &SMTPDialer{
|
||||||
Host: host,
|
Host: host,
|
||||||
Port: port,
|
Port: port,
|
||||||
Auth: smtp.PlainAuth("", username, password, host),
|
Auth: &plainAuth{
|
||||||
SSL: port == 465,
|
username: username,
|
||||||
|
password: password,
|
||||||
|
host: host,
|
||||||
|
},
|
||||||
|
SSL: port == 465,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -9,13 +9,14 @@ import (
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
const (
|
||||||
testHost = "smtp.example.com"
|
|
||||||
testPort = 587
|
testPort = 587
|
||||||
testSSLPort = 465
|
testSSLPort = 465
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
testTLSConn = &tls.Conn{}
|
testTLSConn = &tls.Conn{}
|
||||||
testConfig = &tls.Config{InsecureSkipVerify: true}
|
testConfig = &tls.Config{InsecureSkipVerify: true}
|
||||||
testAuth = smtp.PlainAuth("", "user", "pwd", testHost)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestSMTPDialer(t *testing.T) {
|
func TestSMTPDialer(t *testing.T) {
|
||||||
|
|
Loading…
Reference in New Issue