Skip to content

Commit 06e7909

Browse files
committed
Mysql and SQLite schema
1 parent a9f6fd6 commit 06e7909

File tree

5 files changed

+235
-35
lines changed

5 files changed

+235
-35
lines changed

schema/mysql.go

Lines changed: 55 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@ import (
44
"database/sql"
55
"errors"
66
"fmt"
7+
"net"
8+
"strconv"
79
"sync"
810

9-
_ "github.com/go-sql-driver/mysql"
11+
mysql "github.com/go-sql-driver/mysql"
1012
"github.com/xwb1989/sqlparser"
1113
)
1214

@@ -15,44 +17,69 @@ var (
1517
ErrDDLNotFound = errors.New("ddl not found")
1618
)
1719

18-
var _ TableScanner = &MysqlTableScanner{}
20+
var _ TableScanner = &MysqlScanner{}
1921

20-
type MysqlTableScanner struct {
21-
DbName string
22-
TableName string
23-
Host string
24-
Port int
25-
User string
26-
Password string
27-
conn *sql.DB
22+
type MysqlScanner struct {
23+
DbName string
24+
Host string
25+
Port int
26+
User string
27+
Password string
28+
conn *sql.DB
2829
sync.Once
2930
}
3031

31-
func NewMysqlTableScanner(dbName, tableName, host, user, password string, port int) *MysqlTableScanner {
32-
return &MysqlTableScanner{
33-
DbName: dbName,
34-
TableName: tableName,
35-
Host: host,
36-
Port: port,
37-
User: user,
38-
Password: password,
32+
func ParseDsn(dsn string) (host string, port int,
33+
user string, password string, dbName string, err error) {
34+
var (
35+
addr string
36+
portStr string
37+
dsnInfo *mysql.Config
38+
)
39+
if dsnInfo, err = mysql.ParseDSN(dsn); err != nil {
40+
return
41+
}
42+
addr = dsnInfo.Addr
43+
user = dsnInfo.User
44+
password = dsnInfo.Passwd
45+
if host, portStr, err = net.SplitHostPort(addr); err != nil {
46+
return
47+
}
48+
if portStr != "" {
49+
if port, err = strconv.Atoi(portStr); err != nil {
50+
return
51+
}
52+
} else {
53+
port = 3306
3954
}
55+
dbName = dsnInfo.DBName
56+
return
4057
}
4158

42-
func (s *MysqlTableScanner) initConn() (err error) {
59+
func NewMysqlScanner(dbName, host, user, password string, port int) *MysqlScanner {
60+
return &MysqlScanner{
61+
DbName: dbName,
62+
Host: host,
63+
Port: port,
64+
User: user,
65+
Password: password,
66+
}
67+
}
68+
69+
func (s *MysqlScanner) initConn() (err error) {
4370
s.Do(func() {
4471
s.conn, err = sql.Open("mysql",
4572
s.User+":"+s.Password+"@tcp("+s.Host+":"+fmt.Sprint(s.Port)+")/"+s.DbName)
4673
})
4774
return err
4875
}
4976

50-
func (s *MysqlTableScanner) GetSchema() (schema *Schema, err error) {
77+
func (s *MysqlScanner) GetSchema(tableName string) (schema *Schema, err error) {
5178
if err = s.initConn(); err != nil {
5279
return
5380
}
5481

55-
ddl, err := s.conn.Query("SHOW CREATE TABLE " + s.TableName)
82+
ddl, err := s.conn.Query("SHOW CREATE TABLE " + tableName)
5683
if err != nil {
5784
return
5885
}
@@ -81,12 +108,12 @@ func (s *MysqlTableScanner) GetSchema() (schema *Schema, err error) {
81108
}
82109

83110
// GetRows returns a row Scanner for the given table.
84-
func (s *MysqlTableScanner) GetRows() (rows *sql.Rows, err error) {
111+
func (s *MysqlScanner) GetRows(query string) (rows *sql.Rows, err error) {
85112
if s.initConn() != nil {
86113
return
87114
}
88115

89-
rows, err = s.conn.Query("SELECT * FROM " + s.TableName)
116+
rows, err = s.conn.Query(query)
90117
return
91118
}
92119

@@ -107,10 +134,12 @@ func ParseMysqlDDL(dbName string, ddlStr string) (schema *Schema, err error) {
107134
schema.TableName = stmt.NewName.Name.String()
108135
for i, def := range stmt.TableSpec.Columns {
109136
schema.Columns = append(schema.Columns, Column{
110-
Name: def.Name.String(),
111-
Type: def.Type.Type,
112-
Comment: string(def.Type.Comment.Val),
137+
Name: def.Name.String(),
138+
Type: def.Type.Type,
113139
})
140+
if def.Type.Comment != nil {
141+
schema.Columns[i].Comment = string(def.Type.Comment.Val)
142+
}
114143
if def.Type.Length != nil {
115144
schema.Columns[i].Size = string(def.Type.Length.Val)
116145
}

schema/mysql_test.go

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,25 @@ const (
1717
)
1818

1919
func TestSchema(t *testing.T) {
20-
scaner := NewMysqlTableScanner(mysqlTestDbName, mysqlTestTableName, mysqlTestHost, mysqlTestUser, mysqlTestPassword, mysqlTestPort)
20+
scaner := NewMysqlScanner(mysqlTestDbName, mysqlTestHost, mysqlTestUser, mysqlTestPassword, mysqlTestPort)
2121
Convey("get schema", t, func() {
22-
schema, err := scaner.GetSchema()
22+
schema, err := scaner.GetSchema(mysqlTestTableName)
2323
So(err, ShouldBeNil)
2424
So(schema, ShouldNotBeNil)
25-
So(schema.DbName, ShouldEqual, "auxten")
26-
So(schema.TableName, ShouldEqual, "task_ddl")
25+
So(schema.DbName, ShouldEqual, mysqlTestDbName)
26+
So(schema.TableName, ShouldEqual, mysqlTestTableName)
2727
So(schema.Columns, ShouldHaveLength, 8)
2828
})
2929

3030
Convey("schema not exist", t, func() {
31-
scaner := NewMysqlTableScanner(mysqlTestDbName, "not_exist", mysqlTestHost, mysqlTestUser, mysqlTestPassword, mysqlTestPort)
32-
schema, err := scaner.GetSchema()
31+
scaner := NewMysqlScanner(mysqlTestDbName, mysqlTestHost, mysqlTestUser, mysqlTestPassword, mysqlTestPort)
32+
schema, err := scaner.GetSchema("not_exist")
3333
So(fmt.Sprint(err), ShouldContainSubstring, "Error 1146: Table 'auxten.not_exist' doesn't exist")
3434
So(schema, ShouldBeNil)
3535
})
3636

3737
Convey("get rows", t, func() {
38-
rows, err := scaner.GetRows()
38+
rows, err := scaner.GetRows("select * from " + mysqlTestTableName)
3939
So(rows.Next(), ShouldBeTrue)
4040
var (
4141
task_id int64
@@ -62,3 +62,26 @@ func TestSchema(t *testing.T) {
6262
So(created_at, ShouldEqual, "2022-06-19 23:01:47")
6363
})
6464
}
65+
66+
func TestParseDsn(t *testing.T) {
67+
Convey("parse dsn", t, func() {
68+
dsn := fmt.Sprintf("%s:%s@tcp(%s:%d)/%s", mysqlTestUser, mysqlTestPassword, mysqlTestHost, mysqlTestPort, mysqlTestDbName)
69+
host, port, user, password, dbName, err := ParseDsn(dsn)
70+
So(err, ShouldBeNil)
71+
So(host, ShouldEqual, mysqlTestHost)
72+
So(port, ShouldEqual, mysqlTestPort)
73+
So(user, ShouldEqual, mysqlTestUser)
74+
So(password, ShouldEqual, mysqlTestPassword)
75+
So(dbName, ShouldEqual, mysqlTestDbName)
76+
})
77+
Convey("parse dsn no port", t, func() {
78+
dsn := fmt.Sprintf("%s:%s@tcp(%s)/", mysqlTestUser, mysqlTestPassword, mysqlTestHost)
79+
host, port, user, password, dbName, err := ParseDsn(dsn)
80+
So(err, ShouldBeNil)
81+
So(host, ShouldEqual, mysqlTestHost)
82+
So(port, ShouldEqual, 3306)
83+
So(user, ShouldEqual, mysqlTestUser)
84+
So(password, ShouldEqual, mysqlTestPassword)
85+
So(dbName, ShouldEqual, "")
86+
})
87+
}

schema/scan.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ import (
66

77
//TableScanner could get schema or content of given mysql table
88
type TableScanner interface {
9-
GetSchema() (schema *Schema, err error)
10-
GetRows() (rows *sql.Rows, err error)
9+
GetSchema(tableName string) (schema *Schema, err error)
10+
GetRows(query string) (rows *sql.Rows, err error)
1111
}

schema/sqlite.go

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
package schema
2+
3+
import (
4+
"database/sql"
5+
"fmt"
6+
"sync"
7+
8+
_ "github.com/mattn/go-sqlite3" //keep
9+
)
10+
11+
type SqliteScanner struct {
12+
DbPath string
13+
conn *sql.DB
14+
sync.Once
15+
}
16+
17+
func NewSqliteScanner(dbPath string) *SqliteScanner {
18+
return &SqliteScanner{
19+
DbPath: dbPath,
20+
}
21+
}
22+
23+
func (s *SqliteScanner) initConn() (err error) {
24+
s.Do(func() {
25+
s.conn, err = sql.Open("sqlite3",
26+
fmt.Sprintf("file:%s?cache=shared", s.DbPath),
27+
)
28+
})
29+
return err
30+
}
31+
32+
func (s *SqliteScanner) GetSchema(tableName string) (schema *Schema, err error) {
33+
if s.initConn() != nil {
34+
return
35+
}
36+
37+
ddl, err := s.conn.Query(fmt.Sprintf("PRAGMA table_info(%s)", tableName))
38+
if err != nil {
39+
return
40+
}
41+
defer ddl.Close()
42+
var (
43+
cid sql.NullInt64
44+
name, typeStr, notNull, defaultVal sql.NullString
45+
pk sql.NullInt64
46+
)
47+
schema = &Schema{
48+
TableName: tableName,
49+
Columns: make([]Column, 0),
50+
}
51+
for ddl.Next() {
52+
if err = ddl.Scan(&cid, &name, &typeStr, &notNull, &defaultVal, &pk); err != nil {
53+
return
54+
}
55+
schema.Columns = append(schema.Columns, Column{
56+
Name: name.String,
57+
Type: typeStr.String,
58+
Size: "",
59+
Extra: fmt.Sprintf("notNull:%s, defaultVal:%s, pk:%d",
60+
notNull.String, defaultVal.String, pk.Int64),
61+
Comment: "",
62+
})
63+
}
64+
65+
return
66+
}
67+
68+
func (s *SqliteScanner) GetRows(query string) (rows *sql.Rows, err error) {
69+
if s.initConn() != nil {
70+
return
71+
}
72+
73+
rows, err = s.conn.Query(query)
74+
return
75+
}

schema/sqlite_test.go

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
package schema
2+
3+
import (
4+
"database/sql"
5+
"os"
6+
"testing"
7+
8+
. "github.com/smartystreets/goconvey/convey"
9+
)
10+
11+
func TestSqliteScanner(t *testing.T) {
12+
//create temp db
13+
tempDbFile, err := os.CreateTemp("", "sqlite_test.db")
14+
if err != nil {
15+
t.Fatal(err)
16+
}
17+
defer os.Remove(tempDbFile.Name())
18+
db, err := sql.Open("sqlite3", tempDbFile.Name())
19+
if err != nil {
20+
t.Fatal(err)
21+
}
22+
_, err = db.Exec("CREATE TABLE task_ddl (" +
23+
"task_id INTEGER PRIMARY KEY, title TEXT, start_date TEXT, due_date TEXT, " +
24+
"status INT, priority INT, description TEXT, created_at TEXT)")
25+
if err != nil {
26+
t.Fatal(err)
27+
}
28+
29+
// insert some data
30+
_, err = db.Exec("INSERT INTO task_ddl (task_id, title, start_date, due_date, status, priority, description, created_at) " +
31+
"VALUES (1, 't1', '2022-06-19', '2022-06-20', 0, 1, 'd1', '2022-06-19'), " +
32+
"(2, 't2', '2022-06-19', '2022-06-20', 0, 1, 'd2', '2022-06-19')")
33+
if err != nil {
34+
t.Fatal(err)
35+
}
36+
37+
_ = db.Close()
38+
Convey("get schema", t, func() {
39+
scanner := NewSqliteScanner(tempDbFile.Name())
40+
schema, err := scanner.GetSchema("task_ddl")
41+
So(err, ShouldBeNil)
42+
So(schema.Columns, ShouldHaveLength, 8)
43+
})
44+
45+
Convey("get rows", t, func() {
46+
scanner := NewSqliteScanner(tempDbFile.Name())
47+
rows, err := scanner.GetRows("select * from task_ddl")
48+
So(rows.Next(), ShouldBeTrue)
49+
var (
50+
task_id int64
51+
title string
52+
start_date string
53+
due_date string
54+
status int
55+
priority int
56+
description string
57+
created_at string
58+
)
59+
err = rows.Scan(
60+
&task_id, &title, &start_date, &due_date,
61+
&status, &priority, &description, &created_at)
62+
So(err, ShouldBeNil)
63+
So(rows, ShouldNotBeNil)
64+
So(task_id, ShouldEqual, 1)
65+
So(title, ShouldEqual, "t1")
66+
So(start_date, ShouldEqual, "2022-06-19")
67+
So(due_date, ShouldEqual, "2022-06-20")
68+
So(status, ShouldEqual, 0)
69+
So(priority, ShouldEqual, 1)
70+
So(description, ShouldEqual, "d1")
71+
So(created_at, ShouldEqual, "2022-06-19")
72+
})
73+
}

0 commit comments

Comments
 (0)