Replaced Message.Export by Message.WriteTo

Message now implement io.WriterTo. It allows streaming
the message directly to the SMTP server without buffering it
first into memory.
This commit is contained in:
Alexandre Cesaro 2015-07-02 23:08:18 +02:00
parent 8de74d4f48
commit 4f6286485b
6 changed files with 122 additions and 121 deletions

150
export.go
View File

@ -1,18 +1,29 @@
package gomail package gomail
import ( import (
"bytes"
"encoding/base64" "encoding/base64"
"errors"
"io" "io"
"mime/multipart" "mime/multipart"
"mime/quotedprintable" "mime/quotedprintable"
"net/mail"
"time" "time"
) )
// Export converts the message into a net/mail.Message. // WriteTo implements io.WriterTo.
func (msg *Message) Export() *mail.Message { func (msg *Message) WriteTo(w io.Writer) (int64, error) {
w := newMessageWriter(msg) mw := &messageWriter{w: w}
mw.writeMessage(msg)
return mw.n, mw.err
}
func (w *messageWriter) writeMessage(msg *Message) {
if _, ok := msg.header["Mime-Version"]; !ok {
w.writeString("Mime-Version: 1.0\r\n")
}
if _, ok := msg.header["Date"]; !ok {
w.writeHeader("Date", msg.FormatDate(now()))
}
w.writeHeaders(msg.header)
if msg.hasMixedPart() { if msg.hasMixedPart() {
w.openMultipart("mixed") w.openMultipart("mixed")
@ -26,11 +37,12 @@ func (msg *Message) Export() *mail.Message {
w.openMultipart("alternative") w.openMultipart("alternative")
} }
for _, part := range msg.parts { for _, part := range msg.parts {
h := make(map[string][]string) contentType := part.contentType + "; charset=" + msg.charset
h["Content-Type"] = []string{part.contentType + "; charset=" + msg.charset} w.writeHeaders(map[string][]string{
h["Content-Transfer-Encoding"] = []string{string(msg.encoding)} "Content-Type": []string{contentType},
"Content-Transfer-Encoding": []string{string(msg.encoding)},
w.write(h, part.body.Bytes(), msg.encoding) })
w.writeBody(part.body.Bytes(), msg.encoding)
} }
if msg.hasAlternativePart() { if msg.hasAlternativePart() {
w.closeMultipart() w.closeMultipart()
@ -45,8 +57,6 @@ func (msg *Message) Export() *mail.Message {
if msg.hasMixedPart() { if msg.hasMixedPart() {
w.closeMultipart() w.closeMultipart()
} }
return w.export()
} }
func (msg *Message) hasMixedPart() bool { func (msg *Message) hasMixedPart() bool {
@ -61,52 +71,33 @@ func (msg *Message) hasAlternativePart() bool {
return len(msg.parts) > 1 return len(msg.parts) > 1
} }
// messageWriter helps converting the message into a net/mail.Message
type messageWriter struct { type messageWriter struct {
header map[string][]string w io.Writer
buf *bytes.Buffer n int64
writers [3]*multipart.Writer writers [3]*multipart.Writer
partWriter io.Writer partWriter io.Writer
depth uint8 depth uint8
err error
} }
func newMessageWriter(msg *Message) *messageWriter {
// We copy the header so Export does not modify the message
header := make(map[string][]string, len(msg.header)+2)
for k, v := range msg.header {
header[k] = v
}
if _, ok := header["Mime-Version"]; !ok {
header["Mime-Version"] = []string{"1.0"}
}
if _, ok := header["Date"]; !ok {
header["Date"] = []string{msg.FormatDate(now())}
}
return &messageWriter{header: header, buf: new(bytes.Buffer)}
}
// Stubbed out for testing.
var now = time.Now
func (w *messageWriter) openMultipart(mimeType string) { func (w *messageWriter) openMultipart(mimeType string) {
w.writers[w.depth] = multipart.NewWriter(w.buf) mw := multipart.NewWriter(w)
contentType := "multipart/" + mimeType + "; boundary=" + w.writers[w.depth].Boundary() contentType := "multipart/" + mimeType + "; boundary=" + mw.Boundary()
w.writers[w.depth] = mw
if w.depth == 0 { if w.depth == 0 {
w.header["Content-Type"] = []string{contentType} w.writeHeader("Content-Type", contentType)
w.writeString("\r\n")
} else { } else {
h := make(map[string][]string) w.createPart(map[string][]string{
h["Content-Type"] = []string{contentType} "Content-Type": []string{contentType},
w.createPart(h) })
} }
w.depth++ w.depth++
} }
func (w *messageWriter) createPart(h map[string][]string) { func (w *messageWriter) createPart(h map[string][]string) {
// No need to check the error since the underlying writer is a bytes.Buffer w.partWriter, w.err = w.writers[w.depth-1].CreatePart(h)
w.partWriter, _ = w.writers[w.depth-1].CreatePart(h)
} }
func (w *messageWriter) closeMultipart() { func (w *messageWriter) closeMultipart() {
@ -131,20 +122,53 @@ func (w *messageWriter) addFiles(files []*File, isAttachment bool) {
h["Content-ID"] = []string{"<" + f.Name + ">"} h["Content-ID"] = []string{"<" + f.Name + ">"}
} }
} }
w.writeHeaders(h)
w.write(h, f.Content, Base64) w.writeBody(f.Content, Base64)
} }
} }
func (w *messageWriter) write(h map[string][]string, body []byte, enc Encoding) { func (w *messageWriter) Write(p []byte) (int, error) {
w.writeHeader(h) if w.err != nil {
w.writeBody(body, enc) return 0, errors.New("gomail: cannot write as writer is in error")
}
var n int
n, w.err = w.w.Write(p)
w.n += int64(n)
return n, w.err
} }
func (w *messageWriter) writeHeader(h map[string][]string) { func (w *messageWriter) writeString(s string) {
n, _ := io.WriteString(w.w, s)
w.n += int64(n)
}
func (w *messageWriter) writeStrings(a []string, sep string) {
if len(a) > 0 {
w.writeString(a[0])
if len(a) == 1 {
return
}
}
for _, s := range a[1:] {
w.writeString(sep)
w.writeString(s)
}
}
func (w *messageWriter) writeHeader(k string, v ...string) {
w.writeString(k)
w.writeString(": ")
w.writeStrings(v, ", ")
w.writeString("\r\n")
}
func (w *messageWriter) writeHeaders(h map[string][]string) {
if w.depth == 0 { if w.depth == 0 {
for field, value := range h { for k, v := range h {
w.header[field] = value if k != "Bcc" {
w.writeHeader(k, v...)
}
} }
} else { } else {
w.createPart(h) w.createPart(h)
@ -154,30 +178,25 @@ func (w *messageWriter) writeHeader(h map[string][]string) {
func (w *messageWriter) writeBody(body []byte, enc Encoding) { func (w *messageWriter) writeBody(body []byte, enc Encoding) {
var subWriter io.Writer var subWriter io.Writer
if w.depth == 0 { if w.depth == 0 {
subWriter = w.buf w.writeString("\r\n")
subWriter = w.w
} else { } else {
subWriter = w.partWriter subWriter = w.partWriter
} }
// The errors returned by writers are not checked since these writers cannot
// return errors.
if enc == Base64 { if enc == Base64 {
writer := base64.NewEncoder(base64.StdEncoding, newBase64LineWriter(subWriter)) wc := base64.NewEncoder(base64.StdEncoding, newBase64LineWriter(subWriter))
writer.Write(body) wc.Write(body)
writer.Close() wc.Close()
} else if enc == Unencoded { } else if enc == Unencoded {
subWriter.Write(body) subWriter.Write(body)
} else { } else {
writer := quotedprintable.NewWriter(subWriter) wc := quotedprintable.NewWriter(subWriter)
writer.Write(body) wc.Write(body)
writer.Close() wc.Close()
} }
} }
func (w *messageWriter) export() *mail.Message {
return &mail.Message{Header: w.header, Body: w.buf}
}
// As required by RFC 2045, 6.7. (page 21) for quoted-printable, and // As required by RFC 2045, 6.7. (page 21) for quoted-printable, and
// RFC 2045, 6.8. (page 25) for base64. // RFC 2045, 6.8. (page 25) for base64.
const maxLineLen = 76 const maxLineLen = 76
@ -207,3 +226,6 @@ func (w *base64LineWriter) Write(p []byte) (int, error) {
return n + len(p), nil return n + len(p), nil
} }
// Stubbed out for testing.
var now = time.Now

View File

@ -1,9 +1,9 @@
package gomail package gomail
import ( import (
"bytes"
"encoding/base64" "encoding/base64"
"io" "io"
"io/ioutil"
"path/filepath" "path/filepath"
"regexp" "regexp"
"strconv" "strconv"
@ -476,7 +476,7 @@ func testMessage(t *testing.T, msg *Message, bCount int, want *message) {
} }
func stubSendMail(t *testing.T, bCount int, want *message) SendFunc { func stubSendMail(t *testing.T, bCount int, want *message) SendFunc {
return func(from string, to []string, msg io.Reader) error { return func(from string, to []string, msg io.WriterTo) error {
if from != want.from { if from != want.from {
t.Fatalf("Invalid from, got %q, want %q", from, want.from) t.Fatalf("Invalid from, got %q, want %q", from, want.from)
} }
@ -495,11 +495,12 @@ func stubSendMail(t *testing.T, bCount int, want *message) SendFunc {
} }
} }
content, err := ioutil.ReadAll(msg) buf := new(bytes.Buffer)
_, err := msg.WriteTo(buf)
if err != nil { if err != nil {
t.Error(err) t.Error(err)
} }
got := string(content) got := buf.String()
wantMsg := string("Mime-Version: 1.0\r\n" + wantMsg := string("Mime-Version: 1.0\r\n" +
"Date: Wed, 25 Jun 2014 17:46:00 +0000\r\n" + "Date: Wed, 25 Jun 2014 17:46:00 +0000\r\n" +
want.content) want.content)
@ -580,7 +581,7 @@ func getBoundaries(t *testing.T, count int, msg string) []string {
var boundaryRegExp = regexp.MustCompile("boundary=(\\w+)") var boundaryRegExp = regexp.MustCompile("boundary=(\\w+)")
func BenchmarkFull(b *testing.B) { func BenchmarkFull(b *testing.B) {
emptyFunc := func(from string, to []string, msg io.Reader) error { emptyFunc := func(from string, to []string, msg io.WriterTo) error {
return nil return nil
} }

58
send.go
View File

@ -1,20 +1,17 @@
package gomail package gomail
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"net/mail" "net/mail"
"strings"
) )
// Sender is the interface that wraps the Send method. // Sender is the interface that wraps the Send method.
// //
// Send sends an email to the given addresses. // Send sends an email to the given addresses.
type Sender interface { type Sender interface {
Send(from string, to []string, msg io.Reader) error Send(from string, to []string, msg io.WriterTo) error
} }
// SendCloser is the interface that groups the Send and Close methods. // SendCloser is the interface that groups the Send and Close methods.
@ -27,10 +24,10 @@ type SendCloser interface {
// The SendFunc type is an adapter to allow the use of ordinary functions as // The SendFunc type is an adapter to allow the use of ordinary functions as
// email senders. If f is a function with the appropriate signature, SendFunc(f) // email senders. If f is a function with the appropriate signature, SendFunc(f)
// is a Sender object that calls f. // is a Sender object that calls f.
type SendFunc func(from string, to []string, msg io.Reader) error type SendFunc func(from string, to []string, msg io.WriterTo) error
// Send calls f(from, to, msg). // Send calls f(from, to, msg).
func (f SendFunc) Send(from string, to []string, msg io.Reader) error { func (f SendFunc) Send(from string, to []string, msg io.WriterTo) error {
return f(from, to, msg) return f(from, to, msg)
} }
@ -46,64 +43,39 @@ func Send(s Sender, msg ...*Message) error {
} }
func send(s Sender, msg *Message) error { func send(s Sender, msg *Message) error {
message := msg.Export() from, err := msg.getFrom()
from, err := getFrom(message)
if err != nil {
return err
}
to, err := getRecipients(message)
if err != nil { if err != nil {
return err return err
} }
h := flattenHeader(message) to, err := msg.getRecipients()
body, err := ioutil.ReadAll(message.Body)
if err != nil { if err != nil {
return err return err
} }
mail := bytes.NewReader(append(h, body...)) if err := s.Send(from, to, msg); err != nil {
if err := s.Send(from, to, mail); err != nil {
return err return err
} }
return nil return nil
} }
func flattenHeader(msg *mail.Message) []byte { func (msg *Message) getFrom() (string, error) {
buf := getBuffer() from := msg.header["Sender"]
defer putBuffer(buf) if len(from) == 0 {
from = msg.header["From"]
for field, value := range msg.Header { if len(from) == 0 {
if field != "Bcc" { return "", errors.New(`gomail: invalid message, "From" field is absent`)
buf.WriteString(field)
buf.WriteString(": ")
buf.WriteString(strings.Join(value, ", "))
buf.WriteString("\r\n")
}
}
buf.WriteString("\r\n")
return buf.Bytes()
}
func getFrom(msg *mail.Message) (string, error) {
from := msg.Header.Get("Sender")
if from == "" {
from = msg.Header.Get("From")
if from == "" {
return "", errors.New("gomail: invalid message, \"From\" field is absent")
} }
} }
return parseAddress(from) return parseAddress(from[0])
} }
func getRecipients(msg *mail.Message) ([]string, error) { func (msg *Message) getRecipients() ([]string, error) {
var list []string var list []string
for _, field := range []string{"To", "Cc", "Bcc"} { for _, field := range []string{"To", "Cc", "Bcc"} {
if addresses, ok := msg.Header[field]; ok { if addresses, ok := msg.header[field]; ok {
for _, a := range addresses { for _, a := range addresses {
addr, err := parseAddress(a) addr, err := parseAddress(a)
if err != nil { if err != nil {

View File

@ -1,8 +1,8 @@
package gomail package gomail
import ( import (
"bytes"
"io" "io"
"io/ioutil"
"reflect" "reflect"
"testing" "testing"
) )
@ -24,7 +24,7 @@ const (
type mockSender SendFunc type mockSender SendFunc
func (s mockSender) Send(from string, to []string, msg io.Reader) error { func (s mockSender) Send(from string, to []string, msg io.WriterTo) error {
return s(from, to, msg) return s(from, to, msg)
} }
@ -60,7 +60,7 @@ func getTestMessage() *Message {
} }
func stubSend(t *testing.T, wantFrom string, wantTo []string, wantBody string) mockSender { func stubSend(t *testing.T, wantFrom string, wantTo []string, wantBody string) mockSender {
return func(from string, to []string, msg io.Reader) error { return func(from string, to []string, msg io.WriterTo) error {
if from != wantFrom { if from != wantFrom {
t.Errorf("invalid from, got %q, want %q", from, wantFrom) t.Errorf("invalid from, got %q, want %q", from, wantFrom)
} }
@ -68,11 +68,12 @@ func stubSend(t *testing.T, wantFrom string, wantTo []string, wantBody string) m
t.Errorf("invalid to, got %v, want %v", to, wantTo) t.Errorf("invalid to, got %v, want %v", to, wantTo)
} }
content, err := ioutil.ReadAll(msg) buf := new(bytes.Buffer)
_, err := msg.WriteTo(buf)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
compareBodies(t, string(content), wantBody) compareBodies(t, buf.String(), wantBody)
return nil return nil
} }

View File

@ -122,7 +122,7 @@ type smtpSender struct {
smtpClient smtpClient
} }
func (c *smtpSender) Send(from string, to []string, msg io.Reader) error { func (c *smtpSender) Send(from string, to []string, msg io.WriterTo) error {
if err := c.Mail(from); err != nil { if err := c.Mail(from); err != nil {
return err return err
} }
@ -138,7 +138,7 @@ func (c *smtpSender) Send(from string, to []string, msg io.Reader) error {
return err return err
} }
if _, err = io.Copy(w, msg); err != nil { if _, err = msg.WriteTo(w); err != nil {
w.Close() w.Close()
return err return err
} }

View File

@ -1,6 +1,7 @@
package gomail package gomail
import ( import (
"bytes"
"crypto/tls" "crypto/tls"
"io" "io"
"net" "net"
@ -173,15 +174,19 @@ func (c *mockClient) do(cmd string) {
type mockWriter struct { type mockWriter struct {
want string want string
c *mockClient c *mockClient
buf bytes.Buffer
} }
func (w *mockWriter) Write(p []byte) (int, error) { func (w *mockWriter) Write(p []byte) (int, error) {
w.c.do("Write message") if w.buf.Len() == 0 {
compareBodies(w.c.t, string(p), w.want) w.c.do("Write message")
}
w.buf.Write(p)
return len(p), nil return len(p), nil
} }
func (w *mockWriter) Close() error { func (w *mockWriter) Close() error {
compareBodies(w.c.t, w.buf.String(), w.want)
w.c.do("Close writer") w.c.do("Close writer")
return nil return nil
} }