163 lines
4.2 KiB
Go
163 lines
4.2 KiB
Go
package connection
|
|
|
|
import (
|
|
"bytes"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"strconv"
|
|
|
|
"github.com/rs/zerolog/log"
|
|
"golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
type SSHConn struct {
|
|
client *ssh.Client
|
|
addr string
|
|
}
|
|
|
|
var _ IConnection = (*SSHConn)(nil)
|
|
|
|
var (
|
|
ErrSSHDial = errors.New("unable to dial ssh addr")
|
|
ErrSSHConn = errors.New("unable to establish a new connection")
|
|
ErrSShCopy = errors.New("unable to copy file")
|
|
ErrSSHSession = errors.New("unable to open a new session")
|
|
ErrSSHReadPrivateKey = errors.New("unable to read private key")
|
|
ErrSSHParsePrivateKey = errors.New("unable to read private key")
|
|
ErrSSHExecute = errors.New("unable to execute command")
|
|
)
|
|
|
|
// NewSSHConn instanciates a new SSH connection.
|
|
// The `privkey` arg is the path of private key where the corresponding
|
|
// public has been deployed on the `user` server.
|
|
func NewSSHConn(addr, user string, port int, privkey string) (SSHConn, error) {
|
|
var newconn SSHConn
|
|
|
|
sshAddr := addr + ":" + strconv.Itoa(port)
|
|
newconn.addr = sshAddr
|
|
|
|
conn, err := net.Dial("tcp", sshAddr)
|
|
if err != nil {
|
|
return newconn, fmt.Errorf("%w, addr=%s, err=%v", ErrSSHDial, addr, err)
|
|
}
|
|
|
|
c, err := os.ReadFile(privkey)
|
|
if err != nil {
|
|
return newconn, fmt.Errorf("%w, privkey=%s, err=%v", ErrSSHReadPrivateKey, privkey, err)
|
|
}
|
|
|
|
sshPrivKey, err := ssh.ParsePrivateKey(c)
|
|
if err != nil {
|
|
return newconn, fmt.Errorf("%w, privkey=%s, err=%v", ErrSSHParsePrivateKey, privkey, err)
|
|
}
|
|
|
|
sshConfig := ssh.ClientConfig{
|
|
User: user,
|
|
HostKeyCallback: ssh.InsecureIgnoreHostKey(), //nolint:gosec // no need
|
|
Auth: []ssh.AuthMethod{
|
|
ssh.PublicKeys(sshPrivKey),
|
|
},
|
|
}
|
|
|
|
sshConn, chNewChannel, chReq, err := ssh.NewClientConn(conn, sshAddr, &sshConfig)
|
|
if err != nil {
|
|
return newconn, fmt.Errorf("%w, addr=%s, err=%v", ErrSSHConn, sshAddr, err)
|
|
}
|
|
|
|
sshClient := ssh.NewClient(sshConn, chNewChannel, chReq)
|
|
newconn.client = sshClient
|
|
|
|
return newconn, nil
|
|
}
|
|
|
|
func (c *SSHConn) Close() error {
|
|
return c.client.Close()
|
|
}
|
|
|
|
// CopyFile copies a local `src` file to the remote `dest` server.
|
|
//
|
|
// NOTE: for now the `dest` filepath (absolute or relative) does not
|
|
// create a push the file to the desired location.
|
|
// All the files are copied in the remote user HOME.
|
|
//
|
|
// TODO: create the `dest` if not exist.
|
|
func (c *SSHConn) CopyFile(src, dest string) error {
|
|
sshSession, err := c.client.NewSession()
|
|
if err != nil {
|
|
return fmt.Errorf("%w, addr=%s, err=%v", ErrSSHSession, c.addr, err)
|
|
}
|
|
defer sshSession.Close() //nolint: errcheck // defered
|
|
|
|
fileInfo, err := os.Stat(src)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to stat scp source file src=%s, err=%v", src, err)
|
|
}
|
|
|
|
file, err := os.Open(src)
|
|
if err != nil {
|
|
return fmt.Errorf("unable to open scp source file src=%s, err=%v", src, err)
|
|
}
|
|
defer file.Close() //nolint: errcheck // defered
|
|
|
|
go func() {
|
|
w, _ := sshSession.StdinPipe()
|
|
defer w.Close() //nolint: errcheck // defered
|
|
|
|
if _, err := fmt.Fprintf(w, "C0644 %d %s\n", fileInfo.Size(), filepath.Base(dest)); err != nil {
|
|
log.Debug().
|
|
Err(err).
|
|
Str("src", src).
|
|
Str("dest", dest).
|
|
Msg("unable to write file info to scp")
|
|
return
|
|
}
|
|
|
|
if _, err := io.Copy(w, file); err != nil {
|
|
log.Debug().Err(err).Str("src", src).Str("dest", dest).Msg("unable to scp src to dest")
|
|
return
|
|
}
|
|
|
|
if _, err := fmt.Fprint(w, "\x00"); err != nil {
|
|
log.Debug().
|
|
Err(err).
|
|
Str("src", src).
|
|
Str("dest", dest).
|
|
Msg("unable to write scp termination string")
|
|
}
|
|
}()
|
|
|
|
if err := sshSession.Run(fmt.Sprintf("scp -t %s", dest)); err != nil {
|
|
return fmt.Errorf(
|
|
"%w, addr=%s, src=%s, dest=%s, err=%v",
|
|
ErrSShCopy,
|
|
c.addr,
|
|
src,
|
|
dest,
|
|
err,
|
|
)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Execute executes a shell command remotly and returns the output.
|
|
func (c *SSHConn) Execute(cmd string) (string, error) {
|
|
sshSession, err := c.client.NewSession()
|
|
if err != nil {
|
|
return "", fmt.Errorf("%w, addr=%s, err=%v", ErrSSHSession, c.addr, err)
|
|
}
|
|
defer sshSession.Close() //nolint: errcheck // defered
|
|
|
|
var buf bytes.Buffer
|
|
sshSession.Stdout = &buf
|
|
if err := sshSession.Run(cmd); err != nil {
|
|
return "", fmt.Errorf("%w, addr=%s, cmd=%s, err=%v", ErrSSHExecute, c.addr, cmd, err)
|
|
}
|
|
|
|
return buf.String(), nil
|
|
}
|