pull/247/head
HFO4 5 years ago
parent 53da4655ad
commit ec7ab83d7c

@ -7,7 +7,7 @@ import (
) )
func init() { func init() {
conf.Init() conf.Init("conf/conf.ini")
model.Init() model.Init()
} }

@ -1,7 +1,9 @@
package model package model
import ( import (
"Cloudreve/pkg/conf"
"Cloudreve/pkg/util" "Cloudreve/pkg/util"
"fmt"
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"time" "time"
@ -15,13 +17,32 @@ var DB *gorm.DB
func Init() { func Init() {
//TODO 从配置文件中读取 包括DEBUG模式 //TODO 从配置文件中读取 包括DEBUG模式
util.Log().Info("初始化数据库连接\n") util.Log().Info("初始化数据库连接\n")
db, err := gorm.Open("mysql", "root:root@(localhost)/v3?charset=utf8&parseTime=True&loc=Local")
var (
db *gorm.DB
err error
)
if conf.DatabaseConfig.Type == "UNSET" {
//TODO 使用内置SQLite数据库
} else {
db, err = gorm.Open(conf.DatabaseConfig.Type, fmt.Sprintf("%s:%s@(%s)/%s?charset=utf8&parseTime=True&loc=Local",
conf.DatabaseConfig.User,
conf.DatabaseConfig.Password,
conf.DatabaseConfig.Host,
conf.DatabaseConfig.Name))
}
// 处理表前缀
gorm.DefaultTableNameHandler = func(db *gorm.DB, defaultTableName string) string {
return conf.DatabaseConfig.TablePrefix + defaultTableName
}
db.LogMode(true) db.LogMode(true)
//db.SetLogger(util.Log()) //db.SetLogger(util.Log())
// Error
if err != nil { if err != nil {
util.Log().Panic("连接数据库不成功", err) util.Log().Panic("连接数据库不成功", err)
} }
//设置连接池 //设置连接池
//空闲 //空闲
db.DB().SetMaxIdleConns(50) db.DB().SetMaxIdleConns(50)
@ -32,5 +53,6 @@ func Init() {
DB = db DB = db
//执行迁移
migration() migration()
} }

@ -2,15 +2,11 @@ package conf
import ( import (
"Cloudreve/pkg/util" "Cloudreve/pkg/util"
"fmt"
"github.com/go-ini/ini" "github.com/go-ini/ini"
) )
type Conf struct { // Database 数据库
Database Database type database struct {
}
type Database struct {
Type string Type string
User string User string
Password string Password string
@ -19,27 +15,33 @@ type Database struct {
TablePrefix string TablePrefix string
} }
var database = &Database{ var DatabaseConfig = &database{
Type: "UNSET", Type: "UNSET",
} }
var cfg *ini.File var cfg *ini.File
func Init() { // Init 初始化配置文件
func Init(path string) {
var err error var err error
//TODO 配置文件不存在时创建 //TODO 配置文件不存在时创建
cfg, err = ini.Load("conf/conf.ini") //TODO 配置合法性验证
cfg, err = ini.Load(path)
if err != nil {
util.Log().Panic("无法解析配置文件 '%s': ", path, err)
}
err = mapSection("Database", DatabaseConfig)
if err != nil { if err != nil {
util.Log().Panic("无法解析配置文件 'conf/conf.ini': ", err) util.Log().Warning("配置文件 %s 分区解析失败: ", "Database", err)
} }
mapSection("Database", database)
fmt.Println(database)
} }
func mapSection(section string, confStruct interface{}) { // mapSection 将配置文件的 Section 映射到结构体上
err := cfg.Section("Database").MapTo(database) func mapSection(section string, confStruct interface{}) error {
err := cfg.Section(section).MapTo(confStruct)
if err != nil { if err != nil {
util.Log().Warning("配置文件 Database 分区解析失败") return err
} }
return nil
} }

@ -0,0 +1,78 @@
package conf
import (
"github.com/stretchr/testify/assert"
"io/ioutil"
"testing"
)
// 测试Init日志路径错误
func TestInitPanic(t *testing.T) {
asserts := assert.New(t)
// 日志路径不存在时
asserts.Panics(func() {
Init("not/exist/path")
})
}
// TestInitDelimiterNotFound 日志路径存在但 Key 格式错误时
func TestInitDelimiterNotFound(t *testing.T) {
asserts := assert.New(t)
testCase := `[Database]
Type = mysql
User = root
Password233root
Host = 127.0.0.1:3306
Name = v3
TablePrefix = v3_`
err := ioutil.WriteFile("testConf.ini", []byte(testCase), 0644)
if err != nil {
panic(err)
}
asserts.Panics(func() {
Init("testConf.ini")
})
}
// TestInitNoPanic 日志路径存在且合法时
func TestInitNoPanic(t *testing.T) {
asserts := assert.New(t)
testCase := `[Database]
Type = mysql
User = root
Password = root
Host = 127.0.0.1:3306
Name = v3
TablePrefix = v3_`
err := ioutil.WriteFile("testConf.ini", []byte(testCase), 0644)
if err != nil {
panic(err)
}
asserts.NotPanics(func() {
Init("testConf.ini")
})
}
func TestMapSection(t *testing.T) {
asserts := assert.New(t)
//正常情况
testCase := `[Database]
Type = mysql
User = root
Password:root
Host = 127.0.0.1:3306
Name = v3
TablePrefix = v3_`
err := ioutil.WriteFile("testConf.ini", []byte(testCase), 0644)
if err != nil {
panic(err)
}
Init("testConf.ini")
err = mapSection("Database", DatabaseConfig)
asserts.NoError(err)
// TODO 类型不匹配测试
}

@ -2,7 +2,6 @@ package util
import ( import (
"fmt" "fmt"
"os"
"time" "time"
) )
@ -36,7 +35,7 @@ func (ll *Logger) Panic(format string, v ...interface{}) {
} }
msg := fmt.Sprintf("[Panic] "+format, v...) msg := fmt.Sprintf("[Panic] "+format, v...)
ll.Println(msg) ll.Println(msg)
os.Exit(0) panic(msg)
} }
// Error 错误 // Error 错误
@ -80,7 +79,7 @@ func (ll *Logger) Print(v ...interface{}) {
if LevelDebug > ll.level { if LevelDebug > ll.level {
return return
} }
msg := fmt.Sprintf("[SQL] ", v...) msg := fmt.Sprintf("[SQL] %s", v...)
ll.Println(msg) ll.Println(msg)
} }

Loading…
Cancel
Save