Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix nullability inference from query plan #42

Merged
merged 5 commits into from
Nov 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
- uses: TanklesXL/gleam_actions/.github/actions/install_gleam@main
with:
erlang_version: 27
gleam_version: 1.5.1
gleam_version: 1.6.0
- uses: TanklesXL/gleam_actions/.github/actions/hex_publish@main
with:
hex_user: ${{ secrets.HEXPM_USER }}
Expand Down
10 changes: 5 additions & 5 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
- uses: actions/checkout@v4
- uses: TanklesXL/gleam_actions/.github/actions/install_gleam@main
with:
gleam_version: 1.5.1
gleam_version: 1.6.0
erlang_version: 27
- uses: TanklesXL/gleam_actions/.github/actions/format@main

Expand All @@ -24,19 +24,19 @@ jobs:
- uses: actions/checkout@v4
- uses: TanklesXL/gleam_actions/.github/actions/install_gleam@main
with:
gleam_version: 1.5.1
gleam_version: 1.6.0
erlang_version: 27
- uses: TanklesXL/gleam_actions/.github/actions/deps_cache@main
with:
gleam_version: 1.5.1
gleam_version: 1.6.0

test:
runs-on: ubuntu-latest
needs: deps
strategy:
fail-fast: true
matrix:
erlang: ["26", "27"]
erlang: ["27"]

env:
DATABASE_URL: postgres://squirrel_test:postgres_password@localhost:5432/squirrel_test
Expand Down Expand Up @@ -65,7 +65,7 @@ jobs:
- uses: TanklesXL/gleam_actions/.github/actions/deps_cache@main
- uses: TanklesXL/gleam_actions/.github/actions/install_gleam@main
with:
gleam_version: 1.5.1
gleam_version: 1.6.0
erlang_version: ${{ matrix.erlang }}
- name: "integration test"
run: ./integration_test
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# CHANGELOG

## Unreleased

