diff --git a/export.go b/export.go index 65f60fc..e238171 100644 --- a/export.go +++ b/export.go @@ -1,18 +1,29 @@ package gomail import ( - "bytes" "encoding/base64" + "errors" "io" "mime/multipart" "mime/quotedprintable" - "net/mail" "time" ) -// Export converts the message into a net/mail.Message. -func (msg *Message) Export() *mail.Message { - w := newMessageWriter(msg) +// WriteTo implements io.WriterTo. +func (msg *Message) WriteTo(w io.Writer) (int64, error) { + mw := &messageWriter{w: w} + mw.writeMessage(msg) + return mw.n, mw.err +} + +func (w *messageWriter) writeMessage(msg *Message) { + if _, ok := msg.header["Mime-Version"]; !ok { + w.writeString("Mime-Version: 1.0\r\n") + } + if _, ok := msg.header["Date"]; !ok { + w.writeHeader("Date", msg.FormatDate(now())) + } + w.writeHeaders(msg.header) if msg.hasMixedPart() { w.openMultipart("mixed") @@ -26,11 +37,12 @@ func (msg *Message) Export() *mail.Message { w.openMultipart("alternative") } for _, part := range msg.parts { - h := make(map[string][]string) - h["Content-Type"] = []string{part.contentType + "; charset=" + msg.charset} - h["Content-Transfer-Encoding"] = []string{string(msg.encoding)} - - w.write(h, part.body.Bytes(), msg.encoding) + contentType := part.contentType + "; charset=" + msg.charset + w.writeHeaders(map[string][]string{ + "Content-Type": []string{contentType}, + "Content-Transfer-Encoding": []string{string(msg.encoding)}, + }) + w.writeBody(part.body.Bytes(), msg.encoding) } if msg.hasAlternativePart() { w.closeMultipart() @@ -45,8 +57,6 @@ func (msg *Message) Export() *mail.Message { if msg.hasMixedPart() { w.closeMultipart() } - - return w.export() } func (msg *Message) hasMixedPart() bool { @@ -61,52 +71,33 @@ func (msg *Message) hasAlternativePart() bool { return len(msg.parts) > 1 } -// messageWriter helps converting the message into a net/mail.Message type messageWriter struct { - header map[string][]string - buf *bytes.Buffer + w io.Writer + n int64 writers [3]*multipart.Writer partWriter io.Writer depth uint8 + err error } -func newMessageWriter(msg *Message) *messageWriter { - // We copy the header so Export does not modify the message - header := make(map[string][]string, len(msg.header)+2) - for k, v := range msg.header { - header[k] = v - } - - if _, ok := header["Mime-Version"]; !ok { - header["Mime-Version"] = []string{"1.0"} - } - if _, ok := header["Date"]; !ok { - header["Date"] = []string{msg.FormatDate(now())} - } - - return &messageWriter{header: header, buf: new(bytes.Buffer)} -} - -// Stubbed out for testing. -var now = time.Now - func (w *messageWriter) openMultipart(mimeType string) { - w.writers[w.depth] = multipart.NewWriter(w.buf) - contentType := "multipart/" + mimeType + "; boundary=" + w.writers[w.depth].Boundary() + mw := multipart.NewWriter(w) + contentType := "multipart/" + mimeType + "; boundary=" + mw.Boundary() + w.writers[w.depth] = mw if w.depth == 0 { - w.header["Content-Type"] = []string{contentType} + w.writeHeader("Content-Type", contentType) + w.writeString("\r\n") } else { - h := make(map[string][]string) - h["Content-Type"] = []string{contentType} - w.createPart(h) + w.createPart(map[string][]string{ + "Content-Type": []string{contentType}, + }) } w.depth++ } func (w *messageWriter) createPart(h map[string][]string) { - // No need to check the error since the underlying writer is a bytes.Buffer - w.partWriter, _ = w.writers[w.depth-1].CreatePart(h) + w.partWriter, w.err = w.writers[w.depth-1].CreatePart(h) } func (w *messageWriter) closeMultipart() { @@ -131,20 +122,53 @@ func (w *messageWriter) addFiles(files []*File, isAttachment bool) { h["Content-ID"] = []string{"<" + f.Name + ">"} } } - - w.write(h, f.Content, Base64) + w.writeHeaders(h) + w.writeBody(f.Content, Base64) } } -func (w *messageWriter) write(h map[string][]string, body []byte, enc Encoding) { - w.writeHeader(h) - w.writeBody(body, enc) +func (w *messageWriter) Write(p []byte) (int, error) { + if w.err != nil { + return 0, errors.New("gomail: cannot write as writer is in error") + } + + var n int + n, w.err = w.w.Write(p) + w.n += int64(n) + return n, w.err } -func (w *messageWriter) writeHeader(h map[string][]string) { +func (w *messageWriter) writeString(s string) { + n, _ := io.WriteString(w.w, s) + w.n += int64(n) +} + +func (w *messageWriter) writeStrings(a []string, sep string) { + if len(a) > 0 { + w.writeString(a[0]) + if len(a) == 1 { + return + } + } + for _, s := range a[1:] { + w.writeString(sep) + w.writeString(s) + } +} + +func (w *messageWriter) writeHeader(k string, v ...string) { + w.writeString(k) + w.writeString(": ") + w.writeStrings(v, ", ") + w.writeString("\r\n") +} + +func (w *messageWriter) writeHeaders(h map[string][]string) { if w.depth == 0 { - for field, value := range h { - w.header[field] = value + for k, v := range h { + if k != "Bcc" { + w.writeHeader(k, v...) + } } } else { w.createPart(h) @@ -154,30 +178,25 @@ func (w *messageWriter) writeHeader(h map[string][]string) { func (w *messageWriter) writeBody(body []byte, enc Encoding) { var subWriter io.Writer if w.depth == 0 { - subWriter = w.buf + w.writeString("\r\n") + subWriter = w.w } else { subWriter = w.partWriter } - // The errors returned by writers are not checked since these writers cannot - // return errors. if enc == Base64 { - writer := base64.NewEncoder(base64.StdEncoding, newBase64LineWriter(subWriter)) - writer.Write(body) - writer.Close() + wc := base64.NewEncoder(base64.StdEncoding, newBase64LineWriter(subWriter)) + wc.Write(body) + wc.Close() } else if enc == Unencoded { subWriter.Write(body) } else { - writer := quotedprintable.NewWriter(subWriter) - writer.Write(body) - writer.Close() + wc := quotedprintable.NewWriter(subWriter) + wc.Write(body) + wc.Close() } } -func (w *messageWriter) export() *mail.Message { - return &mail.Message{Header: w.header, Body: w.buf} -} - // As required by RFC 2045, 6.7. (page 21) for quoted-printable, and // RFC 2045, 6.8. (page 25) for base64. const maxLineLen = 76 @@ -207,3 +226,6 @@ func (w *base64LineWriter) Write(p []byte) (int, error) { return n + len(p), nil } + +// Stubbed out for testing. +var now = time.Now diff --git a/message_test.go b/message_test.go index 4d837c2..5a40e2d 100644 --- a/message_test.go +++ b/message_test.go @@ -1,9 +1,9 @@ package gomail import ( + "bytes" "encoding/base64" "io" - "io/ioutil" "path/filepath" "regexp" "strconv" @@ -476,7 +476,7 @@ func testMessage(t *testing.T, msg *Message, bCount int, want *message) { } func stubSendMail(t *testing.T, bCount int, want *message) SendFunc { - return func(from string, to []string, msg io.Reader) error { + return func(from string, to []string, msg io.WriterTo) error { if from != want.from { t.Fatalf("Invalid from, got %q, want %q", from, want.from) } @@ -495,11 +495,12 @@ func stubSendMail(t *testing.T, bCount int, want *message) SendFunc { } } - content, err := ioutil.ReadAll(msg) + buf := new(bytes.Buffer) + _, err := msg.WriteTo(buf) if err != nil { t.Error(err) } - got := string(content) + got := buf.String() wantMsg := string("Mime-Version: 1.0\r\n" + "Date: Wed, 25 Jun 2014 17:46:00 +0000\r\n" + want.content) @@ -580,7 +581,7 @@ func getBoundaries(t *testing.T, count int, msg string) []string { var boundaryRegExp = regexp.MustCompile("boundary=(\\w+)") func BenchmarkFull(b *testing.B) { - emptyFunc := func(from string, to []string, msg io.Reader) error { + emptyFunc := func(from string, to []string, msg io.WriterTo) error { return nil } diff --git a/send.go b/send.go index 2c7bedb..a069c64 100644 --- a/send.go +++ b/send.go @@ -1,20 +1,17 @@ package gomail import ( - "bytes" "errors" "fmt" "io" - "io/ioutil" "net/mail" - "strings" ) // Sender is the interface that wraps the Send method. // // Send sends an email to the given addresses. type Sender interface { - Send(from string, to []string, msg io.Reader) error + Send(from string, to []string, msg io.WriterTo) error } // SendCloser is the interface that groups the Send and Close methods. @@ -27,10 +24,10 @@ type SendCloser interface { // The SendFunc type is an adapter to allow the use of ordinary functions as // email senders. If f is a function with the appropriate signature, SendFunc(f) // is a Sender object that calls f. -type SendFunc func(from string, to []string, msg io.Reader) error +type SendFunc func(from string, to []string, msg io.WriterTo) error // Send calls f(from, to, msg). -func (f SendFunc) Send(from string, to []string, msg io.Reader) error { +func (f SendFunc) Send(from string, to []string, msg io.WriterTo) error { return f(from, to, msg) } @@ -46,64 +43,39 @@ func Send(s Sender, msg ...*Message) error { } func send(s Sender, msg *Message) error { - message := msg.Export() - - from, err := getFrom(message) - if err != nil { - return err - } - to, err := getRecipients(message) + from, err := msg.getFrom() if err != nil { return err } - h := flattenHeader(message) - body, err := ioutil.ReadAll(message.Body) + to, err := msg.getRecipients() if err != nil { return err } - mail := bytes.NewReader(append(h, body...)) - if err := s.Send(from, to, mail); err != nil { + if err := s.Send(from, to, msg); err != nil { return err } return nil } -func flattenHeader(msg *mail.Message) []byte { - buf := getBuffer() - defer putBuffer(buf) - - for field, value := range msg.Header { - if field != "Bcc" { - buf.WriteString(field) - buf.WriteString(": ") - buf.WriteString(strings.Join(value, ", ")) - buf.WriteString("\r\n") - } - } - buf.WriteString("\r\n") - - return buf.Bytes() -} - -func getFrom(msg *mail.Message) (string, error) { - from := msg.Header.Get("Sender") - if from == "" { - from = msg.Header.Get("From") - if from == "" { - return "", errors.New("gomail: invalid message, \"From\" field is absent") +func (msg *Message) getFrom() (string, error) { + from := msg.header["Sender"] + if len(from) == 0 { + from = msg.header["From"] + if len(from) == 0 { + return "", errors.New(`gomail: invalid message, "From" field is absent`) } } - return parseAddress(from) + return parseAddress(from[0]) } -func getRecipients(msg *mail.Message) ([]string, error) { +func (msg *Message) getRecipients() ([]string, error) { var list []string for _, field := range []string{"To", "Cc", "Bcc"} { - if addresses, ok := msg.Header[field]; ok { + if addresses, ok := msg.header[field]; ok { for _, a := range addresses { addr, err := parseAddress(a) if err != nil { diff --git a/send_test.go b/send_test.go index f252ac5..ba59cd3 100644 --- a/send_test.go +++ b/send_test.go @@ -1,8 +1,8 @@ package gomail import ( + "bytes" "io" - "io/ioutil" "reflect" "testing" ) @@ -24,7 +24,7 @@ const ( type mockSender SendFunc -func (s mockSender) Send(from string, to []string, msg io.Reader) error { +func (s mockSender) Send(from string, to []string, msg io.WriterTo) error { return s(from, to, msg) } @@ -60,7 +60,7 @@ func getTestMessage() *Message { } func stubSend(t *testing.T, wantFrom string, wantTo []string, wantBody string) mockSender { - return func(from string, to []string, msg io.Reader) error { + return func(from string, to []string, msg io.WriterTo) error { if from != wantFrom { t.Errorf("invalid from, got %q, want %q", from, wantFrom) } @@ -68,11 +68,12 @@ func stubSend(t *testing.T, wantFrom string, wantTo []string, wantBody string) m t.Errorf("invalid to, got %v, want %v", to, wantTo) } - content, err := ioutil.ReadAll(msg) + buf := new(bytes.Buffer) + _, err := msg.WriteTo(buf) if err != nil { t.Fatal(err) } - compareBodies(t, string(content), wantBody) + compareBodies(t, buf.String(), wantBody) return nil } diff --git a/smtp.go b/smtp.go index f82845a..9b5e494 100644 --- a/smtp.go +++ b/smtp.go @@ -122,7 +122,7 @@ type smtpSender struct { smtpClient } -func (c *smtpSender) Send(from string, to []string, msg io.Reader) error { +func (c *smtpSender) Send(from string, to []string, msg io.WriterTo) error { if err := c.Mail(from); err != nil { return err } @@ -138,7 +138,7 @@ func (c *smtpSender) Send(from string, to []string, msg io.Reader) error { return err } - if _, err = io.Copy(w, msg); err != nil { + if _, err = msg.WriteTo(w); err != nil { w.Close() return err } diff --git a/smtp_test.go b/smtp_test.go index 1c17f90..c9b5e48 100644 --- a/smtp_test.go +++ b/smtp_test.go @@ -1,6 +1,7 @@ package gomail import ( + "bytes" "crypto/tls" "io" "net" @@ -173,15 +174,19 @@ func (c *mockClient) do(cmd string) { type mockWriter struct { want string c *mockClient + buf bytes.Buffer } func (w *mockWriter) Write(p []byte) (int, error) { - w.c.do("Write message") - compareBodies(w.c.t, string(p), w.want) + if w.buf.Len() == 0 { + w.c.do("Write message") + } + w.buf.Write(p) return len(p), nil } func (w *mockWriter) Close() error { + compareBodies(w.c.t, w.buf.String(), w.want) w.c.do("Close writer") return nil }