Merge branch 'main' into main

pull/31/head
Gordon 4 years ago committed by GitHub
commit 4ffc9f8919
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,71 @@
# For most projects, this workflow file will not need changing; you simply need
# to commit it to your repository.
#
# You may wish to alter this file to override the set of languages analyzed,
# or to provide custom queries or build logic.
#
# ******** NOTE ********
# We have attempted to detect the languages in your repository. Please check
# the `language` matrix defined below to confirm you have the correct set of
# supported CodeQL languages.
#
name: "CodeQL"
on:
push:
branches: [ main ]
pull_request:
# The branches below must be a subset of the branches above
branches: [ main ]
schedule:
- cron: '23 2 * * 2'
jobs:
analyze:
name: Analyze
runs-on: ubuntu-latest
permissions:
actions: read
contents: read
security-events: write
strategy:
fail-fast: false
matrix:
language: [ 'go' ]
# CodeQL supports [ 'cpp', 'csharp', 'go', 'java', 'javascript', 'python' ]
# Learn more:
# https://docs.github.com/en/free-pro-team@latest/github/finding-security-vulnerabilities-and-errors-in-your-code/configuring-code-scanning#changing-the-languages-that-are-analyzed
steps:
- name: Checkout repository
uses: actions/checkout@v2
# Initializes the CodeQL tools for scanning.
- name: Initialize CodeQL
uses: github/codeql-action/init@v1
with:
languages: ${{ matrix.language }}
# If you wish to specify custom queries, you can do so here or in a config file.
# By default, queries listed here will override any specified in a config file.
# Prefix the list here with "+" to use these queries and those in the config file.
# queries: ./path/to/local/query, your-org/your-repo/queries@main
# Autobuild attempts to build any compiled languages (C/C++, C#, or Java).
# If this step fails, then you should remove it and run the build manually (see below)
- name: Autobuild
uses: github/codeql-action/autobuild@v1
# Command-line programs to run using the OS shell.
# 📚 https://git.io/JvXDl
# ✏️ If the Autobuild fails above, remove it and uncomment the following three lines
# and modify them (or add more) to build your code if your project
# uses a compiled language
#- run: |
# make bootstrap
# make release
- name: Perform CodeQL Analysis
uses: github/codeql-action/analyze@v1

4
.gitignore vendored

@ -0,0 +1,4 @@
.devcontainer
components
logs

@ -1,3 +1,4 @@
# The class cannot be named by Pascal or camel case. # The class cannot be named by Pascal or camel case.
# If it is not used, the corresponding structure will not be set, # If it is not used, the corresponding structure will not be set,
# and it will not be read naturally. # and it will not be read naturally.

@ -15,7 +15,7 @@ services:
restart: always restart: always
mongodb: mongodb:
image: mongo image: mongo:4.0
ports: ports:
- 27017:27017 - 27017:27017
container_name: mongo container_name: mongo

