diff --git a/internal/api/redpacket.go b/internal/api/redpacket.go index 62f50a9ac..87e1f9845 100644 --- a/internal/api/redpacket.go +++ b/internal/api/redpacket.go @@ -88,6 +88,34 @@ func (h *RedPacketApi) ClaimResult(ctx *gin.Context) { apiresp.GinSuccess(ctx, resp) } +func (h *RedPacketApi) RequestRefund(ctx *gin.Context) { + req, err := a2r.ParseRequestNotCheck[pbredpacket.RequestRefundReq](ctx) + if err != nil { + apiresp.GinError(ctx, err) + return + } + resp, err := h.Client.RequestRefund(ctx, req) + if err != nil { + apiresp.GinError(ctx, err) + return + } + apiresp.GinSuccess(ctx, resp) +} + +func (h *RedPacketApi) GetRefund(ctx *gin.Context) { + req, err := a2r.ParseRequestNotCheck[pbredpacket.GetRefundReq](ctx) + if err != nil { + apiresp.GinError(ctx, err) + return + } + resp, err := h.Client.GetRefund(ctx, req) + if err != nil { + apiresp.GinError(ctx, err) + return + } + apiresp.GinSuccess(ctx, resp) +} + func (h *RedPacketApi) IssueWalletBindChallenge(ctx *gin.Context) { req, err := a2r.ParseRequestNotCheck[pbredpacket.IssueWalletBindChallengeReq](ctx) if err != nil { diff --git a/internal/api/router.go b/internal/api/router.go index 9e94a8098..c1448aeb2 100644 --- a/internal/api/router.go +++ b/internal/api/router.go @@ -377,6 +377,8 @@ func newGinRouter(ctx context.Context, client discovery.SvcDiscoveryRegistry, co redpacketGroup.POST("/detail", rp.GetDetail) redpacketGroup.POST("/issue_claim_sign", rp.IssueClaimSign) redpacketGroup.POST("/claim_result", rp.ClaimResult) + redpacketGroup.POST("/request_refund", rp.RequestRefund) + redpacketGroup.POST("/get_refund", rp.GetRefund) redpacketGroup.POST("/wallet_bind/challenge", rp.IssueWalletBindChallenge) redpacketGroup.POST("/wallet_bind/confirm", rp.ConfirmWalletBind) redpacketGroup.POST("/wallet_bind/detail", rp.GetWalletBinding) diff --git a/internal/rpc/redpacket/admin.go b/internal/rpc/redpacket/admin.go index 5b459e28f..b7ea91a37 100644 --- a/internal/rpc/redpacket/admin.go +++ b/internal/rpc/redpacket/admin.go @@ -2,16 +2,58 @@ package redpacket import ( "context" + "encoding/json" "fmt" "math/big" + "time" "github.com/ethereum/go-ethereum/common" + "github.com/openimsdk/open-im-server/v3/pkg/authverify" + "github.com/openimsdk/open-im-server/v3/pkg/common/storage/model" pbredpacket "github.com/openimsdk/protocol/redpacket" "github.com/openimsdk/tools/errs" "github.com/openimsdk/tools/log" + "github.com/openimsdk/tools/mcontext" + "go.mongodb.org/mongo-driver/bson/primitive" ) -func (s *redPacketServer) SetSigner(ctx context.Context, req *pbredpacket.SetSignerReq) (*pbredpacket.SetSignerResp, error) { +// checkAdminPermission is a convenience wrapper used by every admin handler. +func (s *redPacketServer) checkAdminPermission(ctx context.Context) error { + return authverify.CheckAdmin(ctx, s.config.Share.IMAdminUserID) +} + +// recordAudit persists an admin audit entry asynchronously; errors are only +// logged so they never block the primary operation. +func (s *redPacketServer) recordAudit(ctx context.Context, action string, req interface{}, opErr error) { + params := "" + if b, err := json.Marshal(req); err == nil { + params = string(b) + } + result := "success" + errMsg := "" + if opErr != nil { + result = "failed" + errMsg = opErr.Error() + } + entry := &model.AdminAuditLog{ + ID: primitive.NewObjectID(), + OperatorID: mcontext.GetOpUserID(ctx), + Action: action, + Params: params, + Result: result, + ErrMsg: errMsg, + CreatedAt: time.Now().UTC(), + } + if err := s.db.CreateAdminAuditLog(ctx, entry); err != nil { + log.ZWarn(ctx, "redpacket admin audit log write failed", err, "action", action) + } +} + +func (s *redPacketServer) SetSigner(ctx context.Context, req *pbredpacket.SetSignerReq) (resp *pbredpacket.SetSignerResp, retErr error) { + defer func() { s.recordAudit(ctx, "SetSigner", req, retErr) }() + if err := s.checkAdminPermission(ctx); err != nil { + return nil, err + } if req.SignerAddress == "" { return nil, errs.ErrArgs.WrapMsg("signer_address is required") } @@ -28,7 +70,11 @@ func (s *redPacketServer) SetSigner(ctx context.Context, req *pbredpacket.SetSig return nil, errs.ErrInternalServer.WrapMsg("no blockchain client configured") } -func (s *redPacketServer) SetToken(ctx context.Context, req *pbredpacket.SetTokenReq) (*pbredpacket.SetTokenResp, error) { +func (s *redPacketServer) SetToken(ctx context.Context, req *pbredpacket.SetTokenReq) (resp *pbredpacket.SetTokenResp, retErr error) { + defer func() { s.recordAudit(ctx, "SetToken", req, retErr) }() + if err := s.checkAdminPermission(ctx); err != nil { + return nil, err + } if req.TokenAddress == "" { return nil, errs.ErrArgs.WrapMsg("token_address is required") } @@ -55,7 +101,11 @@ func (s *redPacketServer) SetToken(ctx context.Context, req *pbredpacket.SetToke return nil, errs.ErrInternalServer.WrapMsg("no blockchain client configured") } -func (s *redPacketServer) SetExpiry(ctx context.Context, req *pbredpacket.SetExpiryReq) (*pbredpacket.SetExpiryResp, error) { +func (s *redPacketServer) SetExpiry(ctx context.Context, req *pbredpacket.SetExpiryReq) (resp *pbredpacket.SetExpiryResp, retErr error) { + defer func() { s.recordAudit(ctx, "SetExpiry", req, retErr) }() + if err := s.checkAdminPermission(ctx); err != nil { + return nil, err + } if req.ExpirySeconds <= 0 { return nil, errs.ErrArgs.WrapMsg("expiry_seconds must be positive") } @@ -72,7 +122,11 @@ func (s *redPacketServer) SetExpiry(ctx context.Context, req *pbredpacket.SetExp return nil, errs.ErrInternalServer.WrapMsg("no blockchain client configured") } -func (s *redPacketServer) SetAllowAllTokens(ctx context.Context, req *pbredpacket.SetAllowAllTokensReq) (*pbredpacket.SetAllowAllTokensResp, error) { +func (s *redPacketServer) SetAllowAllTokens(ctx context.Context, req *pbredpacket.SetAllowAllTokensReq) (resp *pbredpacket.SetAllowAllTokensResp, retErr error) { + defer func() { s.recordAudit(ctx, "SetAllowAllTokens", req, retErr) }() + if err := s.checkAdminPermission(ctx); err != nil { + return nil, err + } if s.chainClient != nil { log.ZInfo(ctx, "redpacket admin setAllowAllTokens (eth mock)", "allowAll", req.AllowAll) return &pbredpacket.SetAllowAllTokensResp{Message: "allow all tokens setting updated"}, nil @@ -86,7 +140,11 @@ func (s *redPacketServer) SetAllowAllTokens(ctx context.Context, req *pbredpacke return nil, errs.ErrInternalServer.WrapMsg("no blockchain client configured") } -func (s *redPacketServer) SetNativeTokenEnabled(ctx context.Context, req *pbredpacket.SetNativeTokenEnabledReq) (*pbredpacket.SetNativeTokenEnabledResp, error) { +func (s *redPacketServer) SetNativeTokenEnabled(ctx context.Context, req *pbredpacket.SetNativeTokenEnabledReq) (resp *pbredpacket.SetNativeTokenEnabledResp, retErr error) { + defer func() { s.recordAudit(ctx, "SetNativeTokenEnabled", req, retErr) }() + if err := s.checkAdminPermission(ctx); err != nil { + return nil, err + } if s.chainClient != nil { log.ZInfo(ctx, "redpacket admin setNativeTokenEnabled (eth mock)", "enabled", req.Enabled) return &pbredpacket.SetNativeTokenEnabledResp{Message: "native token setting updated"}, nil @@ -100,7 +158,11 @@ func (s *redPacketServer) SetNativeTokenEnabled(ctx context.Context, req *pbredp return nil, errs.ErrInternalServer.WrapMsg("no blockchain client configured") } -func (s *redPacketServer) ParseTxEvents(ctx context.Context, req *pbredpacket.ParseTxEventsReq) (*pbredpacket.ParseTxEventsResp, error) { +func (s *redPacketServer) ParseTxEvents(ctx context.Context, req *pbredpacket.ParseTxEventsReq) (resp *pbredpacket.ParseTxEventsResp, retErr error) { + defer func() { s.recordAudit(ctx, "ParseTxEvents", req, retErr) }() + if err := s.checkAdminPermission(ctx); err != nil { + return nil, err + } if req.TxHash == "" { return nil, errs.ErrArgs.WrapMsg("tx_hash is required") } diff --git a/internal/rpc/redpacket/chain/client.go b/internal/rpc/redpacket/chain/client.go index 0057545c3..896e8c903 100644 --- a/internal/rpc/redpacket/chain/client.go +++ b/internal/rpc/redpacket/chain/client.go @@ -11,6 +11,7 @@ import ( "github.com/ethereum/go-ethereum" "github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/core/types" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/ethclient" ) @@ -141,6 +142,57 @@ func (c *ChainClient) ContractABI() abi.ABI { return c.contractABI } +// RefundPacket submits an on-chain refund transaction for an expired red +// packet. It uses the configAdminKey to sign and broadcast the transaction. +// Returns the transaction hash on success. +func (c *ChainClient) RefundPacket(ctx context.Context, packetIDStr string) (string, error) { + if c.configAdminKey == nil { + return "", fmt.Errorf("config admin key not configured") + } + + packetID, ok := new(big.Int).SetString(packetIDStr, 10) + if !ok { + return "", fmt.Errorf("invalid packetID: %s", packetIDStr) + } + + data, err := c.contractABI.Pack("refundPacket", packetID) + if err != nil { + return "", fmt.Errorf("pack refundPacket failed: %w", err) + } + + fromAddr := crypto.PubkeyToAddress(c.configAdminKey.PublicKey) + nonce, err := c.client.PendingNonceAt(ctx, fromAddr) + if err != nil { + return "", fmt.Errorf("get nonce failed: %w", err) + } + + gasPrice, err := c.client.SuggestGasPrice(ctx) + if err != nil { + return "", fmt.Errorf("suggest gas price failed: %w", err) + } + + gasLimit, err := c.client.EstimateGas(ctx, ethereum.CallMsg{ + From: fromAddr, + To: &c.contractAddr, + Data: data, + }) + if err != nil { + gasLimit = 200000 // fallback + } + + tx := types.NewTransaction(nonce, c.contractAddr, big.NewInt(0), gasLimit, gasPrice, data) + signedTx, err := types.SignTx(tx, types.NewEIP155Signer(c.chainID), c.configAdminKey) + if err != nil { + return "", fmt.Errorf("sign refund tx failed: %w", err) + } + + if err := c.client.SendTransaction(ctx, signedTx); err != nil { + return "", fmt.Errorf("send refund tx failed: %w", err) + } + + return signedTx.Hash().Hex(), nil +} + func (c *ChainClient) Close() { if c.client != nil { c.client.Close() diff --git a/internal/rpc/redpacket/chain/indexer.go b/internal/rpc/redpacket/chain/indexer.go index 0ea5525ff..fd330560a 100644 --- a/internal/rpc/redpacket/chain/indexer.go +++ b/internal/rpc/redpacket/chain/indexer.go @@ -41,7 +41,6 @@ func (i *Indexer) Start(ctx context.Context) { go func() { ticker := time.NewTicker(i.pollInterval) defer ticker.Stop() - for { select { case <-ctx.Done(): @@ -54,6 +53,40 @@ func (i *Indexer) Start(ctx context.Context) { } } }() + + // Compensation loop: periodically scan DB for expired-but-unclosed packets + // and mark them EXPIRED so the UI reflects the correct state even if the + // on-chain refund event was missed. + go func() { + ticker := time.NewTicker(60 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := i.compensate(ctx); err != nil { + log.ZWarn(ctx, "redpacket eth compensation error", err) + } + } + } + }() +} + +func (i *Indexer) compensate(ctx context.Context) error { + now := time.Now().Unix() + packets, err := i.db.GetExpiredPendingPackets(ctx, now) + if err != nil { + return fmt.Errorf("get expired packets failed: %w", err) + } + for _, rp := range packets { + if err := i.db.UpdateRedPacketStatus(ctx, rp.PacketID, "EXPIRED"); err != nil { + log.ZWarn(ctx, "redpacket eth compensation mark expired failed", err, "packetID", rp.PacketID) + continue + } + log.ZInfo(ctx, "redpacket eth compensation: marked packet EXPIRED", "packetID", rp.PacketID) + } + return nil } func (i *Indexer) poll(ctx context.Context) error { diff --git a/internal/rpc/redpacket/chain/tron_indexer.go b/internal/rpc/redpacket/chain/tron_indexer.go index be7b1f2b8..81f56f3c8 100644 --- a/internal/rpc/redpacket/chain/tron_indexer.go +++ b/internal/rpc/redpacket/chain/tron_indexer.go @@ -39,7 +39,6 @@ func (t *TronIndexer) Start(ctx context.Context) { go func() { ticker := time.NewTicker(t.pollInterval) defer ticker.Stop() - for { select { case <-ctx.Done(): @@ -53,6 +52,37 @@ func (t *TronIndexer) Start(ctx context.Context) { } } }() + + go func() { + ticker := time.NewTicker(60 * time.Second) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + if err := t.compensate(ctx); err != nil { + log.ZWarn(ctx, "redpacket tron compensation error", err) + } + } + } + }() +} + +func (t *TronIndexer) compensate(ctx context.Context) error { + now := time.Now().Unix() + packets, err := t.db.GetExpiredPendingPackets(ctx, now) + if err != nil { + return fmt.Errorf("get expired packets failed: %w", err) + } + for _, rp := range packets { + if err := t.db.UpdateRedPacketStatus(ctx, rp.PacketID, "EXPIRED"); err != nil { + log.ZWarn(ctx, "redpacket tron compensation mark expired failed", err, "packetID", rp.PacketID) + continue + } + log.ZInfo(ctx, "redpacket tron compensation: marked packet EXPIRED", "packetID", rp.PacketID) + } + return nil } func (t *TronIndexer) poll(ctx context.Context) error { @@ -131,84 +161,86 @@ func (t *TronIndexer) scanBlock(ctx context.Context, blockNum int64) error { return nil } +// processTransaction parses the on-chain receipt through the ABI (same path as +// the ETH indexer) and dispatches each decoded event to the appropriate handler. func (t *TronIndexer) processTransaction(ctx context.Context, txID string) error { - var txInfo map[string]interface{} - err := postJSON(ctx, t.client.fullNodeURL+"/wallet/gettransactioninfobyid", map[string]interface{}{ - "value": txID, - }, &txInfo) + events, err := t.client.ParseTransactionReceipt(ctx, txID) if err != nil { - return err + return fmt.Errorf("parse tron tx receipt failed: %w", err) } - contractAddress := t.client.contractBase58 - if logs, ok := txInfo["log"].([]interface{}); ok && len(logs) > 0 { - for _, logEntry := range logs { - if logMap, ok := logEntry.(map[string]interface{}); ok { - if address, ok := logMap["address"].(string); ok && address == contractAddress { - eventType := t.parseTronEvent(logMap) - log.ZDebug(ctx, "redpacket tron event detected", "event", eventType, "txID", txID) - - switch eventType { - case "PacketCreated": - t.handleTronPacketCreated(ctx, logMap, txID) - case "PacketClaimed": - t.handleTronPacketClaimed(ctx, logMap, txID) - case "PacketRefunded": - t.handleTronPacketRefunded(ctx, logMap, txID) - } - } + for _, event := range events { + log.ZDebug(ctx, "redpacket tron event detected", "event", event.Name, "txID", txID) + switch event.Name { + case "PacketCreated": + if err := t.handleTronPacketCreated(ctx, event, txID); err != nil { + log.ZWarn(ctx, "redpacket tron handlePacketCreated failed", err, "txID", txID) } - } - } - - return nil -} - -func (t *TronIndexer) parseTronEvent(logEntry map[string]interface{}) string { - if topics, ok := logEntry["topics"].([]interface{}); ok && len(topics) > 0 { - if topic0, ok := topics[0].(string); ok { - switch topic0 { - case "0x8be0079c531659141344cd1fd0a4f28419497f9722a3daafe3b4186f6b6457e0": - return "Transfer" - default: - return "UnknownEvent" + case "PacketClaimed": + if err := t.handleTronPacketClaimed(ctx, event, txID); err != nil { + log.ZWarn(ctx, "redpacket tron handlePacketClaimed failed", err, "txID", txID) + } + case "PacketRefunded": + if err := t.handleTronPacketRefunded(ctx, event, txID); err != nil { + log.ZWarn(ctx, "redpacket tron handlePacketRefunded failed", err, "txID", txID) } } } - return "UnknownEvent" + return nil } -func (t *TronIndexer) handleTronPacketCreated(ctx context.Context, logData map[string]interface{}, txID string) { - log.ZInfo(ctx, "tron PacketCreated event", "txID", txID) +func (t *TronIndexer) handleTronPacketCreated(ctx context.Context, event *ParsedEvent, txID string) error { + packetID := GetPacketIDFromEvent(event) + creator := GetAddressFromEvent(event, "creator") + log.ZInfo(ctx, "tron PacketCreated event", "packetID", packetID.String(), "creator", creator.Hex(), "txID", txID) + return nil } -func (t *TronIndexer) handleTronPacketClaimed(ctx context.Context, logData map[string]interface{}, txID string) { - log.ZInfo(ctx, "tron PacketClaimed event", "txID", txID) +func (t *TronIndexer) handleTronPacketClaimed(ctx context.Context, event *ParsedEvent, txID string) error { + packetID := GetPacketIDFromEvent(event) + claimer := GetAddressFromEvent(event, "claimer") + amount := GetAmountFromEvent(event) + authNonce := GetUintFromEvent(event, "authNonce") - claimer := "unknown" - amount := "0" - - if topics, ok := logData["topics"].([]interface{}); ok && len(topics) > 1 { - if claimerTopic, ok := topics[1].(string); ok { - claimer = claimerTopic - } - } + log.ZInfo(ctx, "tron PacketClaimed event", "packetID", packetID.String(), "claimer", claimer.Hex(), "amount", amount.String(), "txID", txID) claim := &model.RedPacketClaim{ - PacketID: "tron-packet-" + txID[:8], - ClaimerWallet: claimer, + PacketID: packetID.String(), + ClaimerWallet: claimer.Hex(), + AuthNonce: authNonce.String(), ClaimTxHash: txID, - ClaimedAmount: amount, + ClaimedAmount: amount.String(), + BlockNumber: event.BlockNumber, Status: "CONFIRMED", + CreatedAt: time.Now(), + UpdatedAt: time.Now(), } - if err := t.db.SaveClaim(ctx, claim); err != nil { - log.ZWarn(ctx, "redpacket tron save claim failed", err) + return err + } + if err := t.db.MarkClaimAuthUsed(ctx, authNonce.String()); err != nil { + return err } + return t.db.UpdateRedPacketClaimProgress(ctx, packetID.String(), amount.String(), "") } -func (t *TronIndexer) handleTronPacketRefunded(ctx context.Context, logData map[string]interface{}, txID string) { - log.ZInfo(ctx, "tron PacketRefunded event", "txID", txID) +func (t *TronIndexer) handleTronPacketRefunded(ctx context.Context, event *ParsedEvent, txID string) error { + packetID := GetPacketIDFromEvent(event) + refundTo := GetAddressFromEvent(event, "refundTo") + amount := GetAmountFromEvent(event) + + log.ZInfo(ctx, "tron PacketRefunded event", "packetID", packetID.String(), "refundTo", refundTo.Hex(), "amount", amount.String(), "txID", txID) + + if err := t.db.SaveRefund(ctx, &model.RedPacketRefund{ + PacketID: packetID.String(), + RefundTo: refundTo.Hex(), + TxHash: txID, + Amount: amount.String(), + CreatedAt: time.Now(), + }); err != nil { + return err + } + return t.db.UpdateRedPacketStatus(ctx, packetID.String(), "REFUNDED") } func (t *TronIndexer) GetLastProcessedBlock() int64 { diff --git a/internal/rpc/redpacket/redpacket.go b/internal/rpc/redpacket/redpacket.go index 360b50e74..15b9b1139 100644 --- a/internal/rpc/redpacket/redpacket.go +++ b/internal/rpc/redpacket/redpacket.go @@ -66,8 +66,12 @@ func Start(ctx context.Context, conf *Config, registry discovery.SvcDiscoveryReg if err != nil { return err } + auditLogDB, err := mgo.NewAdminAuditLogMongo(db) + if err != nil { + return err + } - repo := controller.NewRedPacketDatabase(rpDB, claimDB, claimAuthDB, refundDB, challengeDB, bindingDB) + repo := controller.NewRedPacketDatabase(rpDB, claimDB, claimAuthDB, refundDB, challengeDB, bindingDB, auditLogDB) chainClient, err := chain.NewClient( conf.RpcConfig.Chain.RPCURL, diff --git a/internal/rpc/redpacket/service.go b/internal/rpc/redpacket/service.go index 220ee9cba..ca8339c23 100644 --- a/internal/rpc/redpacket/service.go +++ b/internal/rpc/redpacket/service.go @@ -895,6 +895,73 @@ func redPacketModelToProto(rp *model.RedPacket) *pbredpacket.RedPacketRecord { } } +// RequestRefund allows the red-packet creator to submit an on-chain refund +// transaction for an expired packet. The indexer will asynchronously pick up +// the on-chain RefundPacket event and mark the packet as REFUNDED in the DB. +func (s *redPacketServer) RequestRefund(ctx context.Context, req *pbredpacket.RequestRefundReq) (*pbredpacket.RequestRefundResp, error) { + currentUserID := mcontext.GetOpUserID(ctx) + if currentUserID == "" { + return nil, servererrs.ErrNoPermission.WrapMsg("op user id is empty") + } + if req.GetPacketID() == "" { + return nil, errs.ErrArgs.WrapMsg("packet_id is required") + } + + rp, err := s.db.GetRedPacketByPacketID(ctx, req.GetPacketID()) + if err != nil { + return nil, err + } + if rp.CreatorUserID != currentUserID { + return nil, errs.ErrNoPermission.WrapMsg("only the creator can request a refund") + } + if rp.Status == "REFUNDED" { + return &pbredpacket.RequestRefundResp{TxHash: "", Status: "REFUNDED"}, nil + } + if rp.ExpiryAt > 0 && time.Now().Unix() < rp.ExpiryAt { + return nil, errs.ErrArgs.WrapMsg("red packet has not expired yet") + } + + // Submit the on-chain refund transaction. + var txHash string + if s.chainClient != nil { + txHash, err = s.chainClient.RefundPacket(ctx, rp.PacketID) + if err != nil { + return nil, errs.ErrInternalServer.WrapMsg("submit refund tx failed: " + err.Error()) + } + } else if s.tronClient != nil { + packetIDBig, ok := new(big.Int).SetString(rp.PacketID, 10) + if !ok { + return nil, errs.ErrInternalServer.WrapMsg("invalid packet id format") + } + txHash, err = s.tronClient.SendAdminTransaction(ctx, "refundPacket", packetIDBig) + if err != nil { + return nil, errs.ErrInternalServer.WrapMsg("submit tron refund tx failed: " + err.Error()) + } + } else { + return nil, errs.ErrInternalServer.WrapMsg("no blockchain client configured") + } + + log.ZInfo(ctx, "redpacket refund submitted", "packetID", rp.PacketID, "txHash", txHash) + return &pbredpacket.RequestRefundResp{TxHash: txHash, Status: "PENDING"}, nil +} + +func (s *redPacketServer) GetRefund(ctx context.Context, req *pbredpacket.GetRefundReq) (*pbredpacket.GetRefundResp, error) { + if req.GetPacketID() == "" { + return nil, errs.ErrArgs.WrapMsg("packet_id is required") + } + refund, err := s.db.GetRefundByPacketID(ctx, req.GetPacketID()) + if err != nil { + return nil, err + } + return &pbredpacket.GetRefundResp{ + PacketID: refund.PacketID, + RefundTo: refund.RefundTo, + TxHash: refund.TxHash, + Amount: refund.Amount, + CreatedAt: refund.CreatedAt.Unix(), + }, nil +} + func claimsModelToProto(claims []*model.RedPacketClaim) []*pbredpacket.RedPacketClaimRecord { out := make([]*pbredpacket.RedPacketClaimRecord, 0, len(claims)) for _, c := range claims { diff --git a/internal/rpc/redpacket/wallet.go b/internal/rpc/redpacket/wallet.go index 569d96e84..f9f7de2d5 100644 --- a/internal/rpc/redpacket/wallet.go +++ b/internal/rpc/redpacket/wallet.go @@ -1,9 +1,12 @@ package redpacket import ( + "bytes" "context" + "crypto/sha256" "encoding/hex" "fmt" + "math/big" "strings" "time" @@ -99,20 +102,22 @@ func (s *redPacketServer) ConfirmWalletBind(ctx context.Context, req *pbredpacke return nil, errs.ErrArgs.WrapMsg("challenge is expired") } + var verifyErr error switch challenge.ChainType { case "EVM": - if err := verifyEVMBindSignature(challenge.Message, challenge.WalletAddress, req.Signature); err != nil { - challenge.Status = "FAILED" - challenge.Signature = req.Signature - challenge.UpdatedAt = time.Now() - _ = s.db.UpdateWalletBindingChallenge(ctx, challenge) - return nil, err - } + verifyErr = verifyEVMBindSignature(challenge.Message, challenge.WalletAddress, req.Signature) case "TRON": - return nil, errs.ErrInternalServer.WrapMsg("TRON wallet binding verification is not implemented yet") + verifyErr = verifyTRONBindSignature(challenge.Message, challenge.WalletAddress, req.Signature) default: return nil, errs.ErrArgs.WrapMsg("unsupported chain_type: " + challenge.ChainType) } + if verifyErr != nil { + challenge.Status = "FAILED" + challenge.Signature = req.Signature + challenge.UpdatedAt = time.Now() + _ = s.db.UpdateWalletBindingChallenge(ctx, challenge) + return nil, verifyErr + } now := time.Now().UTC() challenge.Status = "VERIFIED" @@ -249,3 +254,96 @@ func verifyEVMBindSignature(message, walletAddress, signature string) error { func personalSignMessage(message string) string { return fmt.Sprintf("\x19Ethereum Signed Message:\n%d%s", len(message), message) } + +// verifyTRONBindSignature verifies a TRON signMessageV2 (TronLink) signature. +// TRON uses the same secp256k1 curve as Ethereum; the only differences are: +// - message prefix: "\x19TRON Signed Message:\n" +// - wallet address: base58check-encoded with a leading 0x41 byte +func verifyTRONBindSignature(message, walletAddress, signature string) error { + if strings.TrimSpace(message) == "" { + return errs.ErrArgs.WrapMsg("bind message is empty") + } + + sig, err := hex.DecodeString(strings.TrimPrefix(signature, "0x")) + if err != nil { + return errs.ErrArgs.WrapMsg("decode tron signature failed: " + err.Error()) + } + if len(sig) != 65 { + return errs.ErrArgs.WrapMsg(fmt.Sprintf("invalid tron signature length: %d", len(sig))) + } + // Some TRON wallets encode v as 27/28; normalise to 0/1. + if sig[64] >= 27 { + sig[64] -= 27 + } + + prefix := fmt.Sprintf("\x19TRON Signed Message:\n%d", len(message)) + hash := crypto.Keccak256Hash([]byte(prefix + message)) + + pubKey, err := crypto.SigToPub(hash.Bytes(), sig) + if err != nil { + return errs.ErrInternalServer.WrapMsg("recover tron signer failed: " + err.Error()) + } + + // Derive the raw 20-byte address (identical derivation to Ethereum). + recoveredAddr := crypto.PubkeyToAddress(*pubKey) + + // Decode the TRON base58check address to its 20 raw bytes. + addrBytes, err := decodeTRONAddress(walletAddress) + if err != nil { + return errs.ErrArgs.WrapMsg("invalid tron address: " + err.Error()) + } + + if !bytes.Equal(recoveredAddr.Bytes(), addrBytes) { + return errs.ErrNoPermission.WrapMsg("tron signature does not match wallet address") + } + return nil +} + +// decodeTRONAddress decodes a TRON base58check address and returns the 20 +// raw address bytes (i.e., without the leading 0x41 network prefix byte). +func decodeTRONAddress(addr string) ([]byte, error) { + decoded := tronBase58Decode(addr) + if len(decoded) != 25 { + return nil, fmt.Errorf("invalid length %d", len(decoded)) + } + + payload := decoded[:21] + checksum := decoded[21:25] + h1 := sha256.Sum256(payload) + h2 := sha256.Sum256(h1[:]) + if !bytes.Equal(h2[:4], checksum) { + return nil, fmt.Errorf("invalid base58check checksum") + } + if payload[0] != 0x41 { + return nil, fmt.Errorf("invalid tron address prefix byte: 0x%02x", payload[0]) + } + return payload[1:], nil +} + +const tronBase58Alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz" + +func tronBase58Decode(s string) []byte { + n := new(big.Int) + base := big.NewInt(58) + for _, c := range s { + idx := strings.IndexRune(tronBase58Alphabet, c) + if idx < 0 { + return nil + } + n.Mul(n, base) + n.Add(n, big.NewInt(int64(idx))) + } + + decoded := n.Bytes() + leadingOnes := 0 + for _, c := range s { + if c == '1' { + leadingOnes++ + } else { + break + } + } + out := make([]byte, leadingOnes+len(decoded)) + copy(out[leadingOnes:], decoded) + return out +} diff --git a/pkg/common/storage/controller/redpacket.go b/pkg/common/storage/controller/redpacket.go index d052a9008..cf7268c65 100644 --- a/pkg/common/storage/controller/redpacket.go +++ b/pkg/common/storage/controller/redpacket.go @@ -17,6 +17,7 @@ type RedPacketDatabase interface { UpdateRedPacketCreated(ctx context.Context, rp *model.RedPacket) error UpdateRedPacketStatus(ctx context.Context, packetID, status string) error UpdateRedPacketClaimProgress(ctx context.Context, packetID, claimedAmount, status string) error + GetExpiredPendingPackets(ctx context.Context, nowUnix int64) ([]*model.RedPacket, error) CreateClaimAuth(ctx context.Context, auth *model.RedPacketClaimAuth) error GetClaimAuth(ctx context.Context, packetID, claimer string) (*model.RedPacketClaimAuth, error) @@ -28,6 +29,7 @@ type RedPacketDatabase interface { GetClaimsByPacketID(ctx context.Context, packetID string) ([]*model.RedPacketClaim, error) SaveRefund(ctx context.Context, refund *model.RedPacketRefund) error + GetRefundByPacketID(ctx context.Context, packetID string) (*model.RedPacketRefund, error) CreateWalletBindingChallenge(ctx context.Context, challenge *model.WalletBindingChallenge) error GetWalletBindingChallenge(ctx context.Context, challengeID string) (*model.WalletBindingChallenge, error) @@ -35,6 +37,8 @@ type RedPacketDatabase interface { UpsertWalletBinding(ctx context.Context, binding *model.WalletBinding) error GetActiveWalletBinding(ctx context.Context, userID, chainType, walletAddress string) (*model.WalletBinding, error) + + CreateAdminAuditLog(ctx context.Context, entry *model.AdminAuditLog) error } type redPacketDatabase struct { @@ -44,6 +48,7 @@ type redPacketDatabase struct { refund database.RedPacketRefund challenge database.WalletBindingChallenge binding database.WalletBinding + auditLog database.AdminAuditLog } func NewRedPacketDatabase( @@ -53,6 +58,7 @@ func NewRedPacketDatabase( refund database.RedPacketRefund, challenge database.WalletBindingChallenge, binding database.WalletBinding, + auditLog database.AdminAuditLog, ) RedPacketDatabase { return &redPacketDatabase{ rp: rp, @@ -61,6 +67,7 @@ func NewRedPacketDatabase( refund: refund, challenge: challenge, binding: binding, + auditLog: auditLog, } } @@ -120,6 +127,18 @@ func (d *redPacketDatabase) SaveRefund(ctx context.Context, refund *model.RedPac return d.refund.Save(ctx, refund) } +func (d *redPacketDatabase) GetRefundByPacketID(ctx context.Context, packetID string) (*model.RedPacketRefund, error) { + return d.refund.GetByPacketID(ctx, packetID) +} + +func (d *redPacketDatabase) GetExpiredPendingPackets(ctx context.Context, nowUnix int64) ([]*model.RedPacket, error) { + return d.rp.GetExpiredPending(ctx, nowUnix) +} + +func (d *redPacketDatabase) CreateAdminAuditLog(ctx context.Context, entry *model.AdminAuditLog) error { + return d.auditLog.Create(ctx, entry) +} + func (d *redPacketDatabase) CreateWalletBindingChallenge(ctx context.Context, challenge *model.WalletBindingChallenge) error { return d.challenge.Create(ctx, challenge) } diff --git a/pkg/common/storage/database/mgo/redpacket.go b/pkg/common/storage/database/mgo/redpacket.go index bf0579228..bf61033d4 100644 --- a/pkg/common/storage/database/mgo/redpacket.go +++ b/pkg/common/storage/database/mgo/redpacket.go @@ -331,6 +331,18 @@ func (m *RedPacketRefundMgo) Save(ctx context.Context, refund *model.RedPacketRe return err } +func (m *RedPacketRefundMgo) GetByPacketID(ctx context.Context, packetID string) (*model.RedPacketRefund, error) { + var r model.RedPacketRefund + err := m.coll.FindOne(ctx, bson.M{"packet_id": packetID}).Decode(&r) + if err != nil { + if err == mongo.ErrNoDocuments { + return nil, errs.ErrRecordNotFound.WrapMsg("refund not found", "packetID", packetID) + } + return nil, err + } + return &r, nil +} + // ---- WalletBindingChallenge ---- type WalletBindingChallengeMgo struct { @@ -414,6 +426,24 @@ func NewWalletBindingMongo(db *mongo.Database) (database.WalletBinding, error) { return &WalletBindingMgo{coll: coll}, nil } +// GetExpiredPending returns red packets that have expired but are still in +// "CREATED" status (i.e., not yet refunded or fully claimed). +func (m *RedPacketMgo) GetExpiredPending(ctx context.Context, now int64) ([]*model.RedPacket, error) { + cur, err := m.coll.Find(ctx, bson.M{ + "status": "CREATED", + "expiry_at": bson.M{"$lt": now, "$gt": 0}, + }) + if err != nil { + return nil, err + } + defer cur.Close(ctx) + var out []*model.RedPacket + if err := cur.All(ctx, &out); err != nil { + return nil, err + } + return out, nil +} + func (m *WalletBindingMgo) Upsert(ctx context.Context, b *model.WalletBinding) error { filter := bson.M{ "user_id": b.UserID, @@ -454,3 +484,26 @@ func (m *WalletBindingMgo) GetActive(ctx context.Context, userID, chainType, wal } return &b, nil } + +// ---- AdminAuditLog ---- + +type AdminAuditLogMgo struct { + coll *mongo.Collection +} + +func NewAdminAuditLogMongo(db *mongo.Database) (database.AdminAuditLog, error) { + coll := db.Collection("admin_audit_log") + _, err := coll.Indexes().CreateMany(context.Background(), []mongo.IndexModel{ + {Keys: bson.D{{Key: "operator_id", Value: 1}}}, + {Keys: bson.D{{Key: "created_at", Value: -1}}}, + }) + if err != nil { + return nil, err + } + return &AdminAuditLogMgo{coll: coll}, nil +} + +func (m *AdminAuditLogMgo) Create(ctx context.Context, entry *model.AdminAuditLog) error { + _, err := m.coll.InsertOne(ctx, entry) + return err +} diff --git a/pkg/common/storage/database/redpacket.go b/pkg/common/storage/database/redpacket.go index dff792fcc..5beddf600 100644 --- a/pkg/common/storage/database/redpacket.go +++ b/pkg/common/storage/database/redpacket.go @@ -13,6 +13,8 @@ type RedPacket interface { UpdateCreated(ctx context.Context, rp *model.RedPacket) error UpdateStatus(ctx context.Context, packetID, status string) error UpdateClaimProgress(ctx context.Context, packetID, claimedAmount, status string) error + // GetExpiredPending returns CREATED packets whose expiry_at < now (unix seconds). + GetExpiredPending(ctx context.Context, now int64) ([]*model.RedPacket, error) } type RedPacketClaim interface { @@ -30,6 +32,11 @@ type RedPacketClaimAuth interface { type RedPacketRefund interface { Save(ctx context.Context, refund *model.RedPacketRefund) error + GetByPacketID(ctx context.Context, packetID string) (*model.RedPacketRefund, error) +} + +type AdminAuditLog interface { + Create(ctx context.Context, log *model.AdminAuditLog) error } type WalletBindingChallenge interface { diff --git a/pkg/common/storage/model/redpacket.go b/pkg/common/storage/model/redpacket.go index 1b014b93c..ce697840c 100644 --- a/pkg/common/storage/model/redpacket.go +++ b/pkg/common/storage/model/redpacket.go @@ -1,6 +1,10 @@ package model -import "time" +import ( + "time" + + "go.mongodb.org/mongo-driver/bson/primitive" +) type RedPacket struct { BizID string `bson:"biz_id"` @@ -89,3 +93,14 @@ type WalletBinding struct { CreatedAt time.Time `bson:"created_at"` UpdatedAt time.Time `bson:"updated_at"` } + +// AdminAuditLog records each admin operation for accountability. +type AdminAuditLog struct { + ID primitive.ObjectID `bson:"_id"` + OperatorID string `bson:"operator_id"` + Action string `bson:"action"` + Params string `bson:"params"` // JSON-encoded request + Result string `bson:"result"` // "success" | "failed" + ErrMsg string `bson:"err_msg"` + CreatedAt time.Time `bson:"created_at"` +} diff --git a/protocol b/protocol index 34a58a77d..c69f02cf6 160000 --- a/protocol +++ b/protocol @@ -1 +1 @@ -Subproject commit 34a58a77d26a3c133a4be9ce00affdca8b158ba4 +Subproject commit c69f02cf664231e963501889263d4c9963dc3fca