From 7a9914325de779b6d8cbdcd5a31a87b413f70a86 Mon Sep 17 00:00:00 2001
From: prescientmoon <git@moonythm.dev>
Date: Wed, 20 Nov 2024 00:30:49 +0100
Subject: [PATCH] Prevent naming conflicts for _by_pk queries/mutations

---
 source/AirGQL/GraphQL.hs                      | 20 +++---
 source/AirGQL/Introspection.hs                | 64 ++++++++++++-------
 source/AirGQL/Introspection/NamingConflict.hs | 42 ++++++++++++
 source/AirGQL/Introspection/Types.hs          |  2 +-
 source/AirGQL/Lib.hs                          | 31 +++++----
 5 files changed, 113 insertions(+), 46 deletions(-)
 create mode 100644 source/AirGQL/Introspection/NamingConflict.hs

diff --git a/source/AirGQL/GraphQL.hs b/source/AirGQL/GraphQL.hs
index cc7d4e2..414ae4b 100644
--- a/source/AirGQL/GraphQL.hs
+++ b/source/AirGQL/GraphQL.hs
@@ -82,6 +82,7 @@ import AirGQL.Config (
  )
 
 import AirGQL.Introspection qualified as Introspection
+import AirGQL.Introspection.NamingConflict (encodeOutsidePKNames)
 import AirGQL.Introspection.Resolver qualified as Introspection
 import AirGQL.Introspection.Types qualified as Introspection
 import AirGQL.Lib (
@@ -523,8 +524,7 @@ executeUpdateMutation
   -> HashMap Text Value
   -> [(Text, Value)]
   -> IO (Int, [[SQLData]])
-executeUpdateMutation connection table args filterElements = do
-  pairsToSet :: HashMap Text Value <- getArg "set" args
+executeUpdateMutation connection table pairsToSet filterElements = do
   let
     columnsToSet :: [(ColumnEntry, Value)]
     columnsToSet =
@@ -770,7 +770,7 @@ queryType connection accessMode dbId tables = do
 
         getTableByPKTuple :: TableEntry -> IO (Maybe (Text, Resolver IO))
         getTableByPKTuple table =
-          P.for (Introspection.tableQueryByPKField table) $ \field ->
+          P.for (Introspection.tableQueryByPKField tables table) $ \field ->
             makeResolver field (getDbEntriesByPK table)
 
       queryMany <- P.for tables getTableTuple
@@ -998,13 +998,14 @@ mutationType connection maxRowsPerTable accessMode dbId tables = do
       let Arguments args = context.arguments
       liftIO $ do
         filterObj <- getArg "filter" args
+        pairsToSet <- getArg "set" args
         (numOfChanges, updatedRows) <- case HashMap.toList filterObj of
           [] -> P.throwIO $ userError "Error: Filter must not be empty"
           filterElements ->
             executeUpdateMutation
               connection
               table
-              args
+              pairsToSet
               filterElements
 
         mutationResponse table numOfChanges updatedRows
@@ -1019,11 +1020,12 @@ mutationType connection maxRowsPerTable accessMode dbId tables = do
               & getByPKFilterElements
 
       liftIO $ do
+        pairsToSet <- getArg (encodeOutsidePKNames table "set") args
         (numOfChanges, updatedRows) <-
           executeUpdateMutation
             connection
             table
-            args
+            pairsToSet
             filterElements
 
         mutationByPKResponse table numOfChanges $ P.head updatedRows
@@ -1083,8 +1085,8 @@ mutationType connection maxRowsPerTable accessMode dbId tables = do
 
         getUpdateByPKTableTuple :: TableEntry -> IO (Maybe (Text, Resolver IO))
         getUpdateByPKTableTuple table =
-          P.for (Introspection.tableUpdateFieldByPk accessMode table) $ \field ->
-            makeResolver field (executeDbUpdatesByPK table)
+          P.for (Introspection.tableUpdateFieldByPk accessMode tables table) $
+            \field -> makeResolver field (executeDbUpdatesByPK table)
 
         getDeleteTableTuple :: TableEntry -> IO (Text, Resolver IO)
         getDeleteTableTuple table =
@@ -1094,8 +1096,8 @@ mutationType connection maxRowsPerTable accessMode dbId tables = do
 
         getDeleteByPKTableTuple :: TableEntry -> IO (Maybe (Text, Resolver IO))
         getDeleteByPKTableTuple table =
-          P.for (Introspection.tableDeleteFieldByPK accessMode table) $ \field ->
-            makeResolver field (executeDbDeletionsByPK table)
+          P.for (Introspection.tableDeleteFieldByPK accessMode tables table) $
+            \field -> makeResolver field (executeDbDeletionsByPK table)
 
         getTableTuples :: IO [(Text, Resolver IO)]
         getTableTuples =
diff --git a/source/AirGQL/Introspection.hs b/source/AirGQL/Introspection.hs
index 62fb63e..e22c5b2 100644
--- a/source/AirGQL/Introspection.hs
+++ b/source/AirGQL/Introspection.hs
@@ -33,12 +33,16 @@ import Language.GraphQL.Type.Out as Out (
   Type (NonNullScalarType),
  )
 
+import AirGQL.Introspection.NamingConflict (
+  encodeOutsidePKNames,
+  encodeOutsideTableNames,
+ )
 import AirGQL.Introspection.Resolver (makeType)
 import AirGQL.Introspection.Types (IntrospectionType)
 import AirGQL.Introspection.Types qualified as Type
 import AirGQL.Lib (
   AccessMode,
-  ColumnEntry (isRowid, primary_key),
+  ColumnEntry,
   GqlTypeName (full, root),
   ObjectType (Table),
   TableEntry (columns, name, object_type),
@@ -46,6 +50,7 @@ import AirGQL.Lib (
   canWrite,
   column_name_gql,
   datatype_gql,
+  getPKColumns,
   isOmittable,
   notnull,
  )
@@ -182,26 +187,20 @@ tableQueryField table =
 
 tablePKArguments :: TableEntry -> Maybe [Type.InputValue]
 tablePKArguments table = do
-  let pks = List.filter (\col -> col.primary_key) table.columns
-
-  -- We filter out the rowid column, unless it is the only one
-  withoutRowid <- case pks of
-    [] -> Nothing
-    [first] | first.isRowid -> Just [first]
-    _ -> Just $ List.filter (\col -> P.not col.isRowid) pks
+  pks <- getPKColumns table
 
   pure $
-    withoutRowid <&> \column -> do
+    pks <&> \column -> do
       let name = doubleXEncodeGql column.column_name_gql
       Type.inputValue name $ Type.nonNull $ columnType column
 
 
-tableQueryByPKField :: TableEntry -> Maybe Type.Field
-tableQueryByPKField table = do
+tableQueryByPKField :: [TableEntry] -> TableEntry -> Maybe Type.Field
+tableQueryByPKField tables table = do
   pkArguments <- tablePKArguments table
   pure $
     Type.field
-      (doubleXEncodeGql table.name <> "_by_pk")
+      (encodeOutsideTableNames tables $ doubleXEncodeGql table.name <> "_by_pk")
       (tableRowType table)
       & Type.fieldWithDescription
         ( "Rows from the table \""
@@ -237,7 +236,6 @@ mutationResponseType accessMode table = do
 
 mutationByPkResponseType :: AccessMode -> TableEntry -> Type.IntrospectionType
 mutationByPkResponseType accessMode table = do
-  let tableName = doubleXEncodeGql table.name
   let readonlyFields =
         if canRead accessMode
           then
@@ -247,7 +245,7 @@ mutationByPkResponseType accessMode table = do
           else []
 
   Type.object
-    (tableName <> "_mutation_by_pk_response")
+    (doubleXEncodeGql table.name <> "_mutation_by_pk_response")
     ( [ Type.field "affected_rows" (Type.nonNull Type.typeInt)
       ]
         <> readonlyFields
@@ -363,13 +361,17 @@ tableUpdateField accessMode table = do
       ]
 
 
-tableUpdateFieldByPk :: AccessMode -> TableEntry -> Maybe Type.Field
-tableUpdateFieldByPk accessMode table = do
+tableUpdateFieldByPk
+  :: AccessMode
+  -> [TableEntry]
+  -> TableEntry
+  -> Maybe Type.Field
+tableUpdateFieldByPk accessMode tables table = do
   pkArguments <- tablePKArguments table
 
   let arguments =
         [ Type.inputValue
-            "set"
+            (encodeOutsidePKNames table "set")
             (Type.nonNull $ tableSetInput table)
             & Type.inputValueWithDescription "Fields to be updated"
         ]
@@ -377,7 +379,13 @@ tableUpdateFieldByPk accessMode table = do
 
   pure $
     Type.field
-      ("update_" <> doubleXEncodeGql table.name <> "_by_pk")
+      ( "update_"
+          <> encodeOutsideTableNames
+            tables
+            ( doubleXEncodeGql table.name
+                <> "_by_pk"
+            )
+      )
       (Type.nonNull $ mutationByPkResponseType accessMode table)
       & Type.fieldWithDescription
         ("Update row in table \"" <> table.name <> "\"")
@@ -399,12 +407,20 @@ tableDeleteField accessMode table = do
       ]
 
 
-tableDeleteFieldByPK :: AccessMode -> TableEntry -> Maybe Type.Field
-tableDeleteFieldByPK accessMode table = do
+tableDeleteFieldByPK
+  :: AccessMode
+  -> [TableEntry]
+  -> TableEntry
+  -> Maybe Type.Field
+tableDeleteFieldByPK accessMode tables table = do
   args <- tablePKArguments table
   pure $
     Type.field
-      ("delete_" <> doubleXEncodeGql table.name <> "_by_pk")
+      ( "delete_"
+          <> encodeOutsideTableNames
+            tables
+            (doubleXEncodeGql table.name <> "_by_pk")
+      )
       (Type.nonNull $ mutationByPkResponseType accessMode table)
       & Type.fieldWithDescription
         ("Delete row in table \"" <> table.name <> "\"")
@@ -445,7 +461,7 @@ getSchema accessMode tables = do
         then
           P.fold
             [ tables <&> tableQueryField
-            , tables & P.mapMaybe tableQueryByPKField
+            , tables & P.mapMaybe (tableQueryByPKField tables)
             ]
         else []
 
@@ -462,9 +478,9 @@ getSchema accessMode tables = do
             , tablesWithoutViews <&> tableUpdateField accessMode
             , tablesWithoutViews <&> tableDeleteField accessMode
             , tablesWithoutViews
-                & P.mapMaybe (tableUpdateFieldByPk accessMode)
+                & P.mapMaybe (tableUpdateFieldByPk accessMode tables)
             , tablesWithoutViews
-                & P.mapMaybe (tableDeleteFieldByPK accessMode)
+                & P.mapMaybe (tableDeleteFieldByPK accessMode tables)
             ]
         else []
 
diff --git a/source/AirGQL/Introspection/NamingConflict.hs b/source/AirGQL/Introspection/NamingConflict.hs
new file mode 100644
index 0000000..921e95c
--- /dev/null
+++ b/source/AirGQL/Introspection/NamingConflict.hs
@@ -0,0 +1,42 @@
+{-| Each table, say `foo`, generates a `foo` and a `foo_by_pk`. If a table
+named `foo_by_pk` also exists, this would create a naming conflict. This
+issue also occurs in a few other places.
+
+To solve this, we implement the `encodeOutsideList` function, which encodes a name
+such that it does not conflict with any other name from a given list. This
+is done by repeatedly appending _ at the end, until the name does not reside in
+the given list anymore.
+-}
+module AirGQL.Introspection.NamingConflict (
+  encodeOutsideList,
+  encodeOutsideTableNames,
+  encodeOutsidePKNames,
+) where
+
+import Protolude (Text, fromMaybe, ($), (<$>), (<>))
+
+import AirGQL.Lib (ColumnEntry (column_name_gql), TableEntry (name), getPKColumns)
+import Data.List qualified as List
+import DoubleXEncoding (doubleXEncodeGql)
+
+
+encodeOutsideList :: [Text] -> Text -> Text
+encodeOutsideList list name = do
+  if name `List.elem` list
+    then encodeOutsideList list (name <> "_")
+    else name
+
+
+-- | Encode a name so it does not conflict with any table name
+encodeOutsideTableNames :: [TableEntry] -> Text -> Text
+encodeOutsideTableNames tables =
+  encodeOutsideList $ (\t -> doubleXEncodeGql t.name) <$> tables
+
+
+{-| Encode a name so it does not conflict with any column that is part of a
+PK constraint for a given table.
+-}
+encodeOutsidePKNames :: TableEntry -> Text -> Text
+encodeOutsidePKNames table = do
+  let cols = fromMaybe [] $ getPKColumns table
+  encodeOutsideList $ column_name_gql <$> cols
diff --git a/source/AirGQL/Introspection/Types.hs b/source/AirGQL/Introspection/Types.hs
index 932f953..56b757f 100644
--- a/source/AirGQL/Introspection/Types.hs
+++ b/source/AirGQL/Introspection/Types.hs
@@ -369,7 +369,7 @@ instance ToGraphQL Directive where
         , ("description", toGraphQL value.description)
         , ("isRepeatable", toGraphQL value.isRepeatable)
         , ("args", toGraphQL value.args)
-        , ("locations", Value.List $ Value.Enum <$> value.locations)
+        , ("locations", toGraphQL $ Value.Enum <$> value.locations)
         ]
 
 
