152 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
			
		
		
	
	
			152 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Go
		
	
	
	
	
	
| package connection
 | |
| 
 | |
| import (
 | |
| 	"bytes"
 | |
| 	"errors"
 | |
| 	"fmt"
 | |
| 	"io"
 | |
| 	"net"
 | |
| 	"os"
 | |
| 	"path/filepath"
 | |
| 	"strconv"
 | |
| 
 | |
| 	"github.com/rs/zerolog/log"
 | |
| 	"golang.org/x/crypto/ssh"
 | |
| )
 | |
| 
 | |
| type SSHConn struct {
 | |
| 	addr   string
 | |
| 	client *ssh.Client
 | |
| }
 | |
| 
 | |
| var _ IConnection = (*SSHConn)(nil)
 | |
| 
 | |
| var (
 | |
| 	ErrSSHDial            = errors.New("unable to dial ssh addr")
 | |
| 	ErrSSHConn            = errors.New("unable to establish a new connection")
 | |
| 	ErrSShCopy            = errors.New("unable to copy file")
 | |
| 	ErrSSHSession         = errors.New("unable to open a new session")
 | |
| 	ErrSSHReadPrivateKey  = errors.New("unable to read private key")
 | |
| 	ErrSSHParsePrivateKey = errors.New("unable to read private key")
 | |
| 	ErrSSHExecute         = errors.New("unable to execute command")
 | |
| )
 | |
| 
 | |
| func NewSSHConn(addr, user string, port int, privkey string) (SSHConn, error) {
 | |
| 	var newconn SSHConn
 | |
| 
 | |
| 	sshAddr := addr + ":" + strconv.Itoa(port)
 | |
| 	newconn.addr = sshAddr
 | |
| 
 | |
| 	conn, err := net.Dial("tcp", sshAddr)
 | |
| 	if err != nil {
 | |
| 		return newconn, fmt.Errorf("%w, addr=%s, err=%v", ErrSSHDial, addr, err)
 | |
| 	}
 | |
| 
 | |
| 	c, err := os.ReadFile(privkey)
 | |
| 	if err != nil {
 | |
| 		return newconn, fmt.Errorf("%w, privkey=%s, err=%v", ErrSSHReadPrivateKey, privkey, err)
 | |
| 	}
 | |
| 
 | |
| 	sshPrivKey, err := ssh.ParsePrivateKey(c)
 | |
| 	if err != nil {
 | |
| 		return newconn, fmt.Errorf("%w, privkey=%s, err=%v", ErrSSHParsePrivateKey, privkey, err)
 | |
| 	}
 | |
| 
 | |
| 	sshConfig := ssh.ClientConfig{
 | |
| 		User:            user,
 | |
| 		HostKeyCallback: ssh.InsecureIgnoreHostKey(), //nolint:gosec // no need
 | |
| 		Auth: []ssh.AuthMethod{
 | |
| 			ssh.PublicKeys(sshPrivKey),
 | |
| 		},
 | |
| 	}
 | |
| 
 | |
| 	sshConn, chNewChannel, chReq, err := ssh.NewClientConn(conn, sshAddr, &sshConfig)
 | |
| 	if err != nil {
 | |
| 		return newconn, fmt.Errorf("%w, addr=%s, err=%v", ErrSSHConn, sshAddr, err)
 | |
| 	}
 | |
| 
 | |
| 	sshClient := ssh.NewClient(sshConn, chNewChannel, chReq)
 | |
| 	newconn.client = sshClient
 | |
| 
 | |
| 	return newconn, nil
 | |
| }
 | |
| 
 | |
| func (c *SSHConn) Close() error {
 | |
| 	return c.client.Close()
 | |
| }
 | |
| 
 | |
| func (c *SSHConn) CopyFile(src, dest string) error {
 | |
| 	sshSession, err := c.client.NewSession()
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("%w, addr=%s, err=%v", ErrSSHSession, c.addr, err)
 | |
| 	}
 | |
| 	defer sshSession.Close() //nolint: errcheck // defered
 | |
| 
 | |
| 	fileInfo, err := os.Stat(src)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("unable to stat scp source file src=%s, err=%v", src, err)
 | |
| 	}
 | |
| 
 | |
| 	file, err := os.Open(src)
 | |
| 	if err != nil {
 | |
| 		return fmt.Errorf("unable to open scp source file src=%s, err=%v", src, err)
 | |
| 	}
 | |
| 	defer file.Close() //nolint: errcheck // defered
 | |
| 
 | |
| 	go func() {
 | |
| 		w, _ := sshSession.StdinPipe()
 | |
| 		defer w.Close() //nolint: errcheck // defered
 | |
| 
 | |
| 		if _, err := fmt.Fprintf(w, "C0644 %d %s\n", fileInfo.Size(), filepath.Base(dest)); err != nil {
 | |
| 			log.Debug().
 | |
| 				Err(err).
 | |
| 				Str("src", src).
 | |
| 				Str("dest", dest).
 | |
| 				Msg("unable to write file info to scp")
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		if _, err := io.Copy(w, file); err != nil {
 | |
| 			log.Debug().Err(err).Str("src", src).Str("dest", dest).Msg("unable to scp src to dest")
 | |
| 			return
 | |
| 		}
 | |
| 
 | |
| 		if _, err := fmt.Fprint(w, "\x00"); err != nil {
 | |
| 			log.Debug().
 | |
| 				Err(err).
 | |
| 				Str("src", src).
 | |
| 				Str("dest", dest).
 | |
| 				Msg("unable to write scp termination string")
 | |
| 		}
 | |
| 	}()
 | |
| 
 | |
| 	if err := sshSession.Run(fmt.Sprintf("scp -t %s", dest)); err != nil {
 | |
| 		return fmt.Errorf(
 | |
| 			"%w, addr=%s, src=%s, dest=%s, err=%v",
 | |
| 			ErrSShCopy,
 | |
| 			c.addr,
 | |
| 			src,
 | |
| 			dest,
 | |
| 			err,
 | |
| 		)
 | |
| 	}
 | |
| 
 | |
| 	return nil
 | |
| }
 | |
| 
 | |
| func (c *SSHConn) Execute(cmd string) (string, error) {
 | |
| 	sshSession, err := c.client.NewSession()
 | |
| 	if err != nil {
 | |
| 		return "", fmt.Errorf("%w, addr=%s, err=%v", ErrSSHSession, c.addr, err)
 | |
| 	}
 | |
| 	defer sshSession.Close() //nolint: errcheck // defered
 | |
| 
 | |
| 	var buf bytes.Buffer
 | |
| 	sshSession.Stdout = &buf
 | |
| 	if err := sshSession.Run(cmd); err != nil {
 | |
| 		return "", fmt.Errorf("%w, addr=%s, cmd=%s, err=%v", ErrSSHExecute, c.addr, cmd, err)
 | |
| 	}
 | |
| 
 | |
| 	return "", nil
 | |
| }
 | 
