Skip to content

Commit 43533de

Browse files
authored
Merge pull request #117 from sev-2/tweak
tweak : register postgres Date and rpc register
2 parents f18822b + 9e57c00 commit 43533de

File tree

9 files changed

+302
-6
lines changed

9 files changed

+302
-6
lines changed

pkg/cli/generate/command.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,12 @@ func Run(flags *Flags, config *raiden.Config, projectPath string, initialize boo
194194
}
195195
GenerateLogger.Debug("finish generate role register file")
196196

197+
// generate rpc register
198+
GenerateLogger.Debug("start generate rpc register file")
199+
if err := generator.GenerateRpcRegister(projectPath, config.ProjectName, generator.Generate); err != nil {
200+
errChan <- err
201+
}
202+
GenerateLogger.Debug("finish generate rpc register file")
197203
}
198204

199205
// generate job register

pkg/generator/model.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ func MapTableAttributes(projectName string, table objects.Table, mapDataType map
209209
"*postgres.Point": true,
210210
"postgres.DateTime": true,
211211
"*postgres.DateTime": true,
212+
"postgres.Date": true,
213+
"*postgres.Date": true,
212214
}
213215

214216
if postgresCustomTypes[column.Type] {

pkg/generator/rpc.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,8 @@ func (r *ExtractRpcDataResult) GetParams(mapImports map[string]bool) (columns []
637637
importPackageName = "github.com/google/uuid"
638638
case "json":
639639
importPackageName = "encoding/json"
640+
case "postgres":
641+
importPackageName = "github.com/sev-2/raiden/pkg/postgres"
640642
}
641643
key := fmt.Sprintf("%q", importPackageName)
642644
mapImports[key] = true
@@ -725,6 +727,8 @@ func (r *ExtractRpcDataResult) GetReturn(mapImports map[string]bool) (returnDecl
725727
importPackageName = "github.com/google/uuid"
726728
case "json":
727729
importPackageName = "encoding/json"
730+
case "postgres":
731+
importPackageName = "github.com/sev-2/raiden/pkg/postgres"
728732
}
729733
key := fmt.Sprintf("%q", importPackageName)
730734
mapImports[key] = true

pkg/generator/rpc_test.go

Lines changed: 111 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ func TestExtractRpcWithPrefix(t *testing.T) {
178178
where sc.name = in_scouter_name and c.name = in_candidate_name ; end;
179179
`,
180180
CompleteStatement: `
181-
CREATE OR REPLACE FUNCTION public.get_submissions(in_scouter_name character varying, in_candidate_name character varying)\n
182-
RETURNS TABLE(id integer, created_at timestamp without time zone, sc_name character varying, c_name character varying)\n
181+
CREATE OR REPLACE FUNCTION public.get_submissions(in_scouter_name character varying, in_candidate_name character varying, in_register date)\n
182+
RETURNS TABLE(id integer, created_at timestamp without time zone, sc_name character varying, c_name character varying, in_register date)\n
183183
LANGUAGE plpgsql\n
184184
AS $function$
185185
begin return query
@@ -202,6 +202,12 @@ func TestExtractRpcWithPrefix(t *testing.T) {
202202
TypeId: 1043,
203203
HasDefault: false,
204204
},
205+
{
206+
Mode: "in",
207+
Name: "in_register",
208+
TypeId: 1043,
209+
HasDefault: false,
210+
},
205211
{
206212
Mode: "table",
207213
Name: "id",
@@ -227,10 +233,10 @@ func TestExtractRpcWithPrefix(t *testing.T) {
227233
HasDefault: false,
228234
},
229235
},
230-
ArgumentTypes: "in_scouter_name character varying, in_candidate_name character varying",
231-
IdentityArgumentTypes: "in_scouter_name character varying, in_candidate_name character varying",
236+
ArgumentTypes: "in_scouter_name character varying, in_candidate_name character varying, in_register date",
237+
IdentityArgumentTypes: "in_scouter_name character varying, in_candidate_name character varying, in_register date",
232238
ReturnTypeID: 2249,
233-
ReturnType: "TABLE(id integer, created_at timestamp without time zone, sc_name character varying, c_name character varying)",
239+
ReturnType: "TABLE(id integer, created_at timestamp without time zone, sc_name character varying, c_name character varying, register date)",
234240
ReturnTypeRelationID: 0,
235241
IsSetReturningFunction: true,
236242
Behavior: string(raiden.RpcBehaviorVolatile),
@@ -518,6 +524,106 @@ func TestGenerateRpc(t *testing.T) {
518524
assert.FileExists(t, dir+"/internal/rpc/get_submissions.go")
519525
}
520526

527+
func TestGenerateRpc_DateType(t *testing.T) {
528+
fns := []objects.Function{
529+
{
530+
Schema: "public",
531+
Name: "get_latest_active_rates_by_tenant",
532+
Language: "sql",
533+
Definition: "\n select distinct on (tr.tax_id)\n tr.id,\n tr.tax_id,\n tr.type,\n tr.rate,\n tr.start_date,\n tr.end_date,\n tr.rate as applicable_rate,\n t.name as tax_name,\n t.tenant::text\n from tax_rates tr\n join taxes t on tr.tax_id = t.id\n where t.tenant = input_tenant::public.tenant\n and tr.start_date <= now()\n and (tr.end_date is null or tr.end_date > now())\n order by tr.tax_id, tr.start_date desc\n",
534+
CompleteStatement: "CREATE OR REPLACE FUNCTION public.get_latest_active_rates_by_tenant(input_tenant text)\n RETURNS TABLE(id bigint, tax_id bigint, type text, rate numeric, start_date date, end_date date, applicable_rate numeric, tax_name text, tenant text)\n LANGUAGE sql\nAS $function$\n select distinct on (tr.tax_id)\n tr.id,\n tr.tax_id,\n tr.type,\n tr.rate,\n tr.start_date,\n tr.end_date,\n tr.rate as applicable_rate,\n t.name as tax_name,\n t.tenant::text\n from tax_rates tr\n join taxes t on tr.tax_id = t.id\n where t.tenant = input_tenant::public.tenant\n and tr.start_date <= now()\n and (tr.end_date is null or tr.end_date > now())\n order by tr.tax_id, tr.start_date desc\n$function$\n",
535+
Args: []objects.FunctionArg{
536+
{
537+
Mode: "in",
538+
Name: "input_tenant",
539+
TypeId: 25,
540+
HasDefault: false,
541+
},
542+
{
543+
Mode: "in",
544+
Name: "input_register",
545+
TypeId: 25,
546+
HasDefault: false,
547+
},
548+
{
549+
Mode: "table",
550+
Name: "id",
551+
TypeId: 20,
552+
HasDefault: false,
553+
},
554+
{
555+
Mode: "table",
556+
Name: "tax_id",
557+
TypeId: 20,
558+
HasDefault: false,
559+
},
560+
{
561+
Mode: "table",
562+
Name: "type",
563+
TypeId: 25,
564+
HasDefault: false,
565+
},
566+
{
567+
Mode: "table",
568+
Name: "rate",
569+
TypeId: 1700,
570+
HasDefault: false,
571+
},
572+
{
573+
Mode: "table",
574+
Name: "start_date",
575+
TypeId: 1082,
576+
HasDefault: false,
577+
},
578+
{
579+
Mode: "table",
580+
Name: "end_date",
581+
TypeId: 1082,
582+
HasDefault: false,
583+
},
584+
{
585+
Mode: "table",
586+
Name: "applicable_rate",
587+
TypeId: 1700,
588+
HasDefault: false,
589+
},
590+
{
591+
Mode: "table",
592+
Name: "tax_name",
593+
TypeId: 25,
594+
HasDefault: false,
595+
},
596+
{
597+
Mode: "table",
598+
Name: "tenant",
599+
TypeId: 25,
600+
HasDefault: false,
601+
},
602+
},
603+
ArgumentTypes: "input_tenant text, input_register date",
604+
IdentityArgumentTypes: "input_tenant text, input_register date",
605+
ReturnTypeID: 2249,
606+
ReturnType: "TABLE(id bigint, tax_id bigint, type text, rate numeric, start_date date, end_date date, applicable_rate numeric, tax_name text, tenant text)",
607+
ReturnTypeRelationID: 0,
608+
IsSetReturningFunction: true,
609+
Behavior: string(raiden.RpcBehaviorVolatile),
610+
SecurityDefiner: false,
611+
ConfigParams: nil,
612+
},
613+
}
614+
615+
dir, err := os.MkdirTemp("", "rpc")
616+
assert.NoError(t, err)
617+
618+
rpcPath := filepath.Join(dir, "internal")
619+
err1 := utils.CreateFolder(rpcPath)
620+
assert.NoError(t, err1)
621+
622+
err2 := generator.GenerateRpc(dir, "test", fns, []objects.Table{}, generator.GenerateFn(generator.Generate))
623+
assert.NoError(t, err2)
624+
assert.FileExists(t, dir+"/internal/rpc/get_latest_active_rates_by_tenant.go")
625+
}
626+
521627
func TestRpcWithTrigger(t *testing.T) {
522628
fn := objects.Function{
523629
Schema: "public",

pkg/postgres/data_type.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,10 @@ func ToGoType(pgType DataType, isNullable bool) (goType string) {
114114
goType = "float64"
115115
case VarcharType, VarcharTypeAlias, CharType, BpcharType, TextType:
116116
goType = "string"
117-
case TimestampType, TimestampTypeAlias, TimestampTzType, TimestampTzTypeAlias, TimeType, TimeTypeAlias, TimeTzType, TimeTzTypeAlias, DateType:
117+
case TimestampType, TimestampTypeAlias, TimestampTzType, TimestampTzTypeAlias, TimeType, TimeTypeAlias, TimeTzType, TimeTzTypeAlias:
118118
goType = "postgres.DateTime"
119+
case DateType:
120+
goType = "postgres.Date"
119121
case IntervalType:
120122
goType = "time.Duration"
121123
case BooleanType:
@@ -160,6 +162,8 @@ func ToPostgresType(goType string) (pgType DataType) {
160162
pgType = TextType
161163
case "postgres.DateTime":
162164
pgType = TimestampType
165+
case "postgres.Date":
166+
pgType = DateType
163167
case "time.Duration":
164168
pgType = IntervalType
165169
case "bool":

pkg/postgres/date.go

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
package postgres
2+
3+
import (
4+
"database/sql/driver"
5+
"fmt"
6+
"strings"
7+
"time"
8+
)
9+
10+
type Date struct {
11+
time.Time
12+
}
13+
14+
// Ensure Date implements sql.Scanner and driver.Valuer interfaces
15+
var _ driver.Valuer = (*Date)(nil)
16+
var _ fmt.Stringer = (*Date)(nil)
17+
18+
const dateFormat = "2006-01-02"
19+
20+
func (d *Date) Scan(value interface{}) error {
21+
switch v := value.(type) {
22+
case time.Time:
23+
d.Time = v.UTC().Truncate(24 * time.Hour)
24+
return nil
25+
case string:
26+
t, err := time.Parse(dateFormat, v)
27+
if err != nil {
28+
return err
29+
}
30+
d.Time = t
31+
return nil
32+
case []byte:
33+
t, err := time.Parse(dateFormat, string(v))
34+
if err != nil {
35+
return err
36+
}
37+
d.Time = t
38+
return nil
39+
default:
40+
return fmt.Errorf("cannot scan type %T into Date", value)
41+
}
42+
}
43+
44+
func (d Date) Value() (driver.Value, error) {
45+
return d.Format(dateFormat), nil
46+
}
47+
48+
func (d Date) String() string {
49+
return d.Format(dateFormat)
50+
}
51+
52+
// JSON Marshal/Unmarshal support
53+
func (d Date) MarshalJSON() ([]byte, error) {
54+
return []byte(fmt.Sprintf(`"%s"`, d.Format(dateFormat))), nil
55+
}
56+
57+
func (d *Date) UnmarshalJSON(b []byte) error {
58+
s := strings.Trim(string(b), `"`)
59+
t, err := time.Parse(dateFormat, s)
60+
if err != nil {
61+
return err
62+
}
63+
d.Time = t
64+
return nil
65+
}

pkg/postgres/date_test.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package postgres_test
2+
3+
import (
4+
"encoding/json"
5+
"testing"
6+
"time"
7+
8+
"github.com/sev-2/raiden/pkg/postgres"
9+
"github.com/stretchr/testify/assert"
10+
)
11+
12+
func TestDate_ScanFromString(t *testing.T) {
13+
var d postgres.Date
14+
err := d.Scan("2025-05-22")
15+
if err != nil {
16+
t.Fatalf("unexpected error: %v", err)
17+
}
18+
if d.Format("2006-01-02") != "2025-05-22" {
19+
t.Errorf("expected 2025-05-22, got %s", d.String())
20+
}
21+
}
22+
23+
func TestDate_ScanFromString_Err(t *testing.T) {
24+
var d postgres.Date
25+
err := d.Scan("2025-05-22-01")
26+
assert.Error(t, err)
27+
}
28+
29+
func TestDate_ScanFromBytes(t *testing.T) {
30+
var d postgres.Date
31+
err := d.Scan([]byte("2025-05-22"))
32+
if err != nil {
33+
t.Fatalf("unexpected error: %v", err)
34+
}
35+
if d.Format("2006-01-02") != "2025-05-22" {
36+
t.Errorf("expected 2025-05-22, got %s", d.String())
37+
}
38+
}
39+
40+
func TestDate_ScanFromTime(t *testing.T) {
41+
expected := time.Date(2025, 5, 22, 12, 0, 0, 0, time.UTC)
42+
var d postgres.Date
43+
err := d.Scan(expected)
44+
if err != nil {
45+
t.Fatalf("unexpected error: %v", err)
46+
}
47+
if !d.Time.Equal(expected.Truncate(24 * time.Hour)) {
48+
t.Errorf("expected %v, got %v", expected, d.Time)
49+
}
50+
}
51+
52+
func TestDateValue(t *testing.T) {
53+
d := postgres.Date{Time: time.Date(2025, 5, 22, 0, 0, 0, 0, time.UTC)}
54+
val, err := d.Value()
55+
if err != nil {
56+
t.Fatalf("unexpected error: %v", err)
57+
}
58+
if str, ok := val.(string); !ok || str != "2025-05-22" {
59+
t.Errorf("expected string 2025-05-22, got %v", val)
60+
}
61+
}
62+
63+
func TestDate_JSONMarshaling(t *testing.T) {
64+
d := postgres.Date{Time: time.Date(2025, 5, 22, 0, 0, 0, 0, time.UTC)}
65+
jsonBytes, err := json.Marshal(d)
66+
if err != nil {
67+
t.Fatalf("unexpected error: %v", err)
68+
}
69+
expected := `"2025-05-22"`
70+
if string(jsonBytes) != expected {
71+
t.Errorf("expected %s, got %s", expected, string(jsonBytes))
72+
}
73+
}
74+
75+
func TestDate_JSONUnmarshaling(t *testing.T) {
76+
var d postgres.Date
77+
err := json.Unmarshal([]byte(`"2025-05-22"`), &d)
78+
if err != nil {
79+
t.Fatalf("unexpected error: %v", err)
80+
}
81+
if d.String() != "2025-05-22" {
82+
t.Errorf("expected 2025-05-22, got %s", d.String())
83+
}
84+
}
85+
86+
func TestDate_ScanInvalidType(t *testing.T) {
87+
var d postgres.Date
88+
err := d.Scan(123)
89+
if err == nil {
90+
t.Fatalf("expected error, got nil")
91+
}
92+
}

0 commit comments

Comments
 (0)