From 301ee009d01b5c69c387e566681334001ccfdf41 Mon Sep 17 00:00:00 2001
From: Jake Wheat <jakewheat@tutanota.com>
Date: Wed, 10 Jan 2024 16:10:00 +0000
Subject: [PATCH] fix pretty printer formatting

---
 Language/SQL/SimpleSQL/Pretty.hs | 94 +++++++++++++++++++-------------
 1 file changed, 56 insertions(+), 38 deletions(-)

diff --git a/Language/SQL/SimpleSQL/Pretty.hs b/Language/SQL/SimpleSQL/Pretty.hs
index 8de4c08..93418b6 100644
--- a/Language/SQL/SimpleSQL/Pretty.hs
+++ b/Language/SQL/SimpleSQL/Pretty.hs
@@ -19,20 +19,20 @@ import Prelude hiding (show)
 import qualified Prelude as P
 
 import Prettyprinter (Doc
-                     ,parens
                      ,nest
-                     ,(<+>)
-                     ,sep
                      ,punctuate
                      ,comma
                      ,squotes
                      ,vsep
-                     ,hsep
                      ,layoutPretty
                      ,defaultLayoutOptions
                      ,brackets
+                     ,align
+                     ,hcat
+                     ,line
                      )
 import qualified Prettyprinter as P
+import Prettyprinter.Internal.Type (Doc(Empty))
 
 import Prettyprinter.Render.Text (renderStrict)
 
@@ -55,7 +55,7 @@ prettyScalarExpr d = render . scalarExpr d
 
 -- | A terminating semicolon.
 terminator :: Doc a
-terminator = pretty ";\n"
+terminator = pretty ";" <> line
 
 -- | Convert a statement ast to concrete syntax.
 prettyStatement :: Dialect -> Statement -> Text
@@ -98,13 +98,13 @@ scalarExpr _ (HostParameter p i) =
 scalarExpr d (App f es) = names f <> parens (commaSep (map (scalarExpr d) es))
 
 scalarExpr dia (AggregateApp f d es od fil) =
-    names f
+    (names f
     <> parens ((case d of
                   Distinct -> pretty "distinct"
                   All -> pretty "all"
                   SQDefault -> mempty)
                <+> commaSep (map (scalarExpr dia) es)
-               <+> orderBy dia od)
+               <+> orderBy dia od))
     <+> me (\x -> pretty "filter"
                   <+> parens (pretty "where" <+> scalarExpr dia x)) fil
 
@@ -120,8 +120,8 @@ scalarExpr d (WindowApp f es pb od fr) =
     <+> pretty "over"
     <+> parens ((case pb of
                     [] -> mempty
-                    _ -> pretty "partition by"
-                          <+> nest 13 (commaSep $ map (scalarExpr d) pb))
+                    _ -> (pretty "partition by") <+> align
+                                   (commaSep $ map (scalarExpr d) pb))
                 <+> orderBy d od
     <+> me frd fr)
   where
@@ -138,11 +138,13 @@ scalarExpr d (WindowApp f es pb od fr) =
     fpd (Preceding e) = scalarExpr d e <+> pretty "preceding"
     fpd (Following e) = scalarExpr d e <+> pretty "following"
 
-scalarExpr dia (SpecialOp nm [a,b,c]) | nm `elem` [[Name Nothing "between"]
-                                                 ,[Name Nothing "not between"]] =
+scalarExpr dia (SpecialOp nm [a,b,c])
+    | nm `elem` [[Name Nothing "between"]
+                ,[Name Nothing "not between"]] =
   sep [scalarExpr dia a
-      ,names nm <+> scalarExpr dia b
-      ,nest (T.length (unnames nm) + 1) $ pretty "and" <+> scalarExpr dia c]
+      ,names nm <+> nest ((T.length (unnames nm) - 3)) (sep
+          [scalarExpr dia b
+          ,pretty "and" <+> scalarExpr dia c])]
 
 scalarExpr d (SpecialOp [Name Nothing "rowctor"] as) =
     parens $ commaSep $ map (scalarExpr d) as
@@ -181,10 +183,11 @@ scalarExpr dia (Case t ws els) =
           <> [pretty "end"]
   where
     w (t0,t1) =
