diff --git a/.github/workflows/codeql-analysis.yml b/.github/workflows/codeql-analysis.yml new file mode 100644 index 000000000..f106438e3 --- /dev/null +++ b/.github/workflows/codeql-analysis.yml @@ -0,0 +1,71 @@ +# For most projects, this workflow file will not need changing; you simply need +# to commit it to your repository. +# +# You may wish to alter this file to override the set of languages analyzed, +# or to provide custom queries or build logic. +# +# ******** NOTE ******** +# We have attempted to detect the languages in your repository. Please check +# the `language` matrix defined below to confirm you have the correct set of +# supported CodeQL languages. +# +name: "CodeQL" + +on: + push: + branches: [ main ] + pull_request: + # The branches below must be a subset of the branches above + branches: [ main ] + schedule: + - cron: '23 2 * * 2' + +jobs: + analyze: + name: Analyze + runs-on: ubuntu-latest + permissions: + actions: read + contents: read + security-events: write + + strategy: + fail-fast: false + matrix: + language: [ 'go' ] + # CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ] + # Learn more: + # https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed + + steps: + - name: Checkout repository + uses: actions/checkout@v2 + + # Initializes the CodeQL tools for scanning. + - name: Initialize CodeQL + uses: github/codeql-action/init@v1 + with: + languages: ${{ matrix.language }} + # If you wish to specify custom queries, you can do so here or in a config file. + # By default, queries listed here will override any specified in a config file. + # Prefix the list here with "+" to use these queries and those in the config file. + # queries: ./path/to/local/query, your-org/your-repo/queries@main + + # Autobuild attempts to build any compiled languages (C/C++, C#, or Java). + # If this step fails, then you should remove it and run the build manually (see below) + - name: Autobuild + uses: github/codeql-action/autobuild@v1 + + # ℹ️ Command-line programs to run using the OS shell. + # 📚 https://git.io/JvXDl + + # ✏️ If the Autobuild fails above, remove it and uncomment the following three lines + # and modify them (or add more) to build your code if your project + # uses a compiled language + + #- run: | + # make bootstrap + # make release + + - name: Perform CodeQL Analysis + uses: github/codeql-action/analyze@v1 diff --git a/.gitignore b/.gitignore new file mode 100644 index 000000000..74a748ab0 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.devcontainer +components +logs + diff --git a/config/config.yaml b/config/config.yaml index 9e8741b48..d7becbbb4 100644 --- a/config/config.yaml +++ b/config/config.yaml @@ -134,8 +134,8 @@ multiloginpolicy: tokenpolicy: accessSecret: "open_im_server" # Token effective time seconds as a unit - #Seven days 7*24*60*60 - accessExpire: 604800 + #Seven days + accessExpire: 7 messagecallback: callbackSwitch: false diff --git a/docker-compose.yaml b/docker-compose.yaml index a9fa9a1c9..1d64dcd9f 100644 --- a/docker-compose.yaml +++ b/docker-compose.yaml @@ -15,7 +15,7 @@ services: restart: always mongodb: - image: mongo + image: mongo:4.0 ports: - 27017:27017 container_name: mongo diff --git a/go.mod b/go.mod index ac0c9edf0..00fe22cb0 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,6 @@ require ( github.com/Shopify/toxiproxy v2.1.4+incompatible // indirect github.com/antonfisher/nested-logrus-formatter v1.3.0 github.com/coreos/go-semver v0.3.0 // indirect - github.com/dgrijalva/jwt-go v3.2.0+incompatible github.com/dustin/go-humanize v1.0.0 // indirect github.com/eapache/go-resiliency v1.2.0 // indirect github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 // indirect @@ -16,6 +15,7 @@ require ( github.com/garyburd/redigo v1.6.2 github.com/gin-gonic/gin v1.7.0 github.com/go-playground/validator/v10 v10.4.1 + github.com/golang-jwt/jwt/v4 v4.1.0 // indirect github.com/golang/protobuf v1.5.2 github.com/golang/snappy v0.0.3 // indirect github.com/gorilla/websocket v1.4.2 @@ -32,6 +32,7 @@ require ( github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 // indirect github.com/rifflock/lfshook v0.0.0-20180920164130-b9218ef580f5 github.com/sirupsen/logrus v1.6.0 + github.com/stretchr/testify v1.7.0 github.com/tencentyun/qcloud-cos-sts-sdk v0.0.0-20210325043845-84a0811633ca github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5 // indirect go.etcd.io/etcd v0.0.0-20200402134248-51bdeb39e698 diff --git a/go.sum b/go.sum index fbbe5ea96..664a73e93 100644 --- a/go.sum +++ b/go.sum @@ -74,6 +74,8 @@ github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5x github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.2.1 h1:/s5zKNz0uPFCZ5hddgPdo2TK2TVrUNMn0OOX8/aZMTE= github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= +github.com/golang-jwt/jwt/v4 v4.1.0 h1:XUgk2Ex5veyVFVeLm0xhusUTQybEbexJXrvPNOKkSY0= +github.com/golang-jwt/jwt/v4 v4.1.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= @@ -233,6 +235,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/tencentyun/qcloud-cos-sts-sdk v0.0.0-20210325043845-84a0811633ca h1:G/aIr3WiUesWHL2YGYgEqjM5tCAJ43Ml+0C18wDkWWs= github.com/tencentyun/qcloud-cos-sts-sdk v0.0.0-20210325043845-84a0811633ca/go.mod h1:b18KQa4IxHbxeseW1GcZox53d7J0z39VNONTxvvlkXw= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= diff --git a/src/common/config/config.go b/src/common/config/config.go index 441a2127f..5d7084653 100644 --- a/src/common/config/config.go +++ b/src/common/config/config.go @@ -1,11 +1,17 @@ package config import ( - "gopkg.in/yaml.v3" "io/ioutil" + "os" + + "path/filepath" + "runtime" + + "gopkg.in/yaml.v3" ) + var Config config type config struct { @@ -152,11 +158,9 @@ func init() { bytes, err := ioutil.ReadFile(path + "/config/config.yaml") if err != nil { panic(err) - return } if err = yaml.Unmarshal(bytes, &Config); err != nil { panic(err) - return } } diff --git a/src/common/db/redisModel.go b/src/common/db/redisModel.go index f86dbb638..2e3277914 100644 --- a/src/common/db/redisModel.go +++ b/src/common/db/redisModel.go @@ -66,7 +66,7 @@ func (d *DataBases) SetLastGetSeq(uid string) (err error) { //获取用户上一次主动拉取Seq的值 func (d *DataBases) GetLastGetSeq(uid string) (int64, error) { - key := userIncrSeq + uid + key := lastGetSeq + uid return redis.Int64(d.Exec("GET", key)) } diff --git a/src/msg_gateway/gate/rpc_server.go b/src/msg_gateway/gate/rpc_server.go index e9aa28033..54bb78076 100644 --- a/src/msg_gateway/gate/rpc_server.go +++ b/src/msg_gateway/gate/rpc_server.go @@ -10,10 +10,11 @@ import ( "context" "encoding/json" "fmt" - "github.com/gorilla/websocket" - "google.golang.org/grpc" "net" "strings" + + "github.com/gorilla/websocket" + "google.golang.org/grpc" ) type RPCServer struct { @@ -41,7 +42,7 @@ func (r *RPCServer) run() { srv := grpc.NewServer() defer srv.GracefulStop() pbRelay.RegisterOnlineMessageRelayServiceServer(srv, r) - err = getcdv3.RegisterEtcd4Unique(r.etcdSchema, strings.Join(r.etcdAddr, ","), ip, r.rpcPort, r.rpcRegisterName, 10) + err = getcdv3.RegisterEtcd(r.etcdSchema, strings.Join(r.etcdAddr, ","), ip, r.rpcPort, r.rpcRegisterName, 10) if err != nil { log.ErrorByKv("register push message rpc to etcd err", "", "err", err.Error()) } diff --git a/src/rpc/auth/auth/user_token.go b/src/rpc/auth/auth/user_token.go index 59e91c33b..c89e4312d 100644 --- a/src/rpc/auth/auth/user_token.go +++ b/src/rpc/auth/auth/user_token.go @@ -18,7 +18,7 @@ func (rpc *rpcAuth) UserToken(_ context.Context, pb *pbAuth.UserTokenReq) (*pbAu } log.Info("", "", "rpc user_token call..., im_mysql_model.AppServerFindFromUserByUserID") - tokens, expTime, err := utils.CreateToken(pb.UID, "", pb.Platform) + tokens, expTime, err := utils.CreateToken(pb.UID, pb.Platform) if err != nil { log.Error("", "", "rpc user_token call..., utils.CreateToken fail [uid: %s] [err: %s]", pb.UID, err.Error()) return &pbAuth.UserTokenResp{ErrCode: 500, ErrMsg: err.Error()}, err diff --git a/src/rpc/chat/chat/send_msg.go b/src/rpc/chat/chat/send_msg.go index 21a1a7be8..602ad03e3 100644 --- a/src/rpc/chat/chat/send_msg.go +++ b/src/rpc/chat/chat/send_msg.go @@ -88,84 +88,77 @@ func (rpc *rpcChat) UserSendMsg(_ context.Context, pb *pbChat.UserSendMsgReq) (* return returnMsg(&replay, pb, m.ResponseErrCode, m.ErrMsg, "", 0) } else { pbData.Content = m.ResponseResult.ModifiedMsg - err1 := rpc.sendMsgToKafka(&pbData, pbData.RecvID) - err2 := rpc.sendMsgToKafka(&pbData, pbData.SendID) - if err1 != nil || err2 != nil { - return returnMsg(&replay, pb, 201, "kafka send msg err", "", 0) - } - return returnMsg(&replay, pb, 0, "", serverMsgID, pbData.SendTime) } } - } else { - switch pbData.SessionType { - case constant.SingleChatType: - err1 := rpc.sendMsgToKafka(&pbData, pbData.RecvID) - err2 := rpc.sendMsgToKafka(&pbData, pbData.SendID) - if err1 != nil || err2 != nil { - return returnMsg(&replay, pb, 201, "kafka send msg err", "", 0) - } - return returnMsg(&replay, pb, 0, "", serverMsgID, pbData.SendTime) - case constant.GroupChatType: - etcdConn := getcdv3.GetConn(config.Config.Etcd.EtcdSchema, strings.Join(config.Config.Etcd.EtcdAddr, ","), config.Config.RpcRegisterName.OpenImGroupName) - client := pbGroup.NewGroupClient(etcdConn) - req := &pbGroup.GetGroupAllMemberReq{ - GroupID: pbData.RecvID, - Token: pbData.Token, - OperationID: pbData.OperationID, - } - reply, err := client.GetGroupAllMember(context.Background(), req) + } + switch pbData.SessionType { + case constant.SingleChatType: + err1 := rpc.sendMsgToKafka(&pbData, pbData.RecvID) + err2 := rpc.sendMsgToKafka(&pbData, pbData.SendID) + if err1 != nil || err2 != nil { + return returnMsg(&replay, pb, 201, "kafka send msg err", "", 0) + } + return returnMsg(&replay, pb, 0, "", serverMsgID, pbData.SendTime) + case constant.GroupChatType: + etcdConn := getcdv3.GetConn(config.Config.Etcd.EtcdSchema, strings.Join(config.Config.Etcd.EtcdAddr, ","), config.Config.RpcRegisterName.OpenImGroupName) + client := pbGroup.NewGroupClient(etcdConn) + req := &pbGroup.GetGroupAllMemberReq{ + GroupID: pbData.RecvID, + Token: pbData.Token, + OperationID: pbData.OperationID, + } + reply, err := client.GetGroupAllMember(context.Background(), req) + if err != nil { + log.Error(pbData.Token, pbData.OperationID, "rpc send_msg getGroupInfo failed, err = %s", err.Error()) + return returnMsg(&replay, pb, 201, err.Error(), "", 0) + } + if reply.ErrorCode != 0 { + log.Error(pbData.Token, pbData.OperationID, "rpc send_msg getGroupInfo failed, err = %s", reply.ErrorMsg) + return returnMsg(&replay, pb, reply.ErrorCode, reply.ErrorMsg, "", 0) + } + var addUidList []string + switch pbData.ContentType { + case constant.KickGroupMemberTip: + var notification content_struct.NotificationContent + var kickContent group.KickGroupMemberReq + err := utils.JsonStringToStruct(pbData.Content, ¬ification) if err != nil { - log.Error(pbData.Token, pbData.OperationID, "rpc send_msg getGroupInfo failed, err = %s", err.Error()) - return returnMsg(&replay, pb, 201, err.Error(), "", 0) - } - if reply.ErrorCode != 0 { - log.Error(pbData.Token, pbData.OperationID, "rpc send_msg getGroupInfo failed, err = %s", reply.ErrorMsg) - return returnMsg(&replay, pb, reply.ErrorCode, reply.ErrorMsg, "", 0) - } - var addUidList []string - switch pbData.ContentType { - case constant.KickGroupMemberTip: - var notification content_struct.NotificationContent - var kickContent group.KickGroupMemberReq - err := utils.JsonStringToStruct(pbData.Content, ¬ification) + log.ErrorByKv("json unmarshall err", pbData.OperationID, "err", err.Error()) + return returnMsg(&replay, pb, 200, err.Error(), "", 0) + } else { + err := utils.JsonStringToStruct(notification.Detail, &kickContent) if err != nil { log.ErrorByKv("json unmarshall err", pbData.OperationID, "err", err.Error()) return returnMsg(&replay, pb, 200, err.Error(), "", 0) - } else { - err := utils.JsonStringToStruct(notification.Detail, &kickContent) - if err != nil { - log.ErrorByKv("json unmarshall err", pbData.OperationID, "err", err.Error()) - return returnMsg(&replay, pb, 200, err.Error(), "", 0) - } - for _, v := range kickContent.UidListInfo { - addUidList = append(addUidList, v.UserId) - } - } - case constant.QuitGroupTip: - addUidList = append(addUidList, pbData.SendID) - default: - } - groupID := pbData.RecvID - for i, v := range reply.MemberList { - pbData.RecvID = v.UserId + " " + groupID - err := rpc.sendMsgToKafka(&pbData, utils.IntToString(i)) - if err != nil { - return returnMsg(&replay, pb, 201, "kafka send msg err", "", 0) } - } - for i, v := range addUidList { - pbData.RecvID = v + " " + groupID - err := rpc.sendMsgToKafka(&pbData, utils.IntToString(i+1)) - if err != nil { - return returnMsg(&replay, pb, 201, "kafka send msg err", "", 0) + for _, v := range kickContent.UidListInfo { + addUidList = append(addUidList, v.UserId) } } - return returnMsg(&replay, pb, 0, "", serverMsgID, pbData.SendTime) + case constant.QuitGroupTip: + addUidList = append(addUidList, pbData.SendID) default: - } + groupID := pbData.RecvID + for i, v := range reply.MemberList { + pbData.RecvID = v.UserId + " " + groupID + err := rpc.sendMsgToKafka(&pbData, utils.IntToString(i)) + if err != nil { + return returnMsg(&replay, pb, 201, "kafka send msg err", "", 0) + } + } + for i, v := range addUidList { + pbData.RecvID = v + " " + groupID + err := rpc.sendMsgToKafka(&pbData, utils.IntToString(i+1)) + if err != nil { + return returnMsg(&replay, pb, 201, "kafka send msg err", "", 0) + } + } + return returnMsg(&replay, pb, 0, "", serverMsgID, pbData.SendTime) + default: } + return returnMsg(&replay, pb, 203, "unkonwn sessionType", "", 0) } diff --git a/src/utils/get_server_ip.go b/src/utils/get_server_ip.go index 21092ffa1..ec5824cb8 100644 --- a/src/utils/get_server_ip.go +++ b/src/utils/get_server_ip.go @@ -13,23 +13,14 @@ func init() { ServerIP = config.Config.ServerIP return } - //fixme Get the ip of the local network card - netInterfaces, err := net.Interfaces() + + // see https://gist.github.com/jniltinho/9787946#gistcomment-3019898 + conn, err := net.Dial("udp", "8.8.8.8:80") if err != nil { panic(err) } - for i := 0; i < len(netInterfaces); i++ { - //Exclude useless network cards by judging the net.flag Up flag - if (netInterfaces[i].Flags & net.FlagUp) != 0 { - address, _ := netInterfaces[i].Addrs() - for _, addr := range address { - if ipNet, ok := addr.(*net.IPNet); ok && !ipNet.IP.IsLoopback() { - if ipNet.IP.To4() != nil { - ServerIP = ipNet.IP.String() - return - } - } - } - } - } + + defer conn.Close() + localAddr := conn.LocalAddr().(*net.UDPAddr) + ServerIP = localAddr.IP.String() } diff --git a/src/utils/get_server_ip_test.go b/src/utils/get_server_ip_test.go new file mode 100644 index 000000000..54e2142ef --- /dev/null +++ b/src/utils/get_server_ip_test.go @@ -0,0 +1,12 @@ +package utils + +import ( + "net" + "testing" +) + +func TestServerIP(t *testing.T) { + if net.ParseIP(ServerIP) == nil { + t.Fail() + } +} diff --git a/src/utils/jwt_token.go b/src/utils/jwt_token.go index 264da2483..14977f7ae 100644 --- a/src/utils/jwt_token.go +++ b/src/utils/jwt_token.go @@ -4,7 +4,7 @@ import ( "Open_IM/src/common/config" "Open_IM/src/common/db" "errors" - "github.com/dgrijalva/jwt-go" + "github.com/golang-jwt/jwt/v4" "time" ) @@ -19,38 +19,27 @@ var ( type Claims struct { UID string Platform string //login platform - jwt.StandardClaims + jwt.RegisteredClaims } -func BuildClaims(uid, accountAddr, platform string, ttl int64) Claims { - now := time.Now().Unix() - //if ttl=-1 Permanent token - if ttl == -1 { - return Claims{ - UID: uid, - Platform: platform, - StandardClaims: jwt.StandardClaims{ - ExpiresAt: -1, - IssuedAt: now, - NotBefore: now, - }} - } +func BuildClaims(uid, platform string, ttl int64) Claims { + now := time.Now() return Claims{ UID: uid, Platform: platform, - StandardClaims: jwt.StandardClaims{ - ExpiresAt: now + ttl, //Expiration time - IssuedAt: now, //Issuing time - NotBefore: now, //Begin Effective time + RegisteredClaims: jwt.RegisteredClaims{ + ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(ttl*24) * time.Hour)), //Expiration time + IssuedAt: jwt.NewNumericDate(now), //Issuing time + NotBefore: jwt.NewNumericDate(now), //Begin Effective time }} } -func CreateToken(userID, accountAddr string, platform int32) (string, int64, error) { - claims := BuildClaims(userID, accountAddr, PlatformIDToName(platform), config.Config.TokenPolicy.AccessExpire) +func CreateToken(userID string, platform int32) (string, int64, error) { + claims := BuildClaims(userID, PlatformIDToName(platform), config.Config.TokenPolicy.AccessExpire) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) tokenString, err := token.SignedString([]byte(config.Config.TokenPolicy.AccessSecret)) - return tokenString, claims.ExpiresAt, err + return tokenString, claims.ExpiresAt.Time.Unix(), err } func secret() jwt.Keyfunc { @@ -59,7 +48,7 @@ func secret() jwt.Keyfunc { } } -func ParseToken(tokensString string) (claims *Claims, err error) { +func getClaimFromToken(tokensString string) (*Claims, error) { token, err := jwt.ParseWithClaims(tokensString, &Claims{}, secret()) if err != nil { if ve, ok := err.(*jwt.ValidationError); ok { @@ -75,76 +64,66 @@ func ParseToken(tokensString string) (claims *Claims, err error) { } } if claims, ok := token.Claims.(*Claims); ok && token.Valid { - // 1.check userid and platform class 0 not exists and 1 exists - existsInterface, err := db.DB.ExistsUserIDAndPlatform(claims.UID, Platform2class[claims.Platform]) + return claims, nil + } + return nil, err +} + +func ParseToken(tokensString string) (claims *Claims, err error) { + claims, err = getClaimFromToken(tokensString) + + if err != nil { + return nil, err + } + + // 1.check userid and platform class 0 not exists and 1 exists + existsInterface, err := db.DB.ExistsUserIDAndPlatform(claims.UID, Platform2class[claims.Platform]) + if err != nil { + return nil, err + } + exists := existsInterface.(int64) + //get config multi login policy + if config.Config.MultiLoginPolicy.OnlyOneTerminalAccess { + //OnlyOneTerminalAccess policy need to check all terminal + //When only one end is allowed to log in, there is a situation that needs to be paid attention to. After PC login, + //mobile login should check two platform times. One of them is less than the redis storage time, which is the invalid token. + platform := "PC" + if Platform2class[claims.Platform] == "PC" { + platform = "Mobile" + } + + existsInterface, err = db.DB.ExistsUserIDAndPlatform(claims.UID, platform) if err != nil { return nil, err } - exists := existsInterface.(int64) - //get config multi login policy - if config.Config.MultiLoginPolicy.OnlyOneTerminalAccess { - //OnlyOneTerminalAccess policy need to check all terminal - //When only one end is allowed to log in, there is a situation that needs to be paid attention to. After PC login, - //mobile login should check two platform times. One of them is less than the redis storage time, which is the invalid token. - if Platform2class[claims.Platform] == "PC" { - existsInterface, err = db.DB.ExistsUserIDAndPlatform(claims.UID, "Mobile") - if err != nil { - return nil, err - } - exists = existsInterface.(int64) - if exists == 1 { - res, err := MakeTheTokenInvalid(*claims, "Mobile") - if err != nil { - return nil, err - } - if res { - return nil, TokenInvalid - } - } - } else { - existsInterface, err = db.DB.ExistsUserIDAndPlatform(claims.UID, "PC") - if err != nil { - return nil, err - } - exists = existsInterface.(int64) - if exists == 1 { - res, err := MakeTheTokenInvalid(*claims, "PC") - if err != nil { - return nil, err - } - if res { - return nil, TokenInvalid - } - } - } - if exists == 1 { - res, err := MakeTheTokenInvalid(*claims, Platform2class[claims.Platform]) - if err != nil { - return nil, err - } - if res { - return nil, TokenInvalid - } + exists = existsInterface.(int64) + if exists == 1 { + res, err := MakeTheTokenInvalid(claims, platform) + if err != nil { + return nil, err } - - } else if config.Config.MultiLoginPolicy.MobileAndPCTerminalAccessButOtherTerminalKickEachOther { - if exists == 1 { - res, err := MakeTheTokenInvalid(*claims, Platform2class[claims.Platform]) - if err != nil { - return nil, err - } - if res { - return nil, TokenInvalid - } + if res { + return nil, TokenInvalid } } - return claims, nil } - return nil, TokenUnknown + // config.Config.MultiLoginPolicy.MobileAndPCTerminalAccessButOtherTerminalKickEachOther == true + // or PC/Mobile validate success + // final check + if exists == 1 { + res, err := MakeTheTokenInvalid(claims, Platform2class[claims.Platform]) + if err != nil { + return nil, err + } + if res { + return nil, TokenInvalid + } + } + return claims, nil } -func MakeTheTokenInvalid(currentClaims Claims, platformClass string) (bool, error) { +func MakeTheTokenInvalid(currentClaims *Claims, platformClass string) (bool, error) { storedRedisTokenInterface, err := db.DB.GetPlatformToken(currentClaims.UID, platformClass) if err != nil { return false, err @@ -154,40 +133,21 @@ func MakeTheTokenInvalid(currentClaims Claims, platformClass string) (bool, erro return false, err } //if issue time less than redis token then make this token invalid - if currentClaims.IssuedAt < storedRedisPlatformClaims.IssuedAt { + if currentClaims.IssuedAt.Time.Unix() < storedRedisPlatformClaims.IssuedAt.Time.Unix() { return true, TokenInvalid } return false, nil } + func ParseRedisInterfaceToken(redisToken interface{}) (*Claims, error) { - token, err := jwt.ParseWithClaims(string(redisToken.([]uint8)), &Claims{}, secret()) - if err != nil { - if ve, ok := err.(*jwt.ValidationError); ok { - if ve.Errors&jwt.ValidationErrorMalformed != 0 { - return nil, TokenMalformed - } else if ve.Errors&jwt.ValidationErrorExpired != 0 { - return nil, TokenExpired - } else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 { - return nil, TokenNotValidYet - } else { - return nil, TokenInvalid - } - } - } - if claims, ok := token.Claims.(*Claims); ok && token.Valid { - return claims, nil - } - return nil, err + return getClaimFromToken(string(redisToken.([]uint8))) } //Validation token, false means failure, true means successful verification func VerifyToken(token, uid string) bool { claims, err := ParseToken(token) - if err != nil { - return false - } else if claims.UID != uid { + if err != nil || claims.UID != uid { return false - } else { - return true } + return true } diff --git a/src/utils/jwt_token_test.go b/src/utils/jwt_token_test.go new file mode 100644 index 000000000..83b9bb91c --- /dev/null +++ b/src/utils/jwt_token_test.go @@ -0,0 +1,81 @@ +package utils + +import ( + "Open_IM/src/common/config" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func Test_BuildClaims(t *testing.T) { + uid := "1" + platform := "PC" + ttl := int64(-1) + claim := BuildClaims(uid, platform, ttl) + now := time.Now().Unix() + + assert.Equal(t, claim.UID, uid, "uid should equal") + assert.Equal(t, claim.Platform, platform, "platform should equal") + assert.Equal(t, claim.RegisteredClaims.ExpiresAt, int64(-1), "StandardClaims.ExpiresAt should be equal") + // time difference within 1s + assert.Equal(t, claim.RegisteredClaims.IssuedAt, now, "StandardClaims.IssuedAt should be equal") + assert.Equal(t, claim.RegisteredClaims.NotBefore, now, "StandardClaims.NotBefore should be equal") + + ttl = int64(60) + now = time.Now().Unix() + claim = BuildClaims(uid, platform, ttl) + // time difference within 1s + assert.Equal(t, claim.RegisteredClaims.ExpiresAt, int64(60)+now, "StandardClaims.ExpiresAt should be equal") + assert.Equal(t, claim.RegisteredClaims.IssuedAt, now, "StandardClaims.IssuedAt should be equal") + assert.Equal(t, claim.RegisteredClaims.NotBefore, now, "StandardClaims.NotBefore should be equal") +} + +func Test_CreateToken(t *testing.T) { + uid := "1" + platform := int32(1) + now := time.Now().Unix() + + tokenString, expiresAt, err := CreateToken(uid, platform) + + assert.NotEmpty(t, tokenString) + assert.Equal(t, expiresAt, 604800+now) + assert.Nil(t, err) +} + +func Test_VerifyToken(t *testing.T) { + uid := "1" + platform := int32(1) + tokenString, _, _ := CreateToken(uid, platform) + result := VerifyToken(tokenString, uid) + assert.True(t, result) + result = VerifyToken(tokenString, "2") + assert.False(t, result) +} + +func Test_ParseRedisInterfaceToken(t *testing.T) { + uid := "1" + platform := int32(1) + tokenString, _, _ := CreateToken(uid, platform) + + claims, err := ParseRedisInterfaceToken([]uint8(tokenString)) + assert.Nil(t, err) + assert.Equal(t, claims.UID, uid) + + // timeout + config.Config.TokenPolicy.AccessExpire = -80 + tokenString, _, _ = CreateToken(uid, platform) + claims, err = ParseRedisInterfaceToken([]uint8(tokenString)) + assert.Equal(t, err, TokenExpired) + assert.Nil(t, claims) +} + +func Test_ParseToken(t *testing.T) { + uid := "1" + platform := int32(1) + tokenString, _, _ := CreateToken(uid, platform) + claims, err := ParseToken(tokenString) + if err == nil { + assert.Equal(t, claims.UID, uid) + } +}