package server import ( "errors" "io" "log" "net" "git.lovezsh.com/lovezsh/nip/internal/pkg/encoding" "git.lovezsh.com/lovezsh/nip/internal/proto" ) type Server struct { token string addr string } func New(token string, addr string) *Server { return &Server{ token: token, addr: addr, } } func (s *Server) ListenAndServe() (err error) { l, err := net.Listen("tcp", s.addr) if err != nil { return err } defer l.Close() for { conn, err := l.Accept() if err != nil { return err } go s.Serve(conn) } } func (s *Server) handshake(conn net.Conn) (target net.Conn, err error) { var req proto.Request if err = encoding.NewDecoder(conn).Decode(&req); err != nil { return } if req.Token != s.token { encoding.NewEncoder(conn).Encode(&proto.Response{Code: 1, Message: "token invalid"}) return nil, errors.New("token invalid") } target, err = net.Dial("tcp", req.Target) if err != nil { encoding.NewEncoder(conn).Encode(&proto.Response{Code: 2, Message: err.Error()}) return nil, err } if err = encoding.NewEncoder(conn).Encode(&proto.Response{Code: 0}); err != nil { target.Close() return nil, err } return target, nil } func (s *Server) Serve(conn net.Conn) { defer conn.Close() dst, err := s.handshake(conn) if err != nil { log.Printf("handshake error: %v", err) return } defer dst.Close() ch := make(chan error) go func() { _, err := io.Copy(dst, conn) ch <- err }() go func() { _, err := io.Copy(conn, dst) ch <- err }() if err := <-ch; err != nil { log.Printf("data stream copy error: %v", err) } }