Dialer.Dial() now automatically uses CRAM-MD5 when it's available
Also deprecated NewPlainDialer() in favor of NewDialer(). Fixes #52
This commit is contained in:
parent
6ea1c86967
commit
5ceb8e6541
|
@ -60,7 +60,7 @@ bypass the verification of the server's certificate chain and host name by using
|
|||
)
|
||||
|
||||
func main() {
|
||||
d := gomail.NewPlainDialer("smtp.example.com", 587, "user", "123456")
|
||||
d := gomail.NewDialer("smtp.example.com", 587, "user", "123456")
|
||||
d.TLSConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
|
||||
// Send emails using d.
|
||||
|
|
50
auth.go
50
auth.go
|
@ -7,51 +7,33 @@ import (
|
|||
"net/smtp"
|
||||
)
|
||||
|
||||
// plainAuth is an smtp.Auth that implements the PLAIN authentication mechanism.
|
||||
// It fallbacks to the LOGIN mechanism if it is the only mechanism advertised
|
||||
// by the server.
|
||||
type plainAuth struct {
|
||||
// loginAuth is an smtp.Auth that implements the LOGIN authentication mechanism.
|
||||
type loginAuth struct {
|
||||
username string
|
||||
password string
|
||||
host string
|
||||
login bool
|
||||
}
|
||||
|
||||
func (a *plainAuth) Start(server *smtp.ServerInfo) (string, []byte, error) {
|
||||
func (a *loginAuth) Start(server *smtp.ServerInfo) (string, []byte, error) {
|
||||
if !server.TLS {
|
||||
advertised := false
|
||||
for _, mechanism := range server.Auth {
|
||||
if mechanism == "LOGIN" {
|
||||
advertised = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !advertised {
|
||||
return "", nil, errors.New("gomail: unencrypted connection")
|
||||
}
|
||||
}
|
||||
if server.Name != a.host {
|
||||
return "", nil, errors.New("gomail: wrong host name")
|
||||
}
|
||||
|
||||
var plain, login bool
|
||||
for _, a := range server.Auth {
|
||||
switch a {
|
||||
case "PLAIN":
|
||||
plain = true
|
||||
case "LOGIN":
|
||||
login = true
|
||||
}
|
||||
}
|
||||
|
||||
if !server.TLS && !plain && !login {
|
||||
return "", nil, errors.New("gomail: unencrypted connection")
|
||||
}
|
||||
|
||||
if !plain && login {
|
||||
a.login = true
|
||||
return "LOGIN", nil, nil
|
||||
}
|
||||
|
||||
return "PLAIN", []byte("\x00" + a.username + "\x00" + a.password), nil
|
||||
}
|
||||
|
||||
func (a *plainAuth) Next(fromServer []byte, more bool) ([]byte, error) {
|
||||
if !a.login {
|
||||
if more {
|
||||
return nil, errors.New("gomail: unexpected server challenge")
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (a *loginAuth) Next(fromServer []byte, more bool) ([]byte, error) {
|
||||
if !more {
|
||||
return nil, nil
|
||||
}
|
||||
|
|
82
auth_test.go
82
auth_test.go
|
@ -11,103 +11,51 @@ const (
|
|||
testHost = "smtp.example.com"
|
||||
)
|
||||
|
||||
var testAuth = &plainAuth{
|
||||
username: testUser,
|
||||
password: testPwd,
|
||||
host: testHost,
|
||||
}
|
||||
|
||||
type plainAuthTest struct {
|
||||
type authTest struct {
|
||||
auths []string
|
||||
challenges []string
|
||||
tls bool
|
||||
wantProto string
|
||||
wantData []string
|
||||
wantError bool
|
||||
}
|
||||
|
||||
func TestNoAdvertisement(t *testing.T) {
|
||||
testPlainAuth(t, &plainAuthTest{
|
||||
testLoginAuth(t, &authTest{
|
||||
auths: []string{},
|
||||
challenges: []string{"Username:", "Password:"},
|
||||
tls: false,
|
||||
wantProto: "PLAIN",
|
||||
wantError: true,
|
||||
})
|
||||
}
|
||||
|
||||
func TestNoAdvertisementTLS(t *testing.T) {
|
||||
testPlainAuth(t, &plainAuthTest{
|
||||
testLoginAuth(t, &authTest{
|
||||
auths: []string{},
|
||||
challenges: []string{"Username:", "Password:"},
|
||||
tls: true,
|
||||
wantProto: "PLAIN",
|
||||
wantData: []string{"\x00" + testUser + "\x00" + testPwd},
|
||||
})
|
||||
}
|
||||
|
||||
func TestPlain(t *testing.T) {
|
||||
testPlainAuth(t, &plainAuthTest{
|
||||
auths: []string{"PLAIN"},
|
||||
challenges: []string{"Username:", "Password:"},
|
||||
tls: false,
|
||||
wantProto: "PLAIN",
|
||||
wantData: []string{"\x00" + testUser + "\x00" + testPwd},
|
||||
})
|
||||
}
|
||||
|
||||
func TestPlainTLS(t *testing.T) {
|
||||
testPlainAuth(t, &plainAuthTest{
|
||||
auths: []string{"PLAIN"},
|
||||
challenges: []string{"Username:", "Password:"},
|
||||
tls: true,
|
||||
wantProto: "PLAIN",
|
||||
wantData: []string{"\x00" + testUser + "\x00" + testPwd},
|
||||
})
|
||||
}
|
||||
|
||||
func TestPlainAndLogin(t *testing.T) {
|
||||
testPlainAuth(t, &plainAuthTest{
|
||||
auths: []string{"PLAIN", "LOGIN"},
|
||||
challenges: []string{"Username:", "Password:"},
|
||||
tls: false,
|
||||
wantProto: "PLAIN",
|
||||
wantData: []string{"\x00" + testUser + "\x00" + testPwd},
|
||||
})
|
||||
}
|
||||
|
||||
func TestPlainAndLoginTLS(t *testing.T) {
|
||||
testPlainAuth(t, &plainAuthTest{
|
||||
auths: []string{"PLAIN", "LOGIN"},
|
||||
challenges: []string{"Username:", "Password:"},
|
||||
tls: true,
|
||||
wantProto: "PLAIN",
|
||||
wantData: []string{"\x00" + testUser + "\x00" + testPwd},
|
||||
wantData: []string{"", testUser, testPwd},
|
||||
})
|
||||
}
|
||||
|
||||
func TestLogin(t *testing.T) {
|
||||
testPlainAuth(t, &plainAuthTest{
|
||||
auths: []string{"LOGIN"},
|
||||
testLoginAuth(t, &authTest{
|
||||
auths: []string{"PLAIN", "LOGIN"},
|
||||
challenges: []string{"Username:", "Password:"},
|
||||
tls: false,
|
||||
wantProto: "LOGIN",
|
||||
wantData: []string{"", testUser, testPwd},
|
||||
})
|
||||
}
|
||||
|
||||
func TestLoginTLS(t *testing.T) {
|
||||
testPlainAuth(t, &plainAuthTest{
|
||||
testLoginAuth(t, &authTest{
|
||||
auths: []string{"LOGIN"},
|
||||
challenges: []string{"Username:", "Password:"},
|
||||
tls: true,
|
||||
wantProto: "LOGIN",
|
||||
wantData: []string{"", testUser, testPwd},
|
||||
})
|
||||
}
|
||||
|
||||
func testPlainAuth(t *testing.T, test *plainAuthTest) {
|
||||
auth := &plainAuth{
|
||||
func testLoginAuth(t *testing.T, test *authTest) {
|
||||
auth := &loginAuth{
|
||||
username: testUser,
|
||||
password: testPwd,
|
||||
host: testHost,
|
||||
|
@ -119,13 +67,13 @@ func testPlainAuth(t *testing.T, test *plainAuthTest) {
|
|||
}
|
||||
proto, toServer, err := auth.Start(server)
|
||||
if err != nil && !test.wantError {
|
||||
t.Fatalf("plainAuth.Start(): %v", err)
|
||||
t.Fatalf("loginAuth.Start(): %v", err)
|
||||
}
|
||||
if err != nil && test.wantError {
|
||||
return
|
||||
}
|
||||
if proto != test.wantProto {
|
||||
t.Errorf("invalid protocol, got %q, want %q", proto, test.wantProto)
|
||||
if proto != "LOGIN" {
|
||||
t.Errorf("invalid protocol, got %q, want LOGIN", proto)
|
||||
}
|
||||
|
||||
i := 0
|
||||
|
@ -134,10 +82,6 @@ func testPlainAuth(t *testing.T, test *plainAuthTest) {
|
|||
t.Errorf("Invalid response, got %q, want %q", got, test.wantData[i])
|
||||
}
|
||||
|
||||
if proto == "PLAIN" {
|
||||
return
|
||||
}
|
||||
|
||||
for _, challenge := range test.challenges {
|
||||
i++
|
||||
if i >= len(test.wantData) {
|
||||
|
@ -146,7 +90,7 @@ func testPlainAuth(t *testing.T, test *plainAuthTest) {
|
|||
|
||||
toServer, err = auth.Next([]byte(challenge), true)
|
||||
if err != nil {
|
||||
t.Fatalf("plainAuth.Auth(): %v", err)
|
||||
t.Fatalf("loginAuth.Auth(): %v", err)
|
||||
}
|
||||
got = string(toServer)
|
||||
if got != test.wantData[i] {
|
||||
|
|
|
@ -5,7 +5,6 @@ import (
|
|||
"html/template"
|
||||
"io"
|
||||
"log"
|
||||
"net/smtp"
|
||||
"time"
|
||||
|
||||
"gopkg.in/gomail.v2"
|
||||
|
@ -20,7 +19,7 @@ func Example() {
|
|||
m.SetBody("text/html", "Hello <b>Bob</b> and <i>Cora</i>!")
|
||||
m.Attach("/home/Alex/lolcat.jpg")
|
||||
|
||||
d := gomail.NewPlainDialer("smtp.example.com", 587, "user", "123456")
|
||||
d := gomail.NewDialer("smtp.example.com", 587, "user", "123456")
|
||||
|
||||
// Send the email to Bob, Cora and Dan.
|
||||
if err := d.DialAndSend(m); err != nil {
|
||||
|
@ -33,7 +32,7 @@ func Example_daemon() {
|
|||
ch := make(chan *gomail.Message)
|
||||
|
||||
go func() {
|
||||
d := gomail.NewPlainDialer("smtp.example.com", 587, "user", "123456")
|
||||
d := gomail.NewDialer("smtp.example.com", 587, "user", "123456")
|
||||
|
||||
var s gomail.SendCloser
|
||||
var err error
|
||||
|
@ -80,7 +79,7 @@ func Example_newsletter() {
|
|||
Address string
|
||||
}
|
||||
|
||||
d := gomail.NewPlainDialer("smtp.example.com", 587, "user", "123456")
|
||||
d := gomail.NewDialer("smtp.example.com", 587, "user", "123456")
|
||||
s, err := d.Dial()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
@ -114,24 +113,6 @@ func Example_noAuth() {
|
|||
}
|
||||
}
|
||||
|
||||
// Send an email using the CRAM-MD5 authentication mechanism.
|
||||
func Example_cRAMMD5() {
|
||||
m := gomail.NewMessage()
|
||||
m.SetHeader("From", "from@example.com")
|
||||
m.SetHeader("To", "to@example.com")
|
||||
m.SetHeader("Subject", "Hello!")
|
||||
m.SetBody("text/plain", "Hello!")
|
||||
|
||||
d := gomail.Dialer{
|
||||
Host: "localhost",
|
||||
Port: 587,
|
||||
Auth: smtp.CRAMMD5Auth("username", "secret"),
|
||||
}
|
||||
if err := d.DialAndSend(m); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// Send an email using an API or postfix.
|
||||
func Example_noSMTP() {
|
||||
m := gomail.NewMessage()
|
||||
|
|
94
smtp.go
94
smtp.go
|
@ -6,6 +6,7 @@ import (
|
|||
"io"
|
||||
"net"
|
||||
"net/smtp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// A Dialer is a dialer to an SMTP server.
|
||||
|
@ -14,6 +15,10 @@ type Dialer struct {
|
|||
Host string
|
||||
// Port represents the port of the SMTP server.
|
||||
Port int
|
||||
// Username is the username to use to authenticate to the SMTP server.
|
||||
Username string
|
||||
// Password is the password to use to authenticate to the SMTP server.
|
||||
Password string
|
||||
// Auth represents the authentication mechanism used to authenticate to the
|
||||
// SMTP server.
|
||||
Auth smtp.Auth
|
||||
|
@ -29,53 +34,39 @@ type Dialer struct {
|
|||
LocalName string
|
||||
}
|
||||
|
||||
// NewPlainDialer returns a Dialer. The given parameters are used to connect to
|
||||
// the SMTP server via a PLAIN authentication mechanism.
|
||||
//
|
||||
// It fallbacks to the LOGIN mechanism if it is the only mechanism advertised by
|
||||
// the server.
|
||||
func NewPlainDialer(host string, port int, username, password string) *Dialer {
|
||||
// NewDialer returns a new SMTP Dialer. The given parameters are used to connect
|
||||
// to the SMTP server.
|
||||
func NewDialer(host string, port int, username, password string) *Dialer {
|
||||
return &Dialer{
|
||||
Host: host,
|
||||
Port: port,
|
||||
Auth: &plainAuth{
|
||||
username: username,
|
||||
password: password,
|
||||
host: host,
|
||||
},
|
||||
Username: username,
|
||||
Password: password,
|
||||
SSL: port == 465,
|
||||
}
|
||||
}
|
||||
|
||||
// NewPlainDialer returns a new SMTP Dialer. The given parameters are used to
|
||||
// connect to the SMTP server.
|
||||
//
|
||||
// Deprecated: Use NewDialer instead.
|
||||
func NewPlainDialer(host string, port int, username, password string) *Dialer {
|
||||
return NewDialer(host, port, username, password)
|
||||
}
|
||||
|
||||
// Dial dials and authenticates to an SMTP server. The returned SendCloser
|
||||
// should be closed when done using it.
|
||||
func (d *Dialer) Dial() (SendCloser, error) {
|
||||
c, err := d.dial()
|
||||
conn, err := netDial("tcp", addr(d.Host, d.Port))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if d.Auth != nil {
|
||||
if ok, _ := c.Extension("AUTH"); ok {
|
||||
if err = c.Auth(d.Auth); err != nil {
|
||||
c.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &smtpSender{c}, nil
|
||||
}
|
||||
|
||||
func (d *Dialer) dial() (smtpClient, error) {
|
||||
if d.SSL {
|
||||
return d.sslDial()
|
||||
}
|
||||
return d.starttlsDial()
|
||||
conn = tlsClient(conn, d.tlsConfig())
|
||||
}
|
||||
|
||||
func (d *Dialer) starttlsDial() (smtpClient, error) {
|
||||
c, err := smtpDial(addr(d.Host, d.Port))
|
||||
c, err := smtpNewClient(conn, d.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -86,41 +77,46 @@ func (d *Dialer) starttlsDial() (smtpClient, error) {
|
|||
}
|
||||
}
|
||||
|
||||
if !d.SSL {
|
||||
if ok, _ := c.Extension("STARTTLS"); ok {
|
||||
if err := c.StartTLS(d.tlsConfig()); err != nil {
|
||||
c.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (d *Dialer) sslDial() (smtpClient, error) {
|
||||
conn, err := tlsDial("tcp", addr(d.Host, d.Port), d.tlsConfig())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if d.Auth == nil && d.Username != "" {
|
||||
if ok, auths := c.Extension("AUTH"); ok {
|
||||
if strings.Contains(auths, "CRAM-MD5") {
|
||||
d.Auth = smtp.CRAMMD5Auth(d.Username, d.Password)
|
||||
} else if strings.Contains(auths, "LOGIN") &&
|
||||
!strings.Contains(auths, "PLAIN") {
|
||||
d.Auth = &loginAuth{
|
||||
username: d.Username,
|
||||
password: d.Password,
|
||||
host: d.Host,
|
||||
}
|
||||
} else {
|
||||
d.Auth = smtp.PlainAuth("", d.Username, d.Password, d.Host)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
c, err := newClient(conn, d.Host)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if d.LocalName != "" {
|
||||
if err := c.Hello(d.LocalName); err != nil {
|
||||
if d.Auth != nil {
|
||||
if err = c.Auth(d.Auth); err != nil {
|
||||
c.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
return &smtpSender{c}, nil
|
||||
}
|
||||
|
||||
func (d *Dialer) tlsConfig() *tls.Config {
|
||||
if d.TLSConfig == nil {
|
||||
return &tls.Config{ServerName: d.Host}
|
||||
}
|
||||
|
||||
return d.TLSConfig
|
||||
}
|
||||
|
||||
|
@ -174,11 +170,9 @@ func (c *smtpSender) Close() error {
|
|||
|
||||
// Stubbed out for tests.
|
||||
var (
|
||||
smtpDial = func(addr string) (smtpClient, error) {
|
||||
return smtp.Dial(addr)
|
||||
}
|
||||
tlsDial = tls.Dial
|
||||
newClient = func(conn net.Conn, host string) (smtpClient, error) {
|
||||
netDial = net.Dial
|
||||
tlsClient = tls.Client
|
||||
smtpNewClient = func(conn net.Conn, host string) (smtpClient, error) {
|
||||
return smtp.NewClient(conn, host)
|
||||
}
|
||||
)
|
||||
|
|
54
smtp_test.go
54
smtp_test.go
|
@ -16,12 +16,14 @@ const (
|
|||
)
|
||||
|
||||
var (
|
||||
testConn = &net.TCPConn{}
|
||||
testTLSConn = &tls.Conn{}
|
||||
testConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
testAuth = smtp.PlainAuth("", testUser, testPwd, testHost)
|
||||
)
|
||||
|
||||
func TestDialer(t *testing.T) {
|
||||
d := NewPlainDialer(testHost, testPort, "user", "pwd")
|
||||
d := NewDialer(testHost, testPort, "user", "pwd")
|
||||
testSendMail(t, d, []string{
|
||||
"Extension STARTTLS",
|
||||
"StartTLS",
|
||||
|
@ -39,7 +41,7 @@ func TestDialer(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDialerSSL(t *testing.T) {
|
||||
d := NewPlainDialer(testHost, testSSLPort, "user", "pwd")
|
||||
d := NewDialer(testHost, testSSLPort, "user", "pwd")
|
||||
testSendMail(t, d, []string{
|
||||
"Extension AUTH",
|
||||
"Auth",
|
||||
|
@ -55,7 +57,7 @@ func TestDialerSSL(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDialerConfig(t *testing.T) {
|
||||
d := NewPlainDialer(testHost, testPort, "user", "pwd")
|
||||
d := NewDialer(testHost, testPort, "user", "pwd")
|
||||
d.LocalName = "test"
|
||||
d.TLSConfig = testConfig
|
||||
testSendMail(t, d, []string{
|
||||
|
@ -76,7 +78,7 @@ func TestDialerConfig(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestDialerSSLConfig(t *testing.T) {
|
||||
d := NewPlainDialer(testHost, testSSLPort, "user", "pwd")
|
||||
d := NewDialer(testHost, testSSLPort, "user", "pwd")
|
||||
d.LocalName = "test"
|
||||
d.TLSConfig = testConfig
|
||||
testSendMail(t, d, []string{
|
||||
|
@ -118,7 +120,6 @@ type mockClient struct {
|
|||
i int
|
||||
want []string
|
||||
addr string
|
||||
auth smtp.Auth
|
||||
config *tls.Config
|
||||
}
|
||||
|
||||
|
@ -139,7 +140,9 @@ func (c *mockClient) StartTLS(config *tls.Config) error {
|
|||
}
|
||||
|
||||
func (c *mockClient) Auth(a smtp.Auth) error {
|
||||
assertAuth(c.t, a, c.auth)
|
||||
if !reflect.DeepEqual(a, testAuth) {
|
||||
c.t.Errorf("Invalid auth, got %#v, want %#v", a, testAuth)
|
||||
}
|
||||
c.do("Auth")
|
||||
return nil
|
||||
}
|
||||
|
@ -205,28 +208,29 @@ func testSendMail(t *testing.T, d *Dialer, want []string) {
|
|||
t: t,
|
||||
want: want,
|
||||
addr: addr(d.Host, d.Port),
|
||||
auth: testAuth,
|
||||
config: d.TLSConfig,
|
||||
}
|
||||
|
||||
smtpDial = func(addr string) (smtpClient, error) {
|
||||
assertAddr(t, addr, testClient.addr)
|
||||
return testClient, nil
|
||||
}
|
||||
|
||||
tlsDial = func(network, addr string, config *tls.Config) (*tls.Conn, error) {
|
||||
netDial = func(network, address string) (net.Conn, error) {
|
||||
if network != "tcp" {
|
||||
t.Errorf("Invalid network, got %q, want tcp", network)
|
||||
}
|
||||
assertAddr(t, addr, testClient.addr)
|
||||
assertConfig(t, config, testClient.config)
|
||||
return testTLSConn, nil
|
||||
if address != testClient.addr {
|
||||
t.Errorf("Invalid address, got %q, want %q",
|
||||
address, testClient.addr)
|
||||
}
|
||||
return testConn, nil
|
||||
}
|
||||
|
||||
newClient = func(conn net.Conn, host string) (smtpClient, error) {
|
||||
if conn != testTLSConn {
|
||||
t.Error("Invalid TLS connection used")
|
||||
tlsClient = func(conn net.Conn, config *tls.Config) *tls.Conn {
|
||||
if conn != testConn {
|
||||
t.Errorf("Invalid conn, got %#v, want %#v", conn, testConn)
|
||||
}
|
||||
assertConfig(t, config, testClient.config)
|
||||
return testTLSConn
|
||||
}
|
||||
|
||||
smtpNewClient = func(conn net.Conn, host string) (smtpClient, error) {
|
||||
if host != testHost {
|
||||
t.Errorf("Invalid host, got %q, want %q", host, testHost)
|
||||
}
|
||||
|
@ -238,18 +242,6 @@ func testSendMail(t *testing.T, d *Dialer, want []string) {
|
|||
}
|
||||
}
|
||||
|
||||
func assertAuth(t *testing.T, got, want smtp.Auth) {
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("Invalid auth, got %#v, want %#v", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func assertAddr(t *testing.T, got, want string) {
|
||||
if got != want {
|
||||
t.Errorf("Invalid addr, got %q, want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
func assertConfig(t *testing.T, got, want *tls.Config) {
|
||||
if want == nil {
|
||||
want = &tls.Config{ServerName: testHost}
|
||||
|
|
Loading…
Reference in New Issue