nip/internal/server/server.go

89 lines
1.6 KiB
Go

package server
import (
"errors"
"io"
"log"
"net"
"git.lovezsh.com/lovezsh/nip/internal/pkg/encoding"
"git.lovezsh.com/lovezsh/nip/internal/proto"
)
type Server struct {
token string
addr string
}
func New(token string, addr string) *Server {
return &Server{
token: token,
addr: addr,
}
}
func (s *Server) ListenAndServe() (err error) {
l, err := net.Listen("tcp", s.addr)
if err != nil {
return err
}
defer l.Close()
for {
conn, err := l.Accept()
if err != nil {
return err
}
go s.Serve(conn)
}
}
func (s *Server) handshake(conn net.Conn) (target net.Conn, err error) {
var req proto.Request
if err = encoding.NewDecoder(conn).Decode(&req); err != nil {
return
}
if req.Token != s.token {
encoding.NewEncoder(conn).Encode(&proto.Response{Code: 1, Message: "token invalid"})
return nil, errors.New("token invalid")
}
target, err = net.Dial("tcp", req.Target)
if err != nil {
encoding.NewEncoder(conn).Encode(&proto.Response{Code: 2, Message: err.Error()})
return nil, err
}
if err = encoding.NewEncoder(conn).Encode(&proto.Response{Code: 0}); err != nil {
target.Close()
return nil, err
}
return target, nil
}
func (s *Server) Serve(conn net.Conn) {
defer conn.Close()
dst, err := s.handshake(conn)
if err != nil {
log.Printf("handshake error: %v", err)
return
}
defer dst.Close()
ch := make(chan error)
go func() {
_, err := io.Copy(dst, conn)
ch <- err
}()
go func() {
_, err := io.Copy(conn, dst)
ch <- err
}()
if err := <-ch; err != nil {
log.Printf("data stream copy error: %v", err)
}
}