Nigiri/schema.go

212 lines
5.5 KiB
Go

package nigiri
import (
"reflect"
"strings"
)
// ============================================================================
// Type Definitions
// ============================================================================
type ConstraintType string
const (
ConstraintUnique ConstraintType = "unique"
ConstraintForeign ConstraintType = "fkey"
ConstraintRequired ConstraintType = "required"
ConstraintIndex ConstraintType = "index"
ConstraintOneToOne ConstraintType = "one_to_one"
ConstraintOneToMany ConstraintType = "one_to_many"
ConstraintManyToOne ConstraintType = "many_to_one"
ConstraintManyToMany ConstraintType = "many_to_many"
)
type RelationshipType string
const (
RelationshipOneToOne RelationshipType = "one_to_one"
RelationshipOneToMany RelationshipType = "one_to_many"
RelationshipManyToOne RelationshipType = "many_to_one"
RelationshipManyToMany RelationshipType = "many_to_many"
)
type FieldConstraint struct {
Type ConstraintType
Field string
Target string
IndexName string
Relationship RelationshipType
TargetType reflect.Type
}
type SchemaInfo struct {
Fields map[string]reflect.Type
Constraints map[string][]FieldConstraint
Indices map[string]string
Relationships map[string]FieldConstraint
}
// ============================================================================
// Schema Parsing
// ============================================================================
func ParseSchema[T any]() *SchemaInfo {
var zero T
t := reflect.TypeOf(zero)
if t.Kind() == reflect.Pointer {
t = t.Elem()
}
schema := &SchemaInfo{
Fields: make(map[string]reflect.Type),
Constraints: make(map[string][]FieldConstraint),
Indices: make(map[string]string),
Relationships: make(map[string]FieldConstraint),
}
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
fieldName := field.Name
fieldType := field.Type
schema.Fields[fieldName] = fieldType
if relationship := detectRelationship(fieldName, fieldType); relationship != nil {
schema.Relationships[fieldName] = *relationship
schema.Constraints[fieldName] = append(schema.Constraints[fieldName], *relationship)
}
if dbTag := field.Tag.Get("db"); dbTag != "" {
if constraints := parseDBTag(fieldName, dbTag); len(constraints) > 0 {
schema.Constraints[fieldName] = append(schema.Constraints[fieldName], constraints...)
}
}
for _, constraint := range schema.Constraints[fieldName] {
if constraint.Type == ConstraintUnique || constraint.Type == ConstraintIndex {
indexName := constraint.IndexName
if indexName == "" {
indexName = fieldName + "_idx"
}
schema.Indices[fieldName] = indexName
}
}
}
return schema
}
// ============================================================================
// Relationship Detection
// ============================================================================
func detectRelationship(fieldName string, fieldType reflect.Type) *FieldConstraint {
switch fieldType.Kind() {
case reflect.Pointer:
elemType := fieldType.Elem()
if isEntityType(elemType) {
return &FieldConstraint{
Type: ConstraintManyToOne,
Field: fieldName,
Relationship: RelationshipManyToOne,
TargetType: elemType,
Target: getEntityName(elemType),
}
}
case reflect.Slice:
elemType := fieldType.Elem()
if elemType.Kind() == reflect.Pointer {
ptrTargetType := elemType.Elem()
if isEntityType(ptrTargetType) {
return &FieldConstraint{
Type: ConstraintOneToMany,
Field: fieldName,
Relationship: RelationshipOneToMany,
TargetType: ptrTargetType,
Target: getEntityName(ptrTargetType),
}
}
}
}
return nil
}
// ============================================================================
// Tag Parsing
// ============================================================================
func parseDBTag(fieldName, tag string) []FieldConstraint {
var constraints []FieldConstraint
parts := strings.SplitSeq(tag, ",")
for part := range parts {
part = strings.TrimSpace(part)
if part == "" {
continue
}
switch {
case part == "unique":
constraints = append(constraints, FieldConstraint{
Type: ConstraintUnique,
Field: fieldName,
})
case part == "required":
constraints = append(constraints, FieldConstraint{
Type: ConstraintRequired,
Field: fieldName,
})
case part == "index":
constraints = append(constraints, FieldConstraint{
Type: ConstraintIndex,
Field: fieldName,
})
case strings.HasPrefix(part, "index:"):
indexName := strings.TrimPrefix(part, "index:")
constraints = append(constraints, FieldConstraint{
Type: ConstraintIndex,
Field: fieldName,
IndexName: indexName,
})
case strings.HasPrefix(part, "fkey:"):
target := strings.TrimPrefix(part, "fkey:")
constraints = append(constraints, FieldConstraint{
Type: ConstraintForeign,
Field: fieldName,
Target: target,
})
}
}
return constraints
}
// ============================================================================
// Utility Functions
// ============================================================================
func isEntityType(t reflect.Type) bool {
if t.Kind() != reflect.Struct {
return false
}
for i := 0; i < t.NumField(); i++ {
field := t.Field(i)
if field.Name == "ID" && field.Type.Kind() == reflect.Int {
return true
}
}
return false
}
func getEntityName(t reflect.Type) string {
name := t.Name()
if name == "" {
name = t.String()
}
return strings.ToLower(name)
}