@ -7,7 +7,6 @@ require (
github.com/Shopify/toxiproxy v2.1.4+incompatible // indirect github.com/Shopify/toxiproxy v2.1.4+incompatible // indirect
github.com/antonfisher/nested-logrus-formatter v1.3.0 github.com/antonfisher/nested-logrus-formatter v1.3.0
github.com/coreos/go-semver v0.3.0 // indirect github.com/coreos/go-semver v0.3.0 // indirect
github.com/dgrijalva/jwt-go v3.2.0+incompatible
github.com/dustin/go-humanize v1.0.0 // indirect github.com/dustin/go-humanize v1.0.0 // indirect
github.com/eapache/go-resiliency v1.2.0 // indirect github.com/eapache/go-resiliency v1.2.0 // indirect
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 // indirect github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21 // indirect
@ -16,6 +15,7 @@ require (
github.com/garyburd/redigo v1.6.2 github.com/garyburd/redigo v1.6.2
github.com/gin-gonic/gin v1.7.0 github.com/gin-gonic/gin v1.7.0
github.com/go-playground/validator/v10 v10.4.1 github.com/go-playground/validator/v10 v10.4.1
github.com/golang-jwt/jwt/v4 v4.1.0 // indirect
github.com/golang/protobuf v1.5.2 github.com/golang/protobuf v1.5.2
github.com/golang/snappy v0.0.3 // indirect github.com/golang/snappy v0.0.3 // indirect
github.com/gorilla/websocket v1.4.2 github.com/gorilla/websocket v1.4.2
@ -32,6 +32,7 @@ require (
github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 // indirect github.com/rcrowley/go-metrics v0.0.0-20201227073835-cf1acfcdf475 // indirect
github.com/rifflock/lfshook v0.0.0-20180920164130-b9218ef580f5 github.com/rifflock/lfshook v0.0.0-20180920164130-b9218ef580f5
github.com/sirupsen/logrus v1.6.0 github.com/sirupsen/logrus v1.6.0
github.com/stretchr/testify v1.7.0
github.com/tencentyun/qcloud-cos-sts-sdk v0.0.0-20210325043845-84a0811633ca github.com/tencentyun/qcloud-cos-sts-sdk v0.0.0-20210325043845-84a0811633ca
github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5 // indirect github.com/tmc/grpc-websocket-proxy v0.0.0-20190109142713-0ad062ec5ee5 // indirect
go.etcd.io/etcd v0.0.0-20200402134248-51bdeb39e698 go.etcd.io/etcd v0.0.0-20200402134248-51bdeb39e698

@ -74,6 +74,8 @@ github.com/godbus/dbus/v5 v5.0.3/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5x
github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ=
github.com/gogo/protobuf v1.2.1 h1:/s5zKNz0uPFCZ5hddgPdo2TK2TVrUNMn0OOX8/aZMTE= github.com/gogo/protobuf v1.2.1 h1:/s5zKNz0uPFCZ5hddgPdo2TK2TVrUNMn0OOX8/aZMTE=
github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4= github.com/gogo/protobuf v1.2.1/go.mod h1:hp+jE20tsWTFYpLwKvXlhS1hjn+gTNwPg2I6zVXpSg4=
github.com/golang-jwt/jwt/v4 v4.1.0 h1:XUgk2Ex5veyVFVeLm0xhusUTQybEbexJXrvPNOKkSY0=
github.com/golang-jwt/jwt/v4 v4.1.0/go.mod h1:/xlHOz8bRuivTWchD4jCa+NbatV+wEUSzwAxVc6locg=
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY=
github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0=
github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q=
@ -233,6 +235,8 @@ github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UV
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0= github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/tencentyun/qcloud-cos-sts-sdk v0.0.0-20210325043845-84a0811633ca h1:G/aIr3WiUesWHL2YGYgEqjM5tCAJ43Ml+0C18wDkWWs= github.com/tencentyun/qcloud-cos-sts-sdk v0.0.0-20210325043845-84a0811633ca h1:G/aIr3WiUesWHL2YGYgEqjM5tCAJ43Ml+0C18wDkWWs=
github.com/tencentyun/qcloud-cos-sts-sdk v0.0.0-20210325043845-84a0811633ca/go.mod h1:b18KQa4IxHbxeseW1GcZox53d7J0z39VNONTxvvlkXw= github.com/tencentyun/qcloud-cos-sts-sdk v0.0.0-20210325043845-84a0811633ca/go.mod h1:b18KQa4IxHbxeseW1GcZox53d7J0z39VNONTxvvlkXw=
github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U= github.com/tmc/grpc-websocket-proxy v0.0.0-20170815181823-89b8d40f7ca8/go.mod h1:ncp9v5uamzpCO7NfCPTXjqaC+bZgJeR0sMTm6dMHP7U=

@ -1,10 +1,17 @@
package config package config
import ( import (
"gopkg.in/yaml.v3"
"io/ioutil" "io/ioutil"
"os"
"path/filepath"
"runtime"
"gopkg.in/yaml.v3"
) )
var Config config var Config config
type config struct { type config struct {
@ -147,14 +154,13 @@ type config struct {
} }
func init() { func init() {
bytes, err := ioutil.ReadFile("../config/config.yaml") path, _ := os.Getwd()
bytes, err := ioutil.ReadFile(path + "/config/config.yaml")
if err != nil { if err != nil {
panic(err) panic(err)
return
} }
if err = yaml.Unmarshal(bytes, &Config); err != nil { if err = yaml.Unmarshal(bytes, &Config); err != nil {
panic(err) panic(err)
return
} }
} }

@ -66,7 +66,7 @@ func (d *DataBases) SetLastGetSeq(uid string) (err error) {
//获取用户上一次主动拉取Seq的值 //获取用户上一次主动拉取Seq的值
func (d *DataBases) GetLastGetSeq(uid string) (int64, error) { func (d *DataBases) GetLastGetSeq(uid string) (int64, error) {
key := userIncrSeq + uid key := lastGetSeq + uid
return redis.Int64(d.Exec("GET", key)) return redis.Int64(d.Exec("GET", key))
} }

@ -10,10 +10,11 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/gorilla/websocket"
"google.golang.org/grpc"
"net" "net"
"strings" "strings"
"github.com/gorilla/websocket"
"google.golang.org/grpc"
) )
type RPCServer struct { type RPCServer struct {
@ -41,7 +42,7 @@ func (r *RPCServer) run() {
srv := grpc.NewServer() srv := grpc.NewServer()
defer srv.GracefulStop() defer srv.GracefulStop()
pbRelay.RegisterOnlineMessageRelayServiceServer(srv, r) pbRelay.RegisterOnlineMessageRelayServiceServer(srv, r)
err = getcdv3.RegisterEtcd4Unique(r.etcdSchema, strings.Join(r.etcdAddr, ","), ip, r.rpcPort, r.rpcRegisterName, 10) err = getcdv3.RegisterEtcd(r.etcdSchema, strings.Join(r.etcdAddr, ","), ip, r.rpcPort, r.rpcRegisterName, 10)
if err != nil { if err != nil {
log.ErrorByKv("register push message rpc to etcd err", "", "err", err.Error()) log.ErrorByKv("register push message rpc to etcd err", "", "err", err.Error())
} }

@ -18,7 +18,7 @@ func (rpc *rpcAuth) UserToken(_ context.Context, pb *pbAuth.UserTokenReq) (*pbAu
} }
log.Info("", "", "rpc user_token call..., im_mysql_model.AppServerFindFromUserByUserID") log.Info("", "", "rpc user_token call..., im_mysql_model.AppServerFindFromUserByUserID")
tokens, expTime, err := utils.CreateToken(pb.UID, "", pb.Platform) tokens, expTime, err := utils.CreateToken(pb.UID, pb.Platform)
if err != nil { if err != nil {
log.Error("", "", "rpc user_token call..., utils.CreateToken fail [uid: %s] [err: %s]", pb.UID, err.Error()) log.Error("", "", "rpc user_token call..., utils.CreateToken fail [uid: %s] [err: %s]", pb.UID, err.Error())
return &pbAuth.UserTokenResp{ErrCode: 500, ErrMsg: err.Error()}, err return &pbAuth.UserTokenResp{ErrCode: 500, ErrMsg: err.Error()}, err

@ -88,84 +88,77 @@ func (rpc *rpcChat) UserSendMsg(_ context.Context, pb *pbChat.UserSendMsgReq) (*
return returnMsg(&replay, pb, m.ResponseErrCode, m.ErrMsg, "", 0) return returnMsg(&replay, pb, m.ResponseErrCode, m.ErrMsg, "", 0)
} else { } else {
pbData.Content = m.ResponseResult.ModifiedMsg pbData.Content = m.ResponseResult.ModifiedMsg
err1 := rpc.sendMsgToKafka(&pbData, pbData.RecvID)
err2 := rpc.sendMsgToKafka(&pbData, pbData.SendID)
if err1 != nil || err2 != nil {
return returnMsg(&replay, pb, 201, "kafka send msg err", "", 0)
}
return returnMsg(&replay, pb, 0, "", serverMsgID, pbData.SendTime)
} }
} }
} else { }
switch pbData.SessionType { switch pbData.SessionType {
case constant.SingleChatType: case constant.SingleChatType:
err1 := rpc.sendMsgToKafka(&pbData, pbData.RecvID) err1 := rpc.sendMsgToKafka(&pbData, pbData.RecvID)
err2 := rpc.sendMsgToKafka(&pbData, pbData.SendID) err2 := rpc.sendMsgToKafka(&pbData, pbData.SendID)
if err1 != nil || err2 != nil { if err1 != nil || err2 != nil {
return returnMsg(&replay, pb, 201, "kafka send msg err", "", 0) return returnMsg(&replay, pb, 201, "kafka send msg err", "", 0)
} }
return returnMsg(&replay, pb, 0, "", serverMsgID, pbData.SendTime) return returnMsg(&replay, pb, 0, "", serverMsgID, pbData.SendTime)
case constant.GroupChatType: case constant.GroupChatType:
etcdConn := getcdv3.GetConn(config.Config.Etcd.EtcdSchema, strings.Join(config.Config.Etcd.EtcdAddr, ","), config.Config.RpcRegisterName.OpenImGroupName) etcdConn := getcdv3.GetConn(config.Config.Etcd.EtcdSchema, strings.Join(config.Config.Etcd.EtcdAddr, ","), config.Config.RpcRegisterName.OpenImGroupName)
client := pbGroup.NewGroupClient(etcdConn) client := pbGroup.NewGroupClient(etcdConn)
req := &pbGroup.GetGroupAllMemberReq{ req := &pbGroup.GetGroupAllMemberReq{
GroupID: pbData.RecvID, GroupID: pbData.RecvID,
Token: pbData.Token, Token: pbData.Token,
OperationID: pbData.OperationID, OperationID: pbData.OperationID,
} }
reply, err := client.GetGroupAllMember(context.Background(), req) reply, err := client.GetGroupAllMember(context.Background(), req)
if err != nil {
log.Error(pbData.Token, pbData.OperationID, "rpc send_msg getGroupInfo failed, err = %s", err.Error())
return returnMsg(&replay, pb, 201, err.Error(), "", 0)
}
if reply.ErrorCode != 0 {
log.Error(pbData.Token, pbData.OperationID, "rpc send_msg getGroupInfo failed, err = %s", reply.ErrorMsg)
return returnMsg(&replay, pb, reply.ErrorCode, reply.ErrorMsg, "", 0)
}
var addUidList []string
switch pbData.ContentType {
case constant.KickGroupMemberTip:
var notification content_struct.NotificationContent
var kickContent group.KickGroupMemberReq
err := utils.JsonStringToStruct(pbData.Content, &notification)
if err != nil { if err != nil {
log.Error(pbData.Token, pbData.OperationID, "rpc send_msg getGroupInfo failed, err = %s", err.Error()) log.ErrorByKv("json unmarshall err", pbData.OperationID, "err", err.Error())
return returnMsg(&replay, pb, 201, err.Error(), "", 0) return returnMsg(&replay, pb, 200, err.Error(), "", 0)
} } else {
if reply.ErrorCode != 0 { err := utils.JsonStringToStruct(notification.Detail, &kickContent)
log.Error(pbData.Token, pbData.OperationID, "rpc send_msg getGroupInfo failed, err = %s", reply.ErrorMsg)
return returnMsg(&replay, pb, reply.ErrorCode, reply.ErrorMsg, "", 0)
}
var addUidList []string
switch pbData.ContentType {
case constant.KickGroupMemberTip:
var notification content_struct.NotificationContent
var kickContent group.KickGroupMemberReq
err := utils.JsonStringToStruct(pbData.Content, &notification)
if err != nil { if err != nil {
log.ErrorByKv("json unmarshall err", pbData.OperationID, "err", err.Error()) log.ErrorByKv("json unmarshall err", pbData.OperationID, "err", err.Error())
return returnMsg(&replay, pb, 200, err.Error(), "", 0) return returnMsg(&replay, pb, 200, err.Error(), "", 0)
} else {
err := utils.JsonStringToStruct(notification.Detail, &kickContent)
if err != nil {
log.ErrorByKv("json unmarshall err", pbData.OperationID, "err", err.Error())
return returnMsg(&replay, pb, 200, err.Error(), "", 0)
}
for _, v := range kickContent.UidListInfo {
addUidList = append(addUidList, v.UserId)
}
}
case constant.QuitGroupTip:
addUidList = append(addUidList, pbData.SendID)
default:
}
groupID := pbData.RecvID
for i, v := range reply.MemberList {
pbData.RecvID = v.UserId + " " + groupID
err := rpc.sendMsgToKafka(&pbData, utils.IntToString(i))
if err != nil {
return returnMsg(&replay, pb, 201, "kafka send msg err", "", 0)
} }
} for _, v := range kickContent.UidListInfo {
for i, v := range addUidList { addUidList = append(addUidList, v.UserId)
pbData.RecvID = v + " " + groupID
err := rpc.sendMsgToKafka(&pbData, utils.IntToString(i+1))
if err != nil {
return returnMsg(&replay, pb, 201, "kafka send msg err", "", 0)
} }
} }
return returnMsg(&replay, pb, 0, "", serverMsgID, pbData.SendTime) case constant.QuitGroupTip:
addUidList = append(addUidList, pbData.SendID)
default: default:
} }
groupID := pbData.RecvID
for i, v := range reply.MemberList {
pbData.RecvID = v.UserId + " " + groupID
err := rpc.sendMsgToKafka(&pbData, utils.IntToString(i))
if err != nil {
return returnMsg(&replay, pb, 201, "kafka send msg err", "", 0)
}
}
for i, v := range addUidList {
pbData.RecvID = v + " " + groupID
err := rpc.sendMsgToKafka(&pbData, utils.IntToString(i+1))
if err != nil {
return returnMsg(&replay, pb, 201, "kafka send msg err", "", 0)
}
}
return returnMsg(&replay, pb, 0, "", serverMsgID, pbData.SendTime)
default:
} }
return returnMsg(&replay, pb, 203, "unkonwn sessionType", "", 0) return returnMsg(&replay, pb, 203, "unkonwn sessionType", "", 0)
} }

