gomail/smtp_test.go

249 lines
4.8 KiB
Go
Raw Normal View History

package gomail
import (
"crypto/tls"
"io"
"net"
"net/smtp"
"reflect"
"testing"
)
var (
testHost = "smtp.example.com"
testPort = 587
testSSLPort = 465
testTLSConn = &tls.Conn{}
testConfig = &tls.Config{InsecureSkipVerify: true}
testAuth = smtp.PlainAuth("", "user", "pwd", testHost)
)
func TestSMTPDialer(t *testing.T) {
d := NewPlainDialer(testHost, "user", "pwd", testPort)
testSendMail(t, d, []string{
"Extension STARTTLS",
"StartTLS",
"Extension AUTH",
"Auth",
"Mail " + testFrom,
"Rcpt " + testTo1,
"Rcpt " + testTo2,
"Data",
"Write message",
"Close writer",
"Quit",
"Close",
})
}
func TestSMTPDialerSSL(t *testing.T) {
d := NewPlainDialer(testHost, "user", "pwd", testSSLPort)
testSendMail(t, d, []string{
"Extension AUTH",
"Auth",
"Mail " + testFrom,
"Rcpt " + testTo1,
"Rcpt " + testTo2,
"Data",
"Write message",
"Close writer",
"Quit",
"Close",
})
}
func TestSMTPDialerConfig(t *testing.T) {
d := NewPlainDialer(testHost, "user", "pwd", testPort)
d.TLSConfig = testConfig
testSendMail(t, d, []string{
"Extension STARTTLS",
"StartTLS",
"Extension AUTH",
"Auth",
"Mail " + testFrom,
"Rcpt " + testTo1,
"Rcpt " + testTo2,
"Data",
"Write message",
"Close writer",
"Quit",
"Close",
})
}
func TestSMTPDialerSSLConfig(t *testing.T) {
d := NewPlainDialer(testHost, "user", "pwd", testSSLPort)
d.TLSConfig = testConfig
testSendMail(t, d, []string{
"Extension AUTH",
"Auth",
"Mail " + testFrom,
"Rcpt " + testTo1,
"Rcpt " + testTo2,
"Data",
"Write message",
"Close writer",
"Quit",
"Close",
})
}
func TestSMTPDialerNoAuth(t *testing.T) {
d := &SMTPDialer{
Host: testHost,
Port: testPort,
}
testSendMail(t, d, []string{
"Extension STARTTLS",
"StartTLS",
"Mail " + testFrom,
"Rcpt " + testTo1,
"Rcpt " + testTo2,
"Data",
"Write message",
"Close writer",
"Quit",
"Close",
})
}
type mockClient struct {
t *testing.T
i int
want []string
addr string
auth smtp.Auth
config *tls.Config
}
func (c *mockClient) Extension(ext string) (bool, string) {
c.do("Extension " + ext)
return true, ""
}
func (c *mockClient) StartTLS(config *tls.Config) error {
assertConfig(c.t, config, c.config)
c.do("StartTLS")
return nil
}
func (c *mockClient) Auth(a smtp.Auth) error {
assertAuth(c.t, a, c.auth)
c.do("Auth")
return nil
}
func (c *mockClient) Mail(from string) error {
c.do("Mail " + from)
return nil
}
func (c *mockClient) Rcpt(to string) error {
c.do("Rcpt " + to)
return nil
}
func (c *mockClient) Data() (io.WriteCloser, error) {
c.do("Data")
return &mockWriter{c: c, want: testMsg}, nil
}
func (c *mockClient) Quit() error {
c.do("Quit")
return nil
}
func (c *mockClient) Close() error {
c.do("Close")
return nil
}
func (c *mockClient) do(cmd string) {
if c.i >= len(c.want) {
c.t.Fatalf("Invalid command %q", cmd)
}
if cmd != c.want[c.i] {
c.t.Fatalf("Invalid command, got %q, want %q", cmd, c.want[c.i])
}
c.i++
}
type mockWriter struct {
want string
c *mockClient
}
func (w *mockWriter) Write(p []byte) (int, error) {
w.c.do("Write message")
compareBodies(w.c.t, string(p), w.want)
return len(p), nil
}
func (w *mockWriter) Close() error {
w.c.do("Close writer")
return nil
}
func testSendMail(t *testing.T, d *SMTPDialer, want []string) {
testClient := &mockClient{
t: t,
want: want,
addr: addr(d.Host, d.Port),
auth: testAuth,
config: d.TLSConfig,
}
smtpDial = func(addr string) (smtpClient, error) {
assertAddr(t, addr, testClient.addr)
return testClient, nil
}
tlsDial = func(network, addr string, config *tls.Config) (*tls.Conn, error) {
if network != "tcp" {
t.Errorf("Invalid network, got %q, want tcp", network)
}
assertAddr(t, addr, testClient.addr)
assertConfig(t, config, testClient.config)
return testTLSConn, nil
}
newClient = func(conn net.Conn, host string) (smtpClient, error) {
if conn != testTLSConn {
t.Error("Invalid TLS connection used")
}
if host != testHost {
t.Errorf("Invalid host, got %q, want %q", host, testHost)
}
return testClient, nil
}
if err := d.DialAndSend(getTestMessage()); err != nil {
t.Error(err)
}
}
func assertAuth(t *testing.T, got, want smtp.Auth) {
if !reflect.DeepEqual(got, want) {
t.Errorf("Invalid auth, got %#v, want %#v", got, want)
}
}
func assertAddr(t *testing.T, got, want string) {
if got != want {
t.Errorf("Invalid addr, got %q, want %q", got, want)
}
}
func assertConfig(t *testing.T, got, want *tls.Config) {
if want == nil {
want = &tls.Config{ServerName: testHost}
}
if got.ServerName != want.ServerName {
t.Errorf("Invalid field ServerName in config, got %q, want %q", got.ServerName, want.ServerName)
}
if got.InsecureSkipVerify != want.InsecureSkipVerify {
t.Errorf("Invalid field InsecureSkipVerify in config, got %v, want %v", got.InsecureSkipVerify, want.InsecureSkipVerify)
}
}