diff --git a/source/AirGQL/Lib.hs b/source/AirGQL/Lib.hs
index cc957af..a2a5e87 100644
--- a/source/AirGQL/Lib.hs
+++ b/source/AirGQL/Lib.hs
@@ -19,6 +19,7 @@ module AirGQL.Lib (
   getTableNames,
   getColumnNames,
   getEnrichedTables,
+  getPKColumns,
   ObjectType (..),
   parseSql,
   replaceCaseInsensitive,
@@ -73,8 +74,9 @@ import Control.Monad (MonadFail (fail))
 import Control.Monad.Catch (catchAll)
 import Data.Aeson (FromJSON, ToJSON, Value (Bool, Null, Number, Object, String))
 import Data.Aeson.KeyMap qualified as KeyMap
+import Data.List qualified as List
 import Data.Scientific qualified as Scientific
-import Data.Text (isInfixOf, isSuffixOf, toUpper)
+import Data.Text (isInfixOf, toUpper)
 import Data.Text qualified as T
 import Database.SQLite.Simple (
   Connection,
@@ -733,18 +735,8 @@ lintTable allEntries parsed =
                 <> " does not have a rowid column. "
                 <> "Such tables are not currently supported by Airsequel."
       _ -> []
-
-    illegalName = case parsed.statement of
-      CreateTable names _ _
-        | Just name <- getFirstName (Just names)
-        , "_by_pk" `isSuffixOf` name ->
-            pure $
-              "Table names shouldn't contain \"_by_pk\", yet \""
-                <> name
-                <> "\" does"
-      _ -> []
   in
-    rowidReferenceWarnings <> withoutRowidWarning <> illegalName
+    rowidReferenceWarnings <> withoutRowidWarning
 
 
 {-| Lint the sql code for creating a table
@@ -778,6 +770,21 @@ getRowidColumnName colNames
   | otherwise = "rowid" -- TODO: Return error to user
 
 
+{-| Select the column(s) that form this table's primary key. If no non-rowid
+columns are marked as part of a PK constraint, the rowid column will be
+returned instead.
+-}
+getPKColumns :: TableEntry -> Maybe [ColumnEntry]
+getPKColumns table = do
+  let pks = List.filter (\col -> col.primary_key) table.columns
+
+  -- We filter out the rowid column, unless it is the only one
+  case pks of
+    [] -> Nothing
+    [first] | first.isRowid -> Just [first]
+    _ -> Just $ List.filter (\col -> P.not col.isRowid) pks
+
+
 columnDefName :: ColumnDef -> Text
 columnDefName (ColumnDef name _ _) = nameAsText name