package connection import ( "bytes" "errors" "fmt" "io" "net" "os" "path/filepath" "strconv" "github.com/rs/zerolog/log" "golang.org/x/crypto/ssh" ) type SSHConn struct { addr string client *ssh.Client } var _ IConnection = (*SSHConn)(nil) var ( ErrSSHDial = errors.New("unable to dial ssh addr") ErrSSHConn = errors.New("unable to establish a new connection") ErrSShCopy = errors.New("unable to copy file") ErrSSHSession = errors.New("unable to open a new session") ErrSSHReadPrivateKey = errors.New("unable to read private key") ErrSSHParsePrivateKey = errors.New("unable to read private key") ErrSSHExecute = errors.New("unable to execute command") ) func NewSSHConn(addr, user string, port int, privkey string) (SSHConn, error) { var newconn SSHConn sshAddr := addr + ":" + strconv.Itoa(port) newconn.addr = sshAddr conn, err := net.Dial("tcp", sshAddr) if err != nil { return newconn, fmt.Errorf("%w, addr=%s, err=%v", ErrSSHDial, addr, err) } c, err := os.ReadFile(privkey) if err != nil { return newconn, fmt.Errorf("%w, privkey=%s, err=%v", ErrSSHReadPrivateKey, privkey, err) } sshPrivKey, err := ssh.ParsePrivateKey(c) if err != nil { return newconn, fmt.Errorf("%w, privkey=%s, err=%v", ErrSSHParsePrivateKey, privkey, err) } sshConfig := ssh.ClientConfig{ User: user, HostKeyCallback: ssh.InsecureIgnoreHostKey(), //nolint:gosec // no need Auth: []ssh.AuthMethod{ ssh.PublicKeys(sshPrivKey), }, } sshConn, chNewChannel, chReq, err := ssh.NewClientConn(conn, sshAddr, &sshConfig) if err != nil { return newconn, fmt.Errorf("%w, addr=%s, err=%v", ErrSSHConn, sshAddr, err) } sshClient := ssh.NewClient(sshConn, chNewChannel, chReq) newconn.client = sshClient return newconn, nil } func (c *SSHConn) Close() error { return c.client.Close() } func (c *SSHConn) CopyFile(src, dest string) error { sshSession, err := c.client.NewSession() if err != nil { return fmt.Errorf("%w, addr=%s, err=%v", ErrSSHSession, c.addr, err) } defer sshSession.Close() //nolint: errcheck // defered fileInfo, err := os.Stat(src) if err != nil { return fmt.Errorf("unable to stat scp source file src=%s, err=%v", src, err) } file, err := os.Open(src) if err != nil { return fmt.Errorf("unable to open scp source file src=%s, err=%v", src, err) } defer file.Close() //nolint: errcheck // defered go func() { w, _ := sshSession.StdinPipe() defer w.Close() //nolint: errcheck // defered if _, err := fmt.Fprintf(w, "C0644 %d %s\n", fileInfo.Size(), filepath.Base(dest)); err != nil { log.Debug(). Err(err). Str("src", src). Str("dest", dest). Msg("unable to write file info to scp") return } if _, err := io.Copy(w, file); err != nil { log.Debug().Err(err).Str("src", src).Str("dest", dest).Msg("unable to scp src to dest") return } if _, err := fmt.Fprint(w, "\x00"); err != nil { log.Debug(). Err(err). Str("src", src). Str("dest", dest). Msg("unable to write scp termination string") } }() if err := sshSession.Run(fmt.Sprintf("scp -t %s", dest)); err != nil { return fmt.Errorf( "%w, addr=%s, src=%s, dest=%s, err=%v", ErrSShCopy, c.addr, src, dest, err, ) } return nil } func (c *SSHConn) Execute(cmd string) (string, error) { sshSession, err := c.client.NewSession() if err != nil { return "", fmt.Errorf("%w, addr=%s, err=%v", ErrSSHSession, c.addr, err) } defer sshSession.Close() //nolint: errcheck // defered var buf bytes.Buffer sshSession.Stdout = &buf if err := sshSession.Run(cmd); err != nil { return "", fmt.Errorf("%w, addr=%s, cmd=%s, err=%v", ErrSSHExecute, c.addr, cmd, err) } return "", nil }