mirror of
https://github.com/ayflying/p2p.git
synced 2026-03-05 01:39:23 +00:00
增加分布式更新方法
This commit is contained in:
@@ -8,4 +8,5 @@ import (
|
||||
_ "github.com/ayflying/p2p/internal/logic/os"
|
||||
_ "github.com/ayflying/p2p/internal/logic/p2p"
|
||||
_ "github.com/ayflying/p2p/internal/logic/s3"
|
||||
_ "github.com/ayflying/p2p/internal/logic/system"
|
||||
)
|
||||
|
||||
@@ -5,13 +5,14 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"path"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/ayflying/p2p/internal/service"
|
||||
"github.com/gogf/gf/v2/encoding/gjson"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/gctx"
|
||||
"github.com/gogf/gf/v2/os/gfile"
|
||||
"github.com/gogf/gf/v2/os/glog"
|
||||
"github.com/gogf/gf/v2/os/gtimer"
|
||||
"github.com/gogf/gf/v2/util/grand"
|
||||
@@ -420,18 +421,23 @@ func (s *sP2P) receiveGatewayMessages(ctx context.Context) {
|
||||
glog.Errorf(ctx, "网关错误: %s", data.Error)
|
||||
case MsgUpdate: //更新节点信息
|
||||
var msgData struct {
|
||||
Server string `json:"server"`
|
||||
Version string `json:"version"`
|
||||
Files []struct {
|
||||
File []byte `json:"file"`
|
||||
Name string `json:"name"`
|
||||
} `json:"files"`
|
||||
}
|
||||
//var msgData *dataType
|
||||
json.Unmarshal(msg.Data, &msgData)
|
||||
for _, v := range msgData.Files {
|
||||
err = gfile.PutBytes(path.Join("download", v.Name), v.File)
|
||||
}
|
||||
// 更新器路径(假设与主程序同目录)
|
||||
//updaterPath := filepath.Join(filepath.Dir(selfPath), "updater.exe")
|
||||
|
||||
g.Log().Infof(ctx, "更新节点信息: %v", data)
|
||||
|
||||
// 调用不同系统的更新服务
|
||||
service.OS().Update(msgData.Version, msgData.Server)
|
||||
//// 调用不同系统的更新服务
|
||||
//service.OS().Update(msgData.Version, msgData.Server)
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -66,14 +65,14 @@ func (s *sP2P) GatewayStart(ctx context.Context, group *ghttp.RouterGroup) (err
|
||||
_, data, _err := ws.ReadMessage()
|
||||
if _err != nil {
|
||||
//g.Log().Errorf(ctx, "读取消息失败: %v", err)
|
||||
//s.sendError(ws, err.Error())
|
||||
//s.SendError(ws, err.Error())
|
||||
break
|
||||
}
|
||||
|
||||
var msg GatewayMessage
|
||||
var msg *GatewayMessage
|
||||
if err = json.Unmarshal(data, &msg); err != nil {
|
||||
//g.Log().Error(ctx, "消息格式错误")
|
||||
s.sendError(ws, "消息格式错误")
|
||||
s.SendError(ws, "消息格式错误")
|
||||
continue
|
||||
}
|
||||
|
||||
@@ -96,7 +95,7 @@ func (s *sP2P) GatewayStart(ctx context.Context, group *ghttp.RouterGroup) (err
|
||||
}
|
||||
|
||||
// 处理注册请求
|
||||
func (s *sP2P) handleRegister(ctx context.Context, conn *websocket.Conn, msg GatewayMessage) {
|
||||
func (s *sP2P) handleRegister(ctx context.Context, conn *websocket.Conn, msg *GatewayMessage) {
|
||||
if msg.From == "" {
|
||||
g.Log().Error(ctx, "客户端ID不能为空")
|
||||
return
|
||||
@@ -108,22 +107,22 @@ func (s *sP2P) handleRegister(ctx context.Context, conn *websocket.Conn, msg Gat
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(msg.Data, &data); err != nil {
|
||||
s.sendError(conn, "注册数据格式错误")
|
||||
s.SendError(conn, "注册数据格式错误")
|
||||
return
|
||||
}
|
||||
|
||||
// 追加公网ip
|
||||
publicIp, _, _ := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
ParseIP := net.ParseIP(publicIp)
|
||||
var ipType string
|
||||
if ParseIP.To4() != nil {
|
||||
ipType = "ip4"
|
||||
} else {
|
||||
ipType = "ip6"
|
||||
}
|
||||
port2 := 53533
|
||||
data.Addrs = append(data.Addrs, fmt.Sprintf("/%s/%s/tcp/%d", ipType, publicIp, port2))
|
||||
data.Addrs = append(data.Addrs, fmt.Sprintf("/%s/%s/udp/%d/quic-v1", ipType, publicIp, port2))
|
||||
//// 追加公网ip
|
||||
//publicIp, _, _ := net.SplitHostPort(conn.RemoteAddr().String())
|
||||
//ParseIP := net.ParseIP(publicIp)
|
||||
//var ipType string
|
||||
//if ParseIP.To4() != nil {
|
||||
// ipType = "ip4"
|
||||
//} else {
|
||||
// ipType = "ip6"
|
||||
//}
|
||||
//port2 := 53533
|
||||
//data.Addrs = append(data.Addrs, fmt.Sprintf("/%s/%s/tcp/%d", ipType, publicIp, port2))
|
||||
//data.Addrs = append(data.Addrs, fmt.Sprintf("/%s/%s/udp/%d/quic-v1", ipType, publicIp, port2))
|
||||
|
||||
// 过滤回环地址
|
||||
data.Addrs = s.filterLoopbackAddrs(data.Addrs)
|
||||
@@ -144,12 +143,12 @@ func (s *sP2P) handleRegister(ctx context.Context, conn *websocket.Conn, msg Gat
|
||||
glog.Infof(ctx, "客户端 ip=%s,%s 注册成功,PeerID: %s", conn.RemoteAddr(), msg.From, data.PeerID)
|
||||
|
||||
// 发送注册成功响应
|
||||
err := s.sendMessage(conn, GatewayMessage{
|
||||
err := s.sendMessage(conn, &GatewayMessage{
|
||||
Type: MsgTypeRegisterAck,
|
||||
Data: json.RawMessage(`{"success": true, "message": "注册成功"}`),
|
||||
})
|
||||
if err != nil {
|
||||
s.sendError(conn, err.Error())
|
||||
s.SendError(conn, err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -178,15 +177,41 @@ func (s *sP2P) cleanupClients(ctx context.Context) {
|
||||
}
|
||||
|
||||
// 发送错误消息
|
||||
func (s *sP2P) sendError(conn *websocket.Conn, errMsg string) {
|
||||
s.sendMessage(conn, GatewayMessage{
|
||||
func (s *sP2P) SendError(conn *websocket.Conn, errMsg string) {
|
||||
s.sendMessage(conn, &GatewayMessage{
|
||||
Type: "error",
|
||||
Data: json.RawMessage(fmt.Sprintf(`{"error": "%s"}`, errMsg)),
|
||||
})
|
||||
}
|
||||
|
||||
// SendAll 发送消息给所有客户端
|
||||
func (s *sP2P) SendAll(typ string, data any) (err error) {
|
||||
for _, client := range s.Clients {
|
||||
conn := client.Conn
|
||||
err = s.sendMessage(conn, &GatewayMessage{
|
||||
Type: MsgType(typ),
|
||||
Data: gjson.MustEncode(data),
|
||||
})
|
||||
if err != nil {
|
||||
g.Log().Error(gctx.New(), err)
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Send 发送消息给指定客户端
|
||||
func (s *sP2P) Send(conn *websocket.Conn, typ string, data any) (err error) {
|
||||
err = s.sendMessage(conn, &GatewayMessage{
|
||||
Type: MsgType(typ),
|
||||
Data: gjson.MustEncode(data),
|
||||
})
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// 发送消息
|
||||
func (s *sP2P) sendMessage(conn *websocket.Conn, msg GatewayMessage) error {
|
||||
func (s *sP2P) sendMessage(conn *websocket.Conn, msg *GatewayMessage) error {
|
||||
data, err := json.Marshal(msg)
|
||||
if err != nil {
|
||||
glog.Errorf(gctx.New(), "序列化消息失败: %v", err)
|
||||
@@ -196,9 +221,9 @@ func (s *sP2P) sendMessage(conn *websocket.Conn, msg GatewayMessage) error {
|
||||
}
|
||||
|
||||
// 处理发现请求
|
||||
func (s *sP2P) handleDiscover(ctx context.Context, conn *websocket.Conn, msg GatewayMessage) {
|
||||
func (s *sP2P) handleDiscover(ctx context.Context, conn *websocket.Conn, msg *GatewayMessage) {
|
||||
if msg.From == "" {
|
||||
s.sendError(conn, "消息缺少发送方ID(from)")
|
||||
s.SendError(conn, "消息缺少发送方ID(from)")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -207,12 +232,12 @@ func (s *sP2P) handleDiscover(ctx context.Context, conn *websocket.Conn, msg Gat
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(msg.Data, &data); err != nil {
|
||||
s.sendError(conn, "发现请求格式错误,需包含target_id")
|
||||
s.SendError(conn, "发现请求格式错误,需包含target_id")
|
||||
return
|
||||
}
|
||||
|
||||
if data.TargetID == "" {
|
||||
s.sendError(conn, "目标ID不能为空")
|
||||
s.SendError(conn, "目标ID不能为空")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -225,7 +250,7 @@ func (s *sP2P) handleDiscover(ctx context.Context, conn *websocket.Conn, msg Gat
|
||||
s.lock.RUnlock()
|
||||
|
||||
//if !fromExists {
|
||||
// s.sendError(conn, "请先注册")
|
||||
// s.SendError(conn, "请先注册")
|
||||
// return
|
||||
//}
|
||||
|
||||
@@ -236,11 +261,10 @@ func (s *sP2P) handleDiscover(ctx context.Context, conn *websocket.Conn, msg Gat
|
||||
|
||||
if targetClient == nil {
|
||||
// 目标不存在
|
||||
s.sendMessage(conn, GatewayMessage{
|
||||
s.sendMessage(conn, &GatewayMessage{
|
||||
Type: MsgTypeDiscoverAck,
|
||||
From: "gateway",
|
||||
To: msg.From,
|
||||
//Data: json.RawMessage(`{"found": false}`),
|
||||
Data: gjson.MustEncode(g.Map{
|
||||
"found": false,
|
||||
}),
|
||||
@@ -249,7 +273,7 @@ func (s *sP2P) handleDiscover(ctx context.Context, conn *websocket.Conn, msg Gat
|
||||
}
|
||||
|
||||
// 向请求方发送目标信息
|
||||
s.sendMessage(conn, GatewayMessage{
|
||||
s.sendMessage(conn, &GatewayMessage{
|
||||
Type: MsgTypeDiscoverAck,
|
||||
From: "gateway", // 发送方是网关
|
||||
To: msg.From, // 接收方是原请求方
|
||||
@@ -262,7 +286,7 @@ func (s *sP2P) handleDiscover(ctx context.Context, conn *websocket.Conn, msg Gat
|
||||
})
|
||||
|
||||
// 向目标方发送打洞请求(协调时机)
|
||||
s.sendMessage(targetClient.Conn, GatewayMessage{
|
||||
s.sendMessage(targetClient.Conn, &GatewayMessage{
|
||||
Type: MsgTypePunchRequest,
|
||||
From: msg.From, // 发送方是原请求方
|
||||
To: data.TargetID, // 接收方是目标方
|
||||
|
||||
@@ -26,8 +26,8 @@ var (
|
||||
|
||||
// 常量定义
|
||||
const (
|
||||
ProtocolID string = "/ay"
|
||||
DefaultPort = 51888
|
||||
ProtocolID string = "/ay"
|
||||
//DefaultPort = 51888
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -58,7 +58,10 @@ func New(_name ...string) *sS3 {
|
||||
if len(_name) > 0 {
|
||||
name = _name[0]
|
||||
} else {
|
||||
getName, _ := g.Cfg("local").Get(gctx.New(), "s3.type")
|
||||
getName, err := g.Cfg("local").Get(gctx.New(), "s3.type")
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
name = getName.String()
|
||||
}
|
||||
|
||||
|
||||
17
internal/logic/system/system.go
Normal file
17
internal/logic/system/system.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"github.com/ayflying/p2p/internal/service"
|
||||
)
|
||||
|
||||
type sSystem struct{}
|
||||
|
||||
func New() *sSystem {
|
||||
return &sSystem{}
|
||||
}
|
||||
|
||||
func init() {
|
||||
service.RegisterSystem(New())
|
||||
}
|
||||
|
||||
func (system *sSystem) Init() {}
|
||||
98
internal/logic/system/update.go
Normal file
98
internal/logic/system/update.go
Normal file
@@ -0,0 +1,98 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/ayflying/p2p/internal/service"
|
||||
"github.com/gogf/gf/v2/encoding/gcompress"
|
||||
"github.com/gogf/gf/v2/frame/g"
|
||||
"github.com/gogf/gf/v2/os/gcmd"
|
||||
)
|
||||
|
||||
func (s *sSystem) Update(ctx context.Context) (err error) {
|
||||
//拼接操作系统和架构(格式:OS_ARCH)
|
||||
platform := fmt.Sprintf("%s_%s", runtime.GOOS, runtime.GOARCH)
|
||||
|
||||
runFile := gcmd.GetArg(0).String()
|
||||
oldFile, err := service.System().RenameRunningFile(runFile)
|
||||
g.Log().Debugf(ctx, "执行文件改名为%v", oldFile)
|
||||
gz := path.Join("download", platform+".gz")
|
||||
err = gcompress.UnGzipFile(gz, runFile)
|
||||
|
||||
go func() {
|
||||
log.Println("5秒后开始重启...")
|
||||
time.Sleep(5 * time.Second)
|
||||
|
||||
if err = service.System().RestartSelf(); err != nil {
|
||||
log.Fatalf("重启失败:%v", err)
|
||||
}
|
||||
}()
|
||||
return
|
||||
}
|
||||
|
||||
// RestartSelf 实现 Windows 平台下的程序自重启
|
||||
func (s *sSystem) RestartSelf() error {
|
||||
// 1. 获取当前程序的绝对路径
|
||||
exePath, err := os.Executable()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// 处理路径中的符号链接(确保路径正确)
|
||||
exePath, err = filepath.EvalSymlinks(exePath)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 2. 获取命令行参数(os.Args[0] 是程序名,实际参数从 os.Args[1:] 开始)
|
||||
args := os.Args[1:]
|
||||
|
||||
// 3. 构建新进程命令(路径为当前程序,参数为原参数)
|
||||
cmd := exec.Command(exePath, args...)
|
||||
// 设置新进程的工作目录与当前进程一致
|
||||
cmd.Dir, err = os.Getwd()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 新进程的输出继承当前进程的标准输出(可选,根据需求调整)
|
||||
cmd.Stdout = os.Stdout
|
||||
cmd.Stderr = os.Stderr
|
||||
cmd.Stdin = os.Stdin
|
||||
|
||||
// 4. 启动新进程(非阻塞,Start() 后立即返回)
|
||||
if err := cmd.Start(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// 5. 新进程启动成功后,退出当前进程
|
||||
os.Exit(0)
|
||||
return nil // 理论上不会执行到这里
|
||||
}
|
||||
|
||||
// RenameRunningFile 重命名正在运行的程序文件(如 p2p.exe → p2p.exe~)
|
||||
func (s *sSystem) RenameRunningFile(exePath string) (string, error) {
|
||||
// 目标备份文件名(p2p.exe → p2p.exe~)
|
||||
backupPath := exePath + "~"
|
||||
|
||||
// 先删除已存在的备份文件(若有)
|
||||
if _, err := os.Stat(backupPath); err == nil {
|
||||
if err := os.Remove(backupPath); err != nil {
|
||||
return "", fmt.Errorf("删除旧备份文件失败: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// 重命名正在运行的 exe 文件
|
||||
// 关键:Windows 允许对锁定的文件执行重命名操作
|
||||
if err := os.Rename(exePath, backupPath); err != nil {
|
||||
return "", fmt.Errorf("重命名运行中文件失败: %v", err)
|
||||
}
|
||||
return backupPath, nil
|
||||
}
|
||||
Reference in New Issue
Block a user