- Fixed a bug where certain queries would generate code with the wrong optional
types.
([Giacomo Cavalieri](https://github.com/giacomocavalieri))

## v2.0.0 - 2024-11-11

- The generated code now uses the [`pog`](https://hexdocs.pm/pog/index.html)
Expand Down
47 changes: 47 additions & 0 deletions birdie_snapshots/left_join_nullability_inference.accepted
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
---
version: 1.2.3
title: left join nullability inference
file: ./test/squirrel_test.gleam
test_name: left_join_nullability_inference_test
---
import decode/zero
import gleam/option.{type Option}
import pog

/// A row you get from running the `query` query
/// defined in `query.sql`.
///
/// > 🐿️ This type definition was generated automatically using v-test of the
/// > [squirrel package](https://github.com/giacomocavalieri/squirrel).
///
pub type QueryRow {
QueryRow(user_id: Int, roles: Option(String))
}

/// Runs the `query` query
/// defined in `query.sql`.
///
/// > 🐿️ This function was generated automatically using v-test of
/// > the [squirrel package](https://github.com/giacomocavalieri/squirrel).
///
pub fn query(db) {
let decoder = {
use user_id <- zero.field(0, zero.int)
use roles <- zero.field(1, zero.optional(zero.string))
zero.success(QueryRow(user_id:, roles:))
}

let query = "
select
users_issue41.user_id,
profile_issue41.roles
from
users_issue41
left join profile_issue41
on profile_issue41.user_id = users_issue41.user_id;
"

pog.query(query)
|> pog.returning(zero.run(_, decoder))
|> pog.execute(db)
}
13 changes: 6 additions & 7 deletions manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@ packages = [
{ name = "argv", version = "1.0.2", build_tools = ["gleam"], requirements = [], otp_app = "argv", source = "hex", outer_checksum = "BA1FF0929525DEBA1CE67256E5ADF77A7CDDFE729E3E3F57A5BDCAA031DED09D" },
{ name = "backoff", version = "1.1.6", build_tools = ["rebar3"], requirements = [], otp_app = "backoff", source = "hex", outer_checksum = "CF0CFFF8995FB20562F822E5CC47D8CCF664C5ECDC26A684CBE85C225F9D7C39" },
{ name = "birdie", version = "1.2.3", build_tools = ["gleam"], requirements = ["argv", "edit_distance", "filepath", "glance", "gleam_community_ansi", "gleam_erlang", "gleam_stdlib", "justin", "rank", "simplifile", "trie_again"], otp_app = "birdie", source = "hex", outer_checksum = "AE1207210E9CC8F4170BCE3FB3C23932F314C352C3FD1BCEA44CF4BF8CF60F93" },
{ name = "decode", version = "0.3.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "decode", source = "hex", outer_checksum = "EE9B979C0D8A5E058E2519EC0EE9CA4C7CEE15B12997BFF50492636CDC53D0C7" },
{ name = "decode", version = "0.5.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "decode", source = "hex", outer_checksum = "05E14DC95A550BA51B8774485B04894B87A898C588B9B1C920104B110AED218B" },
{ name = "edit_distance", version = "2.0.1", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "edit_distance", source = "hex", outer_checksum = "A1E485C69A70210223E46E63985FA1008B8B2DDA9848B7897469171B29020C05" },
{ name = "envoy", version = "1.0.1", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "envoy", source = "hex", outer_checksum = "CFAACCCFC47654F7E8B75E614746ED924C65BD08B1DE21101548AC314A8B6A41" },
{ name = "envoy", version = "1.0.2", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "envoy", source = "hex", outer_checksum = "95FD059345AA982E89A0B6E2A3BF1CF43E17A7048DCD85B5B65D3B9E4E39D359" },
{ name = "eval", version = "1.0.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "eval", source = "hex", outer_checksum = "264DAF4B49DF807F303CA4A4E4EBC012070429E40BE384C58FE094C4958F9BDA" },
{ name = "exception", version = "2.0.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "exception", source = "hex", outer_checksum = "F5580D584F16A20B7FCDCABF9E9BE9A2C1F6AC4F9176FA6DD0B63E3B20D450AA" },
{ name = "filepath", version = "1.0.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "filepath", source = "hex", outer_checksum = "EFB6FF65C98B2A16378ABC3EE2B14124168C0CE5201553DE652E2644DCFDB594" },
{ name = "glam", version = "2.0.1", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "glam", source = "hex", outer_checksum = "66EC3BCD632E51EED029678F8DF419659C1E57B1A93D874C5131FE220DFAD2B2" },
{ name = "glance", version = "0.11.0", build_tools = ["gleam"], requirements = ["gleam_stdlib", "glexer"], otp_app = "glance", source = "hex", outer_checksum = "8F3314D27773B7C3B9FB58D8C02C634290422CE531988C0394FA0DF8676B964D" },
{ name = "gleam_community_ansi", version = "1.4.1", build_tools = ["gleam"], requirements = ["gleam_community_colour", "gleam_stdlib"], otp_app = "gleam_community_ansi", source = "hex", outer_checksum = "4CD513FC62523053E62ED7BAC2F36136EC17D6A8942728250A9A00A15E340E4B" },
{ name = "gleam_community_colour", version = "1.4.0", build_tools = ["gleam"], requirements = ["gleam_json", "gleam_stdlib"], otp_app = "gleam_community_colour", source = "hex", outer_checksum = "795964217EBEDB3DA656F5EB8F67D7AD22872EB95182042D3E7AFEF32D3FD2FE" },
{ name = "gleam_community_colour", version = "1.4.1", build_tools = ["gleam"], requirements = ["gleam_json", "gleam_stdlib"], otp_app = "gleam_community_colour", source = "hex", outer_checksum = "386CB9B01B33371538672EEA8A6375A0A0ADEF41F17C86DDCB81C92AD00DA610" },
{ name = "gleam_crypto", version = "1.4.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleam_crypto", source = "hex", outer_checksum = "8AE56026B3E05EBB1F076778478A762E9EB62B31AEEB4285755452F397029D22" },
{ name = "gleam_erlang", version = "0.27.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleam_erlang", source = "hex", outer_checksum = "DE468F676D71B313C6C8C5334425CFCF827837333F8AB47B64D8A6D7AA40185D" },
{ name = "gleam_json", version = "1.0.1", build_tools = ["gleam"], requirements = ["gleam_stdlib", "thoas"], otp_app = "gleam_json", source = "hex", outer_checksum = "9063D14D25406326C0255BDA0021541E797D8A7A12573D849462CAFED459F6EB" },
{ name = "gleam_stdlib", version = "0.40.0", build_tools = ["gleam"], requirements = [], otp_app = "gleam_stdlib", source = "hex", outer_checksum = "86606B75A600BBD05E539EB59FABC6E307EEEA7B1E5865AFB6D980A93BCB2181" },
{ name = "gleam_erlang", version = "0.30.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleam_erlang", source = "hex", outer_checksum = "760618870AE4A497B10C73548E6E44F43B76292A54F0207B3771CBB599C675B4" },
{ name = "gleam_json", version = "2.0.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleam_json", source = "hex", outer_checksum = "CB10B0E7BF44282FB25162F1A24C1A025F6B93E777CCF238C4017E4EEF2CDE97" },
{ name = "gleam_stdlib", version = "0.43.0", build_tools = ["gleam"], requirements = [], otp_app = "gleam_stdlib", source = "hex", outer_checksum = "69EF22E78FDCA9097CBE7DF91C05B2A8B5436826D9F66680D879182C0860A747" },
{ name = "gleeunit", version = "1.2.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "gleeunit", source = "hex", outer_checksum = "F7A7228925D3EE7D0813C922E062BFD6D7E9310F0BEE585D3A42F3307E3CFD13" },
{ name = "glexer", version = "1.0.1", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "glexer", source = "hex", outer_checksum = "BD477AD657C2B637FEF75F2405FAEFFA533F277A74EF1A5E17B55B1178C228FB" },
{ name = "justin", version = "1.0.1", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "justin", source = "hex", outer_checksum = "7FA0C6DB78640C6DC5FBFD59BF3456009F3F8B485BF6825E97E1EB44E9A1E2CD" },
Expand All @@ -33,7 +33,6 @@ packages = [
{ name = "simplifile", version = "2.2.0", build_tools = ["gleam"], requirements = ["filepath", "gleam_stdlib"], otp_app = "simplifile", source = "hex", outer_checksum = "0DFABEF7DC7A9E2FF4BB27B108034E60C81BEBFCB7AB816B9E7E18ED4503ACD8" },
{ name = "temporary", version = "1.0.0", build_tools = ["gleam"], requirements = ["envoy", "exception", "filepath", "gleam_crypto", "gleam_stdlib", "simplifile"], otp_app = "temporary", source = "hex", outer_checksum = "51C0FEF4D72CE7CA507BD188B21C1F00695B3D5B09D7DFE38240BFD3A8E1E9B3" },
{ name = "term_size", version = "1.0.1", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "term_size", source = "hex", outer_checksum = "D00BD2BC8FB3EBB7E6AE076F3F1FF2AC9D5ED1805F004D0896C784D06C6645F1" },
{ name = "thoas", version = "1.2.1", build_tools = ["rebar3"], requirements = [], otp_app = "thoas", source = "hex", outer_checksum = "E38697EDFFD6E91BD12CEA41B155115282630075C2A727E7A6B2947F5408B86A" },
{ name = "tom", version = "1.1.0", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "tom", source = "hex", outer_checksum = "228E667239504B57AD05EC3C332C930391592F6C974D0EFECF32FFD0F3629A27" },
{ name = "tote", version = "1.0.2", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "tote", source = "hex", outer_checksum = "A249892E26A53C668897F8D47845B0007EEE07707A1A03437487F0CD5A452CA5" },
{ name = "trie_again", version = "1.1.2", build_tools = ["gleam"], requirements = ["gleam_stdlib"], otp_app = "trie_again", source = "hex", outer_checksum = "5B19176F52B1BD98831B57FDC97BD1F88C8A403D6D8C63471407E78598E27184" },
Expand Down
144 changes: 76 additions & 68 deletions src/squirrel/internal/database/postgres.gleam
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@
//// > that is a bug! Please do reach out, I'd love to hear your feedback.
////

import decode/zero
import eval
import gleam/bit_array
import gleam/bool
import gleam/dict.{type Dict}
import gleam/dynamic.{type DecodeErrors, type Dynamic} as d
import gleam/int
import gleam/json
import gleam/list
Expand Down Expand Up @@ -210,24 +210,14 @@ type Nullability {
/// A query plan produced by Postgres when we ask it to `explain` a query.
///
type Plan {
Plan(
join_type: Option(JoinType),
parent_relation: Option(ParentRelation),
output: Option(List(String)),
plans: Option(List(Plan)),
)
Plan(join_type: Option(JoinType), output: List(String), plans: List(Plan))
}

type JoinType {
Full
Left
Right
Other
}

type ParentRelation {
Inner
NotInner
FullJoin
LeftJoin
RightJoin
InnerJoin
}

/// This is the type of a database-related action.
Expand Down Expand Up @@ -741,19 +731,19 @@ fn query_plan(query: UntypedQuery, parameters: Int) -> Db(Plan) {
// We know the output will only contain a single row that is the json string
// containing the query plan.
let assert [[plan]] = res
let assert Ok([plan, ..]) = json.decode_bits(plan, json_plans_decoder)
eval.return(plan)
case json.decode_bits(plan, zero.run(_, json_plans_decoder())) {
Ok([plan, ..]) -> eval.return(plan)
Ok([]) -> panic as "unreachable: no query plan"
Error(reason) ->
eval.throw(error.CannotParsePlanForQuery(file: query.file, reason:))
}
}

/// Given a query plan, returns a set with the indices of the output columns
/// that can contain null values.
///
fn nullables_from_plan(plan: Plan) -> Set(Int) {
let outputs = case plan.output {
Some(outputs) -> list.index_fold(outputs, dict.new(), dict.insert)
None -> dict.new()
}

let outputs = list.index_fold(plan.output, dict.new(), dict.insert)
do_nullables_from_plan(plan, outputs, set.new())
}

Expand All @@ -763,27 +753,57 @@ fn do_nullables_from_plan(
query_outputs: Dict(String, Int),
nullables: Set(Int),
) -> Set(Int) {
let nullables = case plan.output, plan.join_type, plan.parent_relation {
// - All the outputs of a full join must be marked as nullable
// - All the outputs of an inner half join must be marked as nullable
Some(outputs), Some(Full), _ | Some(outputs), _, Some(Inner) -> {
use nullables, output <- list.fold(outputs, from: nullables)
case dict.get(query_outputs, output) {
Ok(i) -> set.insert(nullables, i)
Error(_) -> nullables
}
case plan.join_type, plan.plans {
// If this is a full join then all its outputs could be optional!!
Some(FullJoin), _ ->
plan_outputs_indices(plan, query_outputs)
|> set.union(nullables)

// If this is a right join then we must mark the outputs of its left part as
// nullable!
Some(RightJoin), [left, right] -> {
let nullables =
plan_outputs_indices(left, query_outputs)
|> set.union(nullables)

do_nullables_from_plan(right, query_outputs, nullables)
}

// If this is a left join then we must mark the outputs of its right part as
// nullable!
Some(LeftJoin), [left, right] -> {
let nullables =
plan_outputs_indices(right, query_outputs)
|> set.union(nullables)

do_nullables_from_plan(left, query_outputs, nullables)
}

// This should never happen in theory (a join with 0, 1, or more than two
// childs), so we just inspect their plans as a safe bet.
Some(RightJoin), plans | Some(LeftJoin), plans | None, plans -> {
use nullables, plan <- list.fold(plans, nullables)
do_nullables_from_plan(plan, query_outputs, nullables)
}
_, _, _ -> nullables
}

case plan.plans, plan.join_type {
// If this is an inner half join we keep inspecting the children to mark
// their outputs as nullable.
Some(plans), Some(Left) | Some(plans), Some(Right) -> {
use nullables, plan <- list.fold(plans, from: nullables)
// If this is an inner join then it's outputs are not necessarily nullable,
// we inspect the children's plans to see if they do have some nullable
// columns.
Some(InnerJoin), plans -> {
use nullables, plan <- list.fold(plans, nullables)
do_nullables_from_plan(plan, query_outputs, nullables)
}
_, _ -> nullables
}
}

fn plan_outputs_indices(
plan: Plan,
query_outputs: Dict(String, Int),
) -> Set(Int) {
use nullables, output <- list.fold(plan.output, from: set.new())
case dict.get(query_outputs, output) {
Ok(i) -> set.insert(nullables, i)
Error(_) -> nullables
}
}

Expand Down Expand Up @@ -838,7 +858,7 @@ fn resolve_returns(
// If the name ends with a `?` or `!` we don't want that to be included in
// the gleam name or it would be invalid!
case ends_with_exclamation_mark || ends_with_question_mark {
True -> string.drop_right(name, 1)
True -> string.drop_end(name, 1)
False -> name
}
|> gleam.value_identifier
Expand Down Expand Up @@ -1248,7 +1268,7 @@ fn adjust_parse_error_for_explain(error: Error) -> Error {
CannotParseQuery(
file:,
name:,
content: string.drop_left(content, 31),
content: string.drop_start(content, 31),
starting_line:,
error_code:,
pointer:,
Expand All @@ -1263,37 +1283,25 @@ fn adjust_parse_error_for_explain(error: Error) -> Error {

// --- DECODERS ----------------------------------------------------------------

fn json_plans_decoder(data: Dynamic) -> Result(List(Plan), DecodeErrors) {
d.list(d.field("Plan", plan_decoder))(data)
}

fn plan_decoder(data: Dynamic) -> Result(Plan, DecodeErrors) {
d.decode4(
Plan,
d.optional_field("Join Type", join_type_decoder),
d.optional_field("Parent Relationship", parent_relation_decoder),
d.optional_field("Output", d.list(d.string)),
d.optional_field("Plans", d.list(plan_decoder)),
)(data)
fn json_plans_decoder() {
zero.list(zero.at(["Plan"], plan_decoder()))
}

fn join_type_decoder(data: Dynamic) -> Result(JoinType, DecodeErrors) {
use data <- result.map(d.string(data))
case data {
"Full" -> Full
"Left" -> Left
"Right" -> Right
_ -> Other
}
fn plan_decoder() {
use join_type <- zero.optional_field("Join Type", None, join_type_decoder())
use output <- zero.optional_field("Output", [], zero.list(zero.string))
use plans <- zero.optional_field("Plans", [], zero.list(plan_decoder()))
zero.success(Plan(join_type:, output:, plans:))
}

fn parent_relation_decoder(
data: Dynamic,
) -> Result(ParentRelation, DecodeErrors) {
use data <- result.map(d.string(data))
fn join_type_decoder() {
use data <- zero.then(zero.string)
case data {
"Inner" -> Inner
_ -> NotInner
"Full" -> zero.success(Some(FullJoin))
"Left" -> zero.success(Some(LeftJoin))
"Right" -> zero.success(Some(RightJoin))
"Inner" -> zero.success(Some(InnerJoin))
_ -> zero.failure(None, "a join type")
}
}

Expand Down
Loading
Loading