-      pretty "when" <+> nest 5 (commaSep $ map (scalarExpr dia) t0)
-      <+> pretty "then" <+> nest 5 (scalarExpr dia t1)
-    e el = pretty "else" <+> nest 5 (scalarExpr dia el)
-scalarExpr d (Parens e) = parens $ scalarExpr d e
+      pretty "when" <+> align (sep [commaSep $ map (scalarExpr dia) t0
+                                   ,pretty "then" <+> align (scalarExpr dia t1)])
+    e el = pretty "else" <+> align (scalarExpr dia el)
+scalarExpr d (Parens e) =
+    parens (scalarExpr d e)
 scalarExpr d (Cast e tn) =
     pretty "cast" <> parens (sep [scalarExpr d e
                                  ,pretty "as"
@@ -219,8 +222,7 @@ scalarExpr d (In b se x) =
     scalarExpr d se <+>
     (if b then mempty else pretty "not")
     <+> pretty "in"
-    <+> parens (nest (if b then 3 else 7) $
-                 case x of
+    <+> parens (case x of
                      InList es -> commaSep $ map (scalarExpr d) es
                      InQueryExpr qe -> queryExpr d qe)
 
@@ -294,7 +296,7 @@ name (Name Nothing n) = pretty n
 name (Name (Just (s,e)) n) = pretty s <> pretty n <> pretty e
 
 names :: [Name] -> Doc a
-names ns = hsep $ punctuate (pretty ".") $ map name ns
+names ns = hcat $ punctuate (pretty ".") $ map name ns
 
 typeName :: TypeName -> Doc a
 typeName (TypeName t) = names t
@@ -314,8 +316,8 @@ typeName (PrecLengthTypeName t i m u) =
                        PrecCharacters -> pretty "CHARACTERS"
                        PrecOctets -> pretty "OCTETS") u)
 typeName (CharTypeName t i cs col) =
-    names t
-    <> me (\x -> parens (pretty $ show x)) i
+    (names t
+    <> me (\x -> parens (pretty $ show x)) i)
     <+> (if null cs
          then mempty
          else pretty "character set" <+> names cs)
@@ -323,8 +325,8 @@ typeName (CharTypeName t i cs col) =
          then mempty
          else pretty "collate" <+> names col)
 typeName (TimeTypeName t i tz) =
-    names t
-    <> me (\x -> parens (pretty $ show x)) i
+    (names t
+    <> me (\x -> parens (pretty $ show x)) i)
     <+> pretty (if tz
               then "with time zone"
               else "without time zone")
@@ -355,12 +357,12 @@ intervalTypeField (Itf n p) =
 
 queryExpr :: Dialect -> QueryExpr -> Doc a
 queryExpr dia (Select d sl fr wh gb hv od off fe) =
-  sep [pretty "select"
-      ,case d of
-          SQDefault -> mempty
-          All -> pretty "all"
-          Distinct -> pretty "distinct"
-      ,nest 7 $ sep [selectList dia sl]
+  sep [pretty "select" <+> align (sep
+          [case d of
+               SQDefault -> mempty
+               All -> pretty "all"
+               Distinct -> pretty "distinct"
+          ,selectList dia sl])
       ,from dia fr
       ,maybeScalarExpr dia "where" wh
       ,grpBy dia gb
@@ -423,8 +425,7 @@ selectList d is = commaSep $ map si is
 from :: Dialect -> [TableRef] -> Doc a
 from _ [] = mempty
 from d ts =
-    sep [pretty "from"
-        ,nest 5 $ vsep $ punctuate comma $ map tr ts]
+    pretty "from" <+> align (vsep (punctuate comma $ map tr ts))
   where
     tr (TRSimple t) = names t
     tr (TRLateral t) = pretty "lateral" <+> tr t
@@ -454,13 +455,11 @@ from d ts =
 
 maybeScalarExpr :: Dialect -> Text -> Maybe ScalarExpr -> Doc a
 maybeScalarExpr d k = me
-      (\e -> sep [pretty k
-                 ,nest (T.length k + 1) $ scalarExpr d e])
+      (\e -> pretty k <+> align (scalarExpr d e))
 
 grpBy :: Dialect -> [GroupingExpr] -> Doc a
 grpBy _ [] = mempty
-grpBy d gs = sep [pretty "group by"
-                 ,nest 9 $ commaSep $ map ge gs]
+grpBy d gs = pretty "group by" <+> align (commaSep $ map ge gs)
   where
     ge (SimpleGroup e) = scalarExpr d e
     ge (GroupingParens g) = parens (commaSep $ map ge g)
@@ -470,8 +469,7 @@ grpBy d gs = sep [pretty "group by"
 
 orderBy :: Dialect -> [SortSpec] -> Doc a
 orderBy _ [] = mempty
-orderBy dia os = sep [pretty "order by"
-                 ,nest 9 $ commaSep $ map f os]
+orderBy dia os = pretty "order by" <+> align (commaSep $ map f os)
   where
     f (SortSpec e d n) =
         scalarExpr dia e
@@ -876,3 +874,23 @@ pretty = P.pretty
 
 show :: Show a => a -> Text
 show = T.pack . P.show
+
+-- restore the correct behaviour of mempty
+-- this doesn't quite work when you chain <> and <+> together,
+-- so use parens in those cases
+
+sep :: [Doc a] -> Doc a
+sep = P.sep . filter isEmpty
+  where
+    isEmpty Empty = False
+    isEmpty _ = True
+
+(<+>) :: Doc a -> Doc a -> Doc a
+(<+>) a b = case (a,b) of
+    (Empty, Empty) -> Empty
+    (Empty, x) -> x
+    (x, Empty) -> x
+    _ ->  a P.<+> b
+
+parens :: Doc a -> Doc a
+parens a = P.parens (align a)