@ -13,23 +13,14 @@ func init() {
ServerIP = config.Config.ServerIP ServerIP = config.Config.ServerIP
return return
} }
//fixme Get the ip of the local network card
netInterfaces, err := net.Interfaces() // see https://gist.github.com/jniltinho/9787946#gistcomment-3019898
conn, err := net.Dial("udp", "8.8.8.8:80")
if err != nil { if err != nil {
panic(err) panic(err)
} }
for i := 0; i < len(netInterfaces); i++ {
//Exclude useless network cards by judging the net.flag Up flag defer conn.Close()
if (netInterfaces[i].Flags & net.FlagUp) != 0 { localAddr := conn.LocalAddr().(*net.UDPAddr)
address, _ := netInterfaces[i].Addrs() ServerIP = localAddr.IP.String()
for _, addr := range address {
if ipNet, ok := addr.(*net.IPNet); ok && !ipNet.IP.IsLoopback() {
if ipNet.IP.To4() != nil {
ServerIP = ipNet.IP.String()
return
}
}
}
}
}
} }

@ -0,0 +1,12 @@
package utils
import (
"net"
"testing"
)
func TestServerIP(t *testing.T) {
if net.ParseIP(ServerIP) == nil {
t.Fail()
}
}

@ -4,7 +4,7 @@ import (
"Open_IM/src/common/config" "Open_IM/src/common/config"
"Open_IM/src/common/db" "Open_IM/src/common/db"
"errors" "errors"
"github.com/dgrijalva/jwt-go" "github.com/golang-jwt/jwt/v4"
"time" "time"
) )
@ -19,38 +19,27 @@ var (
type Claims struct { type Claims struct {
UID string UID string
Platform string //login platform Platform string //login platform
jwt.StandardClaims jwt.RegisteredClaims
} }
func BuildClaims(uid, accountAddr, platform string, ttl int64) Claims { func BuildClaims(uid, platform string, ttl int64) Claims {
now := time.Now().Unix() now := time.Now()
//if ttl=-1 Permanent token
if ttl == -1 {
return Claims{
UID: uid,
Platform: platform,
StandardClaims: jwt.StandardClaims{
ExpiresAt: -1,
IssuedAt: now,
NotBefore: now,
}}
}
return Claims{ return Claims{
UID: uid, UID: uid,
Platform: platform, Platform: platform,
StandardClaims: jwt.StandardClaims{ RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: now + ttl, //Expiration time ExpiresAt: jwt.NewNumericDate(now.Add(time.Duration(ttl*24) * time.Hour)), //Expiration time
IssuedAt: now, //Issuing time IssuedAt: jwt.NewNumericDate(now), //Issuing time
NotBefore: now, //Begin Effective time NotBefore: jwt.NewNumericDate(now), //Begin Effective time
}} }}
} }
func CreateToken(userID, accountAddr string, platform int32) (string, int64, error) { func CreateToken(userID string, platform int32) (string, int64, error) {
claims := BuildClaims(userID, accountAddr, PlatformIDToName(platform), config.Config.TokenPolicy.AccessExpire) claims := BuildClaims(userID, PlatformIDToName(platform), config.Config.TokenPolicy.AccessExpire)
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
tokenString, err := token.SignedString([]byte(config.Config.TokenPolicy.AccessSecret)) tokenString, err := token.SignedString([]byte(config.Config.TokenPolicy.AccessSecret))
return tokenString, claims.ExpiresAt, err return tokenString, claims.ExpiresAt.Time.Unix(), err
} }
func secret() jwt.Keyfunc { func secret() jwt.Keyfunc {
@ -59,7 +48,7 @@ func secret() jwt.Keyfunc {
} }
} }
func ParseToken(tokensString string) (claims *Claims, err error) { func getClaimFromToken(tokensString string) (*Claims, error) {
token, err := jwt.ParseWithClaims(tokensString, &Claims{}, secret()) token, err := jwt.ParseWithClaims(tokensString, &Claims{}, secret())
if err != nil { if err != nil {
if ve, ok := err.(*jwt.ValidationError); ok { if ve, ok := err.(*jwt.ValidationError); ok {
@ -75,76 +64,66 @@ func ParseToken(tokensString string) (claims *Claims, err error) {
} }
} }
if claims, ok := token.Claims.(*Claims); ok && token.Valid { if claims, ok := token.Claims.(*Claims); ok && token.Valid {
// 1.check userid and platform class 0 not exists and 1 exists return claims, nil
existsInterface, err := db.DB.ExistsUserIDAndPlatform(claims.UID, Platform2class[claims.Platform]) }
return nil, err
}
func ParseToken(tokensString string) (claims *Claims, err error) {
claims, err = getClaimFromToken(tokensString)
if err != nil {
return nil, err
}
// 1.check userid and platform class 0 not exists and 1 exists
existsInterface, err := db.DB.ExistsUserIDAndPlatform(claims.UID, Platform2class[claims.Platform])
if err != nil {
return nil, err
}
exists := existsInterface.(int64)
//get config multi login policy
if config.Config.MultiLoginPolicy.OnlyOneTerminalAccess {
//OnlyOneTerminalAccess policy need to check all terminal
//When only one end is allowed to log in, there is a situation that needs to be paid attention to. After PC login,
//mobile login should check two platform times. One of them is less than the redis storage time, which is the invalid token.
platform := "PC"
if Platform2class[claims.Platform] == "PC" {
platform = "Mobile"
}
existsInterface, err = db.DB.ExistsUserIDAndPlatform(claims.UID, platform)
if err != nil { if err != nil {
return nil, err return nil, err
} }
exists := existsInterface.(int64)
//get config multi login policy
if config.Config.MultiLoginPolicy.OnlyOneTerminalAccess {
//OnlyOneTerminalAccess policy need to check all terminal
//When only one end is allowed to log in, there is a situation that needs to be paid attention to. After PC login,
//mobile login should check two platform times. One of them is less than the redis storage time, which is the invalid token.
if Platform2class[claims.Platform] == "PC" {
existsInterface, err = db.DB.ExistsUserIDAndPlatform(claims.UID, "Mobile")
if err != nil {
return nil, err
}
exists = existsInterface.(int64)
if exists == 1 {
res, err := MakeTheTokenInvalid(*claims, "Mobile")
if err != nil {
return nil, err
}
if res {
return nil, TokenInvalid
}
}
} else {
existsInterface, err = db.DB.ExistsUserIDAndPlatform(claims.UID, "PC")
if err != nil {
return nil, err
}
exists = existsInterface.(int64)
if exists == 1 {
res, err := MakeTheTokenInvalid(*claims, "PC")
if err != nil {
return nil, err
}
if res {
return nil, TokenInvalid
}
}
}
if exists == 1 { exists = existsInterface.(int64)
res, err := MakeTheTokenInvalid(*claims, Platform2class[claims.Platform]) if exists == 1 {
if err != nil { res, err := MakeTheTokenInvalid(claims, platform)
return nil, err if err != nil {
} return nil, err
if res {
return nil, TokenInvalid
}
} }
if res {
} else if config.Config.MultiLoginPolicy.MobileAndPCTerminalAccessButOtherTerminalKickEachOther { return nil, TokenInvalid
if exists == 1 {
res, err := MakeTheTokenInvalid(*claims, Platform2class[claims.Platform])
if err != nil {
return nil, err
}
if res {
return nil, TokenInvalid
}
} }
} }
return claims, nil
} }
return nil, TokenUnknown // config.Config.MultiLoginPolicy.MobileAndPCTerminalAccessButOtherTerminalKickEachOther == true
// or PC/Mobile validate success
// final check
if exists == 1 {
res, err := MakeTheTokenInvalid(claims, Platform2class[claims.Platform])
if err != nil {
return nil, err
}
if res {
return nil, TokenInvalid
}
}
return claims, nil
} }
func MakeTheTokenInvalid(currentClaims Claims, platformClass string) (bool, error) { func MakeTheTokenInvalid(currentClaims *Claims, platformClass string) (bool, error) {
storedRedisTokenInterface, err := db.DB.GetPlatformToken(currentClaims.UID, platformClass) storedRedisTokenInterface, err := db.DB.GetPlatformToken(currentClaims.UID, platformClass)
if err != nil { if err != nil {
return false, err return false, err
@ -154,40 +133,21 @@ func MakeTheTokenInvalid(currentClaims Claims, platformClass string) (bool, erro
return false, err return false, err
} }
//if issue time less than redis token then make this token invalid //if issue time less than redis token then make this token invalid
if currentClaims.IssuedAt < storedRedisPlatformClaims.IssuedAt { if currentClaims.IssuedAt.Time.Unix() < storedRedisPlatformClaims.IssuedAt.Time.Unix() {
return true, TokenInvalid return true, TokenInvalid
} }
return false, nil return false, nil
} }
func ParseRedisInterfaceToken(redisToken interface{}) (*Claims, error) { func ParseRedisInterfaceToken(redisToken interface{}) (*Claims, error) {
token, err := jwt.ParseWithClaims(string(redisToken.([]uint8)), &Claims{}, secret()) return getClaimFromToken(string(redisToken.([]uint8)))
if err != nil {
if ve, ok := err.(*jwt.ValidationError); ok {
if ve.Errors&jwt.ValidationErrorMalformed != 0 {
return nil, TokenMalformed
} else if ve.Errors&jwt.ValidationErrorExpired != 0 {
return nil, TokenExpired
} else if ve.Errors&jwt.ValidationErrorNotValidYet != 0 {
return nil, TokenNotValidYet
} else {
return nil, TokenInvalid
}
}
}
if claims, ok := token.Claims.(*Claims); ok && token.Valid {
return claims, nil
}
return nil, err
} }
//Validation token, false means failure, true means successful verification //Validation token, false means failure, true means successful verification
func VerifyToken(token, uid string) bool { func VerifyToken(token, uid string) bool {
claims, err := ParseToken(token) claims, err := ParseToken(token)
if err != nil { if err != nil || claims.UID != uid {
return false
} else if claims.UID != uid {
return false return false
} else {
return true
} }
return true
} }

@ -0,0 +1,81 @@
package utils
import (
"Open_IM/src/common/config"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
func Test_BuildClaims(t *testing.T) {
uid := "1"
platform := "PC"
ttl := int64(-1)
claim := BuildClaims(uid, platform, ttl)
now := time.Now().Unix()
assert.Equal(t, claim.UID, uid, "uid should equal")
assert.Equal(t, claim.Platform, platform, "platform should equal")
assert.Equal(t, claim.RegisteredClaims.ExpiresAt, int64(-1), "StandardClaims.ExpiresAt should be equal")
// time difference within 1s
assert.Equal(t, claim.RegisteredClaims.IssuedAt, now, "StandardClaims.IssuedAt should be equal")
assert.Equal(t, claim.RegisteredClaims.NotBefore, now, "StandardClaims.NotBefore should be equal")
ttl = int64(60)
now = time.Now().Unix()
claim = BuildClaims(uid, platform, ttl)
// time difference within 1s
assert.Equal(t, claim.RegisteredClaims.ExpiresAt, int64(60)+now, "StandardClaims.ExpiresAt should be equal")
assert.Equal(t, claim.RegisteredClaims.IssuedAt, now, "StandardClaims.IssuedAt should be equal")
assert.Equal(t, claim.RegisteredClaims.NotBefore, now, "StandardClaims.NotBefore should be equal")
}
func Test_CreateToken(t *testing.T) {
uid := "1"
platform := int32(1)
now := time.Now().Unix()
tokenString, expiresAt, err := CreateToken(uid, platform)
assert.NotEmpty(t, tokenString)
assert.Equal(t, expiresAt, 604800+now)
assert.Nil(t, err)
}
func Test_VerifyToken(t *testing.T) {
uid := "1"
platform := int32(1)
tokenString, _, _ := CreateToken(uid, platform)
result := VerifyToken(tokenString, uid)
assert.True(t, result)
result = VerifyToken(tokenString, "2")
assert.False(t, result)
}
func Test_ParseRedisInterfaceToken(t *testing.T) {
uid := "1"
platform := int32(1)
tokenString, _, _ := CreateToken(uid, platform)
claims, err := ParseRedisInterfaceToken([]uint8(tokenString))
assert.Nil(t, err)
assert.Equal(t, claims.UID, uid)
// timeout
config.Config.TokenPolicy.AccessExpire = -80
tokenString, _, _ = CreateToken(uid, platform)
claims, err = ParseRedisInterfaceToken([]uint8(tokenString))
assert.Equal(t, err, TokenExpired)
assert.Nil(t, claims)
}
func Test_ParseToken(t *testing.T) {
uid := "1"
platform := int32(1)
tokenString, _, _ := CreateToken(uid, platform)
claims, err := ParseToken(tokenString)
if err == nil {
assert.Equal(t, claims.UID, uid)
}
}
Loading…
Cancel
Save