From cffa2403f80be85cf248423b1d51c9eaeef437a6 Mon Sep 17 00:00:00 2001 From: hawklin2017 <32898629+hawklin2017@users.noreply.github.com> Date: Thu, 30 Apr 2026 21:24:24 +0800 Subject: [PATCH] redpacket --- internal/rpc/redpacket/admin.go | 27 +++- internal/rpc/redpacket/chain/indexer.go | 15 +- internal/rpc/redpacket/chain/tron.go | 4 +- internal/rpc/redpacket/chain/tron_indexer.go | 29 +++- internal/rpc/redpacket/service.go | 162 ++++++++++--------- pkg/common/storage/controller/redpacket.go | 6 +- pkg/common/storage/database/mgo/redpacket.go | 44 ++++- pkg/common/storage/database/redpacket.go | 8 +- pkg/common/storage/model/redpacket.go | 7 +- 9 files changed, 191 insertions(+), 111 deletions(-) diff --git a/internal/rpc/redpacket/admin.go b/internal/rpc/redpacket/admin.go index b7ea91a37..e2802d7cb 100644 --- a/internal/rpc/redpacket/admin.go +++ b/internal/rpc/redpacket/admin.go @@ -81,7 +81,9 @@ func (s *redPacketServer) SetToken(ctx context.Context, req *pbredpacket.SetToke minAmountBig := new(big.Int) if req.MinAmount != "" { - minAmountBig.SetString(req.MinAmount, 10) + if _, ok := minAmountBig.SetString(req.MinAmount, 10); !ok { + return nil, errs.ErrArgs.WrapMsg("invalid min_amount", "minAmount", req.MinAmount) + } } if s.chainClient != nil { @@ -167,12 +169,23 @@ func (s *redPacketServer) ParseTxEvents(ctx context.Context, req *pbredpacket.Pa return nil, errs.ErrArgs.WrapMsg("tx_hash is required") } - if req.Chain == "tron" && s.tronClient != nil { - return &pbredpacket.ParseTxEventsResp{ - Chain: "tron", - TxHash: req.TxHash, - Note: "TRON event parsing not fully implemented in this version", - }, nil + if req.Chain == "tron" { + if s.tronClient == nil { + return nil, errs.ErrInternalServer.WrapMsg("TRON client not configured") + } + events, err := s.tronClient.ParseTransactionReceipt(ctx, req.TxHash) + if err != nil { + return nil, errs.ErrInternalServer.WrapMsg("parse TRON tx receipt failed: " + err.Error()) + } + out := make([]*pbredpacket.ParsedEvent, 0, len(events)) + for _, e := range events { + data := make(map[string]string, len(e.Data)) + for k, v := range e.Data { + data[k] = fmt.Sprintf("%v", v) + } + out = append(out, &pbredpacket.ParsedEvent{Name: e.Name, Data: data}) + } + return &pbredpacket.ParseTxEventsResp{Chain: "tron", TxHash: req.TxHash, Events: out}, nil } if s.chainClient != nil { diff --git a/internal/rpc/redpacket/chain/indexer.go b/internal/rpc/redpacket/chain/indexer.go index fd330560a..590b6049d 100644 --- a/internal/rpc/redpacket/chain/indexer.go +++ b/internal/rpc/redpacket/chain/indexer.go @@ -39,6 +39,11 @@ func (i *Indexer) Start(ctx context.Context) { log.ZInfo(ctx, "starting RedPacket ETH event indexer") go func() { + defer func() { + if r := recover(); r != nil { + log.ZError(ctx, "redpacket eth indexer panic recovered", fmt.Errorf("%v", r)) + } + }() ticker := time.NewTicker(i.pollInterval) defer ticker.Stop() for { @@ -58,6 +63,11 @@ func (i *Indexer) Start(ctx context.Context) { // and mark them EXPIRED so the UI reflects the correct state even if the // on-chain refund event was missed. go func() { + defer func() { + if r := recover(); r != nil { + log.ZError(ctx, "redpacket eth compensation panic recovered", fmt.Errorf("%v", r)) + } + }() ticker := time.NewTicker(60 * time.Second) defer ticker.Stop() for { @@ -178,7 +188,10 @@ func (i *Indexer) handlePacketClaimed(ctx context.Context, event *ParsedEvent) e if err := i.db.MarkClaimAuthUsed(ctx, authNonce.String()); err != nil { return err } - return i.db.UpdateRedPacketClaimProgress(ctx, packetID.String(), amount.String(), "") + // Pass "" for forced status; DB layer auto-derives COMPLETED/ACTIVE. + // TxHash is the idempotency key: prevents double-counting if ClaimResult RPC + // already processed this same transaction. + return i.db.UpdateRedPacketClaimProgress(ctx, packetID.String(), amount.String(), "", event.TxHash.Hex()) } func (i *Indexer) handlePacketRefunded(ctx context.Context, event *ParsedEvent) error { diff --git a/internal/rpc/redpacket/chain/tron.go b/internal/rpc/redpacket/chain/tron.go index 93f965522..08ff077da 100644 --- a/internal/rpc/redpacket/chain/tron.go +++ b/internal/rpc/redpacket/chain/tron.go @@ -10,6 +10,7 @@ import ( "math/big" "net/http" "strings" + "time" "github.com/ethereum/go-ethereum/accounts/abi" "github.com/ethereum/go-ethereum/common" @@ -274,7 +275,8 @@ func postJSON(ctx context.Context, url string, body interface{}, out interface{} } req.Header.Set("Content-Type", "application/json") - resp, err := http.DefaultClient.Do(req) + httpClient := &http.Client{Timeout: 10 * time.Second} + resp, err := httpClient.Do(req) if err != nil { return err } diff --git a/internal/rpc/redpacket/chain/tron_indexer.go b/internal/rpc/redpacket/chain/tron_indexer.go index 81f56f3c8..526513367 100644 --- a/internal/rpc/redpacket/chain/tron_indexer.go +++ b/internal/rpc/redpacket/chain/tron_indexer.go @@ -16,7 +16,6 @@ type TronIndexer struct { pollInterval time.Duration lastBlockNum int64 contractAddress string - processedTxs map[string]bool } func NewTronIndexer(client *TronClient, db controller.RedPacketDatabase, pollInterval int, startBlock int64) *TronIndexer { @@ -29,7 +28,6 @@ func NewTronIndexer(client *TronClient, db controller.RedPacketDatabase, pollInt pollInterval: time.Duration(pollInterval) * time.Second, lastBlockNum: startBlock, contractAddress: client.contractBase58, - processedTxs: make(map[string]bool), } } @@ -37,6 +35,11 @@ func (t *TronIndexer) Start(ctx context.Context) { log.ZInfo(ctx, "starting RedPacket TRON event indexer") go func() { + defer func() { + if r := recover(); r != nil { + log.ZError(ctx, "redpacket tron indexer panic recovered", fmt.Errorf("%v", r)) + } + }() ticker := time.NewTicker(t.pollInterval) defer ticker.Stop() for { @@ -54,6 +57,11 @@ func (t *TronIndexer) Start(ctx context.Context) { }() go func() { + defer func() { + if r := recover(); r != nil { + log.ZError(ctx, "redpacket tron compensation panic recovered", fmt.Errorf("%v", r)) + } + }() ticker := time.NewTicker(60 * time.Second) defer ticker.Stop() for { @@ -97,14 +105,18 @@ func (t *TronIndexer) poll(ctx context.Context) error { log.ZDebug(ctx, "redpacket tron scanning blocks", "from", t.lastBlockNum+1, "to", currentBlock) + // Advance the cursor only up to the last successfully processed block so + // that a transient RPC failure does not cause blocks to be silently skipped. + lastOK := t.lastBlockNum for blockNum := t.lastBlockNum + 1; blockNum <= currentBlock; blockNum++ { if err := t.scanBlock(ctx, blockNum); err != nil { log.ZWarn(ctx, "redpacket tron scan block failed", err, "block", blockNum) - continue + break } + lastOK = blockNum } - t.lastBlockNum = currentBlock + t.lastBlockNum = lastOK return nil } @@ -147,14 +159,12 @@ func (t *TronIndexer) scanBlock(ctx context.Context, blockNum int64) error { } txID, _ := tx["txID"].(string) - if txID == "" || t.processedTxs[txID] { + if txID == "" { continue } if err := t.processTransaction(ctx, txID); err != nil { log.ZWarn(ctx, "redpacket tron process tx failed", err, "txID", txID) - } else { - t.processedTxs[txID] = true } } @@ -221,7 +231,10 @@ func (t *TronIndexer) handleTronPacketClaimed(ctx context.Context, event *Parsed if err := t.db.MarkClaimAuthUsed(ctx, authNonce.String()); err != nil { return err } - return t.db.UpdateRedPacketClaimProgress(ctx, packetID.String(), amount.String(), "") + // Pass "" for forced status; DB layer auto-derives COMPLETED/ACTIVE. + // txID is the idempotency key: prevents double-counting if ClaimResult RPC + // already processed this same transaction. + return t.db.UpdateRedPacketClaimProgress(ctx, packetID.String(), amount.String(), "", txID) } func (t *TronIndexer) handleTronPacketRefunded(ctx context.Context, event *ParsedEvent, txID string) error { diff --git a/internal/rpc/redpacket/service.go b/internal/rpc/redpacket/service.go index ca8339c23..71aad905a 100644 --- a/internal/rpc/redpacket/service.go +++ b/internal/rpc/redpacket/service.go @@ -85,6 +85,10 @@ func (s *redPacketServer) CreateOrder(ctx context.Context, req *pbredpacket.Crea } func (s *redPacketServer) CreatedCallback(ctx context.Context, req *pbredpacket.CreatedCallbackReq) (*pbredpacket.CreatedCallbackResp, error) { + opUserID := mcontext.GetOpUserID(ctx) + if opUserID == "" { + return nil, servererrs.ErrNoPermission.WrapMsg("op user id is empty") + } if strings.TrimSpace(req.BizID) == "" || strings.TrimSpace(req.TxHash) == "" { return nil, errs.ErrArgs.WrapMsg("biz_id and tx_hash are required") } @@ -93,6 +97,9 @@ func (s *redPacketServer) CreatedCallback(ctx context.Context, req *pbredpacket. if err != nil { return nil, err } + if rp.CreatorUserID != opUserID { + return nil, servererrs.ErrNoPermission.WrapMsg("only the creator can submit the creation callback") + } groupID := firstNonEmpty(req.GroupID, rp.GroupID) scopeType := normalizeScopeType(firstNonEmpty(req.ScopeType, rp.ScopeType)) @@ -202,7 +209,7 @@ func (s *redPacketServer) IssueClaimSign(ctx context.Context, req *pbredpacket.I signature[64] += 27 } } else { - signature = []byte("0xplaceholder-signature-for-testing") + return nil, errs.ErrInternalServer.WrapMsg("signer key not configured; cannot issue claim signature") } sigHex := "0x" + hex.EncodeToString(signature) @@ -295,8 +302,10 @@ func (s *redPacketServer) ClaimResult(ctx context.Context, req *pbredpacket.Clai } } - nextStatus := derivePacketStatusAfterClaim(rp, claimedEvent.Amount) - if err := s.db.UpdateRedPacketClaimProgress(ctx, req.PacketID, claimedEvent.Amount, nextStatus); err != nil { + // Pass "" for status so the DB layer auto-derives COMPLETED/ACTIVE. + // Pass req.TxHash as the idempotency key so concurrent indexer processing + // of the same transaction cannot double-count the claim. + if err := s.db.UpdateRedPacketClaimProgress(ctx, req.PacketID, claimedEvent.Amount, "", req.TxHash); err != nil { return nil, err } return &pbredpacket.ClaimResultResp{}, nil @@ -350,6 +359,7 @@ type createdPacketSnapshot struct { func (s *redPacketServer) resolveCreatedPacket(ctx context.Context, rp *model.RedPacket, txHashHex, fallbackPacketID string) (*createdPacketSnapshot, error) { switch rp.ChainType { case "EVM": + // Offline mode: no chain client configured; caller must supply packet_id directly. if s.chainClient == nil { if fallbackPacketID == "" { return nil, errs.ErrArgs.WrapMsg("packet_id is required when EVM client is unavailable") @@ -359,10 +369,7 @@ func (s *redPacketServer) resolveCreatedPacket(ctx context.Context, rp *model.Re events, err := s.chainClient.ParseTransactionReceipt(ctx, common.HexToHash(txHashHex)) if err != nil { - if fallbackPacketID == "" { - return nil, errs.ErrInternalServer.WrapMsg("parse created tx failed: " + err.Error()) - } - return buildFallbackCreatedPacket(rp, fallbackPacketID), nil + return nil, errs.ErrInternalServer.WrapMsg("parse created tx failed: " + err.Error()) } for _, event := range events { @@ -379,12 +386,9 @@ func (s *redPacketServer) resolveCreatedPacket(ctx context.Context, rp *model.Re } return createdPacket, nil } - - if fallbackPacketID == "" { - return nil, errs.ErrInternalServer.WrapMsg("PacketCreated event not found in tx: " + txHashHex) - } - return buildFallbackCreatedPacket(rp, fallbackPacketID), nil + return nil, errs.ErrInternalServer.WrapMsg("PacketCreated event not found in tx: " + txHashHex) case "TRON": + // Offline mode: no chain client configured; caller must supply packet_id directly. if s.tronClient == nil { if fallbackPacketID == "" { return nil, errs.ErrArgs.WrapMsg("packet_id is required when TRON client is unavailable") @@ -394,10 +398,7 @@ func (s *redPacketServer) resolveCreatedPacket(ctx context.Context, rp *model.Re events, err := s.tronClient.ParseTransactionReceipt(ctx, txHashHex) if err != nil { - if fallbackPacketID == "" { - return nil, errs.ErrInternalServer.WrapMsg("parse tron created tx failed: " + err.Error()) - } - return buildFallbackCreatedPacket(rp, fallbackPacketID), nil + return nil, errs.ErrInternalServer.WrapMsg("parse tron created tx failed: " + err.Error()) } for _, event := range events { @@ -411,11 +412,7 @@ func (s *redPacketServer) resolveCreatedPacket(ctx context.Context, rp *model.Re } return createdPacket, nil } - - if fallbackPacketID == "" { - return nil, errs.ErrInternalServer.WrapMsg("PacketCreated event not found in TRON tx: " + txHashHex) - } - return buildFallbackCreatedPacket(rp, fallbackPacketID), nil + return nil, errs.ErrInternalServer.WrapMsg("PacketCreated event not found in TRON tx: " + txHashHex) default: return nil, errs.ErrArgs.WrapMsg("unsupported chain_type: " + rp.ChainType) } @@ -486,17 +483,24 @@ func (s *redPacketServer) validateCreatorScope(ctx context.Context, req *pbredpa // validateFixedPacketCreate validates fixed red packets: // - shared base fields -// - total_shares > 0 +// - scope_type must be GROUP (fixed packets are group-only; claim validators require group_id) +// - 0 < total_shares <= maxTotalShares // - total_amount must be divisible by total_shares (each share is an integer in min units) -// - scope-based group/friend relationship for the creator +// - creator must be an active member of the group func (s *redPacketServer) validateFixedPacketCreate(ctx context.Context, req *pbredpacket.CreateOrderReq) error { total, err := validateCreateBaseFields(req) if err != nil { return err } + if normalizeScopeType(req.ScopeType) != "GROUP" { + return errs.ErrArgs.WrapMsg("fixed packet must use scope_type=GROUP") + } if req.TotalShares <= 0 { return errs.ErrArgs.WrapMsg("total_shares must be positive for fixed packet", "totalShares", req.TotalShares) } + if req.TotalShares > maxTotalShares { + return errs.ErrArgs.WrapMsg(fmt.Sprintf("total_shares must not exceed %d for fixed packet", maxTotalShares), "totalShares", req.TotalShares) + } shares := big.NewInt(int64(req.TotalShares)) if new(big.Int).Mod(total, shares).Sign() != 0 { return errs.ErrArgs.WrapMsg("total_amount must be divisible by total_shares for fixed packet", @@ -507,17 +511,24 @@ func (s *redPacketServer) validateFixedPacketCreate(ctx context.Context, req *pb // validateRandomPacketCreate validates random (lucky) red packets: // - shared base fields -// - total_shares > 0 +// - scope_type must be GROUP (random packets are group-only; claim validators require group_id) +// - 0 < total_shares <= maxTotalShares // - total_amount >= total_shares (at least 1 min unit per share) -// - scope-based group/friend relationship for the creator +// - creator must be an active member of the group func (s *redPacketServer) validateRandomPacketCreate(ctx context.Context, req *pbredpacket.CreateOrderReq) error { total, err := validateCreateBaseFields(req) if err != nil { return err } + if normalizeScopeType(req.ScopeType) != "GROUP" { + return errs.ErrArgs.WrapMsg("random packet must use scope_type=GROUP") + } if req.TotalShares <= 0 { return errs.ErrArgs.WrapMsg("total_shares must be positive for random packet", "totalShares", req.TotalShares) } + if req.TotalShares > maxTotalShares { + return errs.ErrArgs.WrapMsg(fmt.Sprintf("total_shares must not exceed %d for random packet", maxTotalShares), "totalShares", req.TotalShares) + } shares := big.NewInt(int64(req.TotalShares)) if total.Cmp(shares) < 0 { return errs.ErrArgs.WrapMsg("total_amount must be >= total_shares for random packet", @@ -528,26 +539,36 @@ func (s *redPacketServer) validateRandomPacketCreate(ctx context.Context, req *p // validateTransferPacketCreate validates transfer red packets: // - shared base fields +// - scope_type must be DIRECT (transfer is a 1-to-1 direct send) // - total_shares == 1 -// - exactly one receiver_user_id, must be a friend of the creator +// - exactly one receiver_user_id (receiver_user_ids must be empty) +// - receiver must not be the creator (no self-transfer) +// - creator and receiver must be friends func (s *redPacketServer) validateTransferPacketCreate(ctx context.Context, req *pbredpacket.CreateOrderReq) error { if _, err := validateCreateBaseFields(req); err != nil { return err } + if normalizeScopeType(req.ScopeType) != "DIRECT" { + return errs.ErrArgs.WrapMsg("transfer packet must use scope_type=DIRECT") + } if req.TotalShares != 1 { return errs.ErrArgs.WrapMsg("transfer packet must have total_shares == 1", "totalShares", req.TotalShares) } + // Reject ambiguous input: receiver_user_ids is not applicable for transfer. + if len(req.ReceiverUserIDs) > 0 { + return errs.ErrArgs.WrapMsg("transfer packet uses receiver_user_id (singular), not receiver_user_ids") + } receiverUserID := strings.TrimSpace(req.ReceiverUserID) if receiverUserID == "" { return errs.ErrArgs.WrapMsg("receiver_user_id is required for transfer packet") } - if len(req.ReceiverUserIDs) > 0 { - return errs.ErrArgs.WrapMsg("transfer packet only supports a single receiver_user_id") - } creatorUserID := mcontext.GetOpUserID(ctx) if creatorUserID == "" { return servererrs.ErrNoPermission.WrapMsg("op user id is empty") } + if creatorUserID == receiverUserID { + return errs.ErrArgs.WrapMsg("transfer packet cannot be sent to yourself") + } return s.ensureFriendRelationship(ctx, creatorUserID, receiverUserID) } @@ -615,14 +636,20 @@ func validateClaimBase(rp *model.RedPacket, userID, claimer string) error { if strings.TrimSpace(claimer) == "" { return errs.ErrArgs.WrapMsg("claimer is required") } - if rp.Status != "ACTIVE" { - return errs.ErrArgs.WrapMsg("packet is not active, current status: " + rp.Status) + // Check status first to give precise error messages for each terminal state. + switch rp.Status { + case "ACTIVE": + // ok, continue to expiry check + case "REFUNDED": + return errs.ErrArgs.WrapMsg("packet has been refunded") + case "EXPIRED": + return errs.ErrArgs.WrapMsg("packet has expired") + default: + return errs.ErrArgs.WrapMsg("packet is not claimable, current status: " + rp.Status) } + // Guard against the race where status is still ACTIVE but expiry has passed. if rp.ExpiryAt > 0 && rp.ExpiryAt <= time.Now().Unix() { - return errs.ErrArgs.WrapMsg("packet is expired") - } - if rp.Status == "REFUNDED" { - return errs.ErrArgs.WrapMsg("packet is refunded") + return errs.ErrArgs.WrapMsg("packet has expired") } return nil } @@ -713,27 +740,34 @@ func (s *redPacketServer) ensureGroupEligibility(ctx context.Context, groupID, u return nil } -// ensureFriendRelationship verifies that creatorUserID and receiverUserID are friends -// (used by transfer red packets to require a pre-existing relationship). -func (s *redPacketServer) ensureFriendRelationship(ctx context.Context, creatorUserID, receiverUserID string) error { - creatorUserID = strings.TrimSpace(creatorUserID) - receiverUserID = strings.TrimSpace(receiverUserID) - if creatorUserID == "" || receiverUserID == "" { - return errs.ErrArgs.WrapMsg("creator_user_id and receiver_user_id are required") - } - if creatorUserID == receiverUserID { +// ensureFriendRelationship verifies that userA and userB are mutual friends. +// It is used in two contexts: +// - validateCreatorScope (DIRECT scope): checking that each listed receiver is +// a friend of the creator. In that path userA == userB is theoretically possible +// (creator adding themselves to a list), which is allowed here; the transfer +// validator has its own explicit self-transfer prohibition. +// - validateTransferPacketClaim: re-confirming the friendship at claim time. +// +// Self-transfer is intentionally allowed at this level; call sites that need to +// prohibit it (e.g. validateTransferPacketCreate) must do so before calling here. +func (s *redPacketServer) ensureFriendRelationship(ctx context.Context, userA, userB string) error { + userA = strings.TrimSpace(userA) + userB = strings.TrimSpace(userB) + if userA == "" || userB == "" { + return errs.ErrArgs.WrapMsg("both user IDs are required for friend relationship check") + } + if userA == userB { return nil } if s.relationClient == nil { return servererrs.ErrInternalServer.WrapMsg("relation client is not initialized") } - ok, err := s.relationClient.IsFriend(ctx, creatorUserID, receiverUserID) + ok, err := s.relationClient.IsFriend(ctx, userA, userB) if err != nil { return err } if !ok { - return errs.ErrNoPermission.WrapMsg("creator and receiver are not friends", - "creatorUserID", creatorUserID, "receiverUserID", receiverUserID) + return errs.ErrNoPermission.WrapMsg("users are not friends", "userA", userA, "userB", userB) } return nil } @@ -782,38 +816,8 @@ func (s *redPacketServer) resolveClaimedEvent(ctx context.Context, rp *model.Red return nil, nil } -func derivePacketStatusAfterClaim(rp *model.RedPacket, claimedAmount string) string { - if rp == nil { - return "" - } - if rp.PacketType == 2 { - return "COMPLETED" - } - - nextShares := rp.ClaimedShares + 1 - if rp.TotalShares > 0 && nextShares >= rp.TotalShares { - return "COMPLETED" - } - - totalClaimed := addNumericStrings(rp.ClaimedAmount, claimedAmount) - if rp.TotalAmount != "" && totalClaimed == rp.TotalAmount { - return "COMPLETED" - } - - return "ACTIVE" -} - -func addNumericStrings(current, delta string) string { - left := new(big.Int) - if current != "" { - left.SetString(current, 10) - } - right := new(big.Int) - if delta != "" { - right.SetString(delta, 10) - } - return new(big.Int).Add(left, right).String() -} +// maxTotalShares caps the number of shares to prevent abuse. +const maxTotalShares = 10_000 func normalizeScopeType(scopeType string) string { switch strings.ToUpper(strings.TrimSpace(scopeType)) { diff --git a/pkg/common/storage/controller/redpacket.go b/pkg/common/storage/controller/redpacket.go index cf7268c65..7bdab8992 100644 --- a/pkg/common/storage/controller/redpacket.go +++ b/pkg/common/storage/controller/redpacket.go @@ -16,7 +16,7 @@ type RedPacketDatabase interface { GetRedPacketByPacketID(ctx context.Context, packetID string) (*model.RedPacket, error) UpdateRedPacketCreated(ctx context.Context, rp *model.RedPacket) error UpdateRedPacketStatus(ctx context.Context, packetID, status string) error - UpdateRedPacketClaimProgress(ctx context.Context, packetID, claimedAmount, status string) error + UpdateRedPacketClaimProgress(ctx context.Context, packetID, claimedAmount, status, claimTxHash string) error GetExpiredPendingPackets(ctx context.Context, nowUnix int64) ([]*model.RedPacket, error) CreateClaimAuth(ctx context.Context, auth *model.RedPacketClaimAuth) error @@ -91,8 +91,8 @@ func (d *redPacketDatabase) UpdateRedPacketStatus(ctx context.Context, packetID, return d.rp.UpdateStatus(ctx, packetID, status) } -func (d *redPacketDatabase) UpdateRedPacketClaimProgress(ctx context.Context, packetID, claimedAmount, status string) error { - return d.rp.UpdateClaimProgress(ctx, packetID, claimedAmount, status) +func (d *redPacketDatabase) UpdateRedPacketClaimProgress(ctx context.Context, packetID, claimedAmount, status, claimTxHash string) error { + return d.rp.UpdateClaimProgress(ctx, packetID, claimedAmount, status, claimTxHash) } func (d *redPacketDatabase) CreateClaimAuth(ctx context.Context, auth *model.RedPacketClaimAuth) error { diff --git a/pkg/common/storage/database/mgo/redpacket.go b/pkg/common/storage/database/mgo/redpacket.go index bf61033d4..0cf51b4c5 100644 --- a/pkg/common/storage/database/mgo/redpacket.go +++ b/pkg/common/storage/database/mgo/redpacket.go @@ -104,7 +104,7 @@ func (m *RedPacketMgo) UpdateStatus(ctx context.Context, packetID, status string return nil } -func (m *RedPacketMgo) UpdateClaimProgress(ctx context.Context, packetID, claimedAmount, status string) error { +func (m *RedPacketMgo) UpdateClaimProgress(ctx context.Context, packetID, claimedAmount, status, claimTxHash string) error { var rp model.RedPacket err := m.coll.FindOne(ctx, bson.M{"packet_id": packetID}).Decode(&rp) if err != nil { @@ -116,15 +116,45 @@ func (m *RedPacketMgo) UpdateClaimProgress(ctx context.Context, packetID, claime totalClaimed := addNumericStrings(rp.ClaimedAmount, claimedAmount) nextShares := rp.ClaimedShares + 1 - updates := bson.M{ + + // Auto-derive status when the caller does not force one. + nextStatus := status + if nextStatus == "" { + if rp.PacketType == 2 { + nextStatus = "COMPLETED" + } else if rp.TotalShares > 0 && nextShares >= rp.TotalShares { + nextStatus = "COMPLETED" + } else { + tcBig, tok := new(big.Int).SetString(totalClaimed, 10) + taBig, taok := new(big.Int).SetString(rp.TotalAmount, 10) + if tok && taok && tcBig.Cmp(taBig) >= 0 { + nextStatus = "COMPLETED" + } + } + } + + setFields := bson.M{ "claimed_amount": totalClaimed, "claimed_shares": nextShares, "updated_at": time.Now(), } - if status != "" { - updates["status"] = status + if nextStatus != "" { + setFields["status"] = nextStatus + } + + // The $addToSet + $ne filter makes the whole update idempotent per claimTxHash: + // if two code paths (RPC handler and indexer) both attempt to process the same + // transaction, only the first UpdateOne will match and the second is a no-op. + filter := bson.M{"packet_id": packetID} + if claimTxHash != "" { + filter["processed_claim_hashes"] = bson.M{"$ne": claimTxHash} } - _, err = m.coll.UpdateOne(ctx, bson.M{"packet_id": packetID}, bson.M{"$set": updates}) + update := bson.M{"$set": setFields} + if claimTxHash != "" { + update["$addToSet"] = bson.M{"processed_claim_hashes": claimTxHash} + } + + _, err = m.coll.UpdateOne(ctx, filter, update) return err } @@ -427,10 +457,10 @@ func NewWalletBindingMongo(db *mongo.Database) (database.WalletBinding, error) { } // GetExpiredPending returns red packets that have expired but are still in -// "CREATED" status (i.e., not yet refunded or fully claimed). +// "ACTIVE" status (i.e., on-chain creation confirmed, not yet fully claimed or refunded). func (m *RedPacketMgo) GetExpiredPending(ctx context.Context, now int64) ([]*model.RedPacket, error) { cur, err := m.coll.Find(ctx, bson.M{ - "status": "CREATED", + "status": "ACTIVE", "expiry_at": bson.M{"$lt": now, "$gt": 0}, }) if err != nil { diff --git a/pkg/common/storage/database/redpacket.go b/pkg/common/storage/database/redpacket.go index 5beddf600..1a958e9c7 100644 --- a/pkg/common/storage/database/redpacket.go +++ b/pkg/common/storage/database/redpacket.go @@ -12,8 +12,12 @@ type RedPacket interface { GetByPacketID(ctx context.Context, packetID string) (*model.RedPacket, error) UpdateCreated(ctx context.Context, rp *model.RedPacket) error UpdateStatus(ctx context.Context, packetID, status string) error - UpdateClaimProgress(ctx context.Context, packetID, claimedAmount, status string) error - // GetExpiredPending returns CREATED packets whose expiry_at < now (unix seconds). + // UpdateClaimProgress atomically increments the claim counter for packetID. + // claimTxHash is used as an idempotency key so that re-processing the same + // on-chain transaction never double-counts. When status is empty the method + // auto-derives the correct status (COMPLETED or ACTIVE). + UpdateClaimProgress(ctx context.Context, packetID, claimedAmount, status, claimTxHash string) error + // GetExpiredPending returns ACTIVE packets whose expiry_at < now (unix seconds). GetExpiredPending(ctx context.Context, now int64) ([]*model.RedPacket, error) } diff --git a/pkg/common/storage/model/redpacket.go b/pkg/common/storage/model/redpacket.go index ce697840c..82c2876ef 100644 --- a/pkg/common/storage/model/redpacket.go +++ b/pkg/common/storage/model/redpacket.go @@ -22,9 +22,10 @@ type RedPacket struct { Token string `bson:"token"` TotalAmount string `bson:"total_amount"` TotalShares int32 `bson:"total_shares"` - ClaimedAmount string `bson:"claimed_amount"` - ClaimedShares int32 `bson:"claimed_shares"` - ExpiryAt int64 `bson:"expiry_at"` + ClaimedAmount string `bson:"claimed_amount"` + ClaimedShares int32 `bson:"claimed_shares"` + ProcessedClaimHashes []string `bson:"processed_claim_hashes"` + ExpiryAt int64 `bson:"expiry_at"` TxHash string `bson:"tx_hash"` Status string `bson:"status"` CreatedAt time.Time `bson:"created_at"`