hmdeploy/connection/ssh_connection.go
2025-04-02 11:18:23 +02:00

132 lines
3.2 KiB
Go

package connection
import (
"bytes"
"errors"
"fmt"
"io"
"net"
"os"
"path/filepath"
"strconv"
"github.com/rs/zerolog/log"
"golang.org/x/crypto/ssh"
)
type SSHConn struct {
addr string
client *ssh.Client
}
var _ IConnection = (*SSHConn)(nil)
var (
ErrSSHDial = errors.New("unable to dial ssh addr")
ErrSSHConn = errors.New("unable to establish a new connection")
ErrSShCopy = errors.New("unable to copy file")
ErrSSHSession = errors.New("unable to open a new session")
ErrSSHReadPrivateKey = errors.New("unable to read private key")
ErrSSHParsePrivateKey = errors.New("unable to read private key")
ErrSSHExecute = errors.New("unable")
)
func NewSSHConn(addr, user string, port int, privkey string) (SSHConn, error) {
var newconn SSHConn
sshAddr := addr + ":" + strconv.Itoa(port)
newconn.addr = sshAddr
conn, err := net.Dial("tcp", sshAddr)
if err != nil {
return newconn, fmt.Errorf("%w, addr=%s, err=%v", ErrSSHDial, addr, err)
}
c, err := os.ReadFile(privkey)
if err != nil {
return newconn, fmt.Errorf("%w, privkey=%s, err=%v", ErrSSHReadPrivateKey, privkey, err)
}
sshPrivKey, err := ssh.ParsePrivateKey(c)
if err != nil {
return newconn, fmt.Errorf("%w, privkey=%s, err=%v", ErrSSHParsePrivateKey, privkey, err)
}
sshConfig := ssh.ClientConfig{
User: user,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
Auth: []ssh.AuthMethod{
ssh.PublicKeys(sshPrivKey),
},
}
sshConn, chNewChannel, chReq, err := ssh.NewClientConn(conn, sshAddr, &sshConfig)
if err != nil {
return newconn, fmt.Errorf("%w, addr=%s, err=%v", ErrSSHConn, sshAddr, err)
}
sshClient := ssh.NewClient(sshConn, chNewChannel, chReq)
newconn.client = sshClient
return newconn, nil
}
func (c *SSHConn) Close() error {
return c.client.Close()
}
func (c *SSHConn) CopyFile(src, dest string) error {
sshSession, err := c.client.NewSession()
if err != nil {
return fmt.Errorf("%w, addr=%s, err=%v", ErrSSHSession, c.addr, err)
}
defer sshSession.Close()
fileInfo, err := os.Stat(src)
if err != nil {
return fmt.Errorf("unable to stat scp source file src=%s, err=%v", src, err)
}
file, err := os.Open(src)
if err != nil {
return fmt.Errorf("unable to open scp source file src=%s, err=%v", src, err)
}
defer file.Close()
go func() {
w, _ := sshSession.StdinPipe()
defer w.Close()
fmt.Fprintf(w, "C0644 %d %s\n", fileInfo.Size(), filepath.Base(dest))
if _, err := io.Copy(w, file); err != nil {
log.Debug().Err(err).Str("src", src).Str("dest", dest).Msg("unable to scp src to dest")
return
}
fmt.Fprint(w, "\x00")
}()
if err := sshSession.Run(fmt.Sprintf("scp -t %s", dest)); err != nil {
return fmt.Errorf("%w, addr=%s, src=%s, dest=%s, err=%v", ErrSShCopy, c.addr, src, dest, err)
}
return nil
}
func (c *SSHConn) Execute(cmd string) (string, error) {
sshSession, err := c.client.NewSession()
if err != nil {
return "", fmt.Errorf("%w, addr=%s, err=%v", ErrSSHSession, c.addr, err)
}
defer sshSession.Close()
var buf bytes.Buffer
sshSession.Stdout = &buf
if err := sshSession.Run(cmd); err != nil {
return "", fmt.Errorf("%w, addr=%s, cmd=%s, err=%v", ErrSSHExecute, c.addr, cmd, err)
}
return "", nil
}