1 /**
2 Commands of various nature. Prepared statement, portal and response demarshalling
3 functions live here.
4 
5 Copyright: Copyright Boris-Barboris 2017.
6 License: MIT
7 Authors: Boris-Barboris
8 */
9 
10 module dpeq.command;
11 
12 import std.algorithm: max, map;
13 import std.exception: enforce;
14 import std.format: format;
15 import std.conv: to;
16 import std.traits;
17 import std.range;
18 import std.variant;
19 import std.meta;
20 import std.typecons;
21 
22 import dpeq.exceptions;
23 import dpeq.connection;
24 import dpeq.constants;
25 import dpeq.marshalling;
26 import dpeq.schema;
27 
28 
29 /*
30 /////////////////////////////////////
31 // Different forms of command input
32 /////////////////////////////////////
33 */
34 
35 /** Simple query is simple. Send string to server and get responses.
36 The most versatile, but unsafe way to issue commands to PSQL.
37 Simple query always returns data in FormatCode.Text format.
38 Simple queries SHOULD NOT be accompanied by SYNC message, they
39 trigger ReadyForQuery message anyways.
40 
41 Every postSimpleQuery or PSQLConnection.sync should be accompanied by getQueryResults
42 call. */
43 void postSimpleQuery(ConnT)(ConnT conn, string query)
44 {
45     conn.putQueryMessage(query);
46 }
47 
48 
49 /// Pre-parsed sql query with variable parameters.
50 class PreparedStatement(ConnT)
51 {
52     protected
53     {
54         const(ObjectID)[] paramTypes;
55         string query;
56         bool parseRequested;
57         ConnT conn;
58         string parsedName;  // name, reserved for this statement in PSQL connection
59         short m_paramCount;
60     }
61 
62     /// name of this prepared statement, as seen by backend.
63     final @property string preparedName() const { return parsedName; }
64 
65     final @property short paramCount() const { return m_paramCount; }
66 
67     /**
68     Creates prepared statement object, wich holds dpeq utility state.
69     Constructor does not write anything to connection write buffer.
70 
71     Quoting https://www.postgresql.org/docs/9.5/static/protocol-message-formats.html:
72 
73     The number of parameter data types specified (can be zero). Note that this is not an
74     indication of the number of parameters that might appear in the query string,
75     only the number that the frontend wants to prespecify types for.
76 
77     Then, for each parameter, there is the following:
78 
79     Int32
80         Specifies the object ID of the parameter data type. Placing a zero here is
81         equivalent to leaving the type unspecified.
82 
83     That means you can leave paramTypes null, unless you're doing some tricky
84     stuff.
85     */
86     this(ConnT conn, string query, short paramCount, bool named = false,
87         const(ObjectID)[] paramTypes = null)
88     {
89         assert(conn);
90         assert(query);
91         assert(paramCount >= 0);
92         this.conn = conn;
93         this.query = query;
94         this.paramTypes = paramTypes;
95         this.m_paramCount = paramCount;
96         if (named)
97             parsedName = conn.getNewPreparedName();
98         else
99             parsedName = "";
100     }
101 
102     /// write Parse message into connection's write buffer.
103     final void postParseMessage()
104     {
105         conn.putParseMessage(parsedName, query, paramTypes[]);
106         parseRequested = true;
107     }
108 
109     /// ditto
110     alias parse = postParseMessage;
111 
112     /** Post message to destroy named prepared statement.
113 
114     An unnamed prepared statement lasts only until the next Parse
115     statement specifying the unnamed statement as destination is issued.
116     (Note that a simple Query message also destroys the unnamed statement.)
117     */
118     final void postCloseMessage()
119     {
120         assert(parseRequested, "prepared statement was never sent to backend");
121         assert(parsedName.length, "no need to close unnamed prepared statements");
122         conn.putCloseMessage(StmtOrPortal.Statement, parsedName);
123         parseRequested = false;
124     }
125 
126     /// poll message queue and make sure parse was completed
127     final void ensureParseComplete()
128     {
129         bool parsed = false;
130         bool interceptor(Message msg, ref bool err, ref string errMsg)
131         {
132             with (BackendMessageType)
133             switch (msg.type)
134             {
135                 case ParseComplete:
136                     parsed = true;
137                     return true;
138                 default:
139                     break;
140             }
141             return false;
142         }
143         conn.pollMessages(&interceptor, true);
144         enforce!PsqlClientException(parsed, "Parse was not confirmed");
145     }
146 }
147 
148 
149 /// Parameter tuple, bound to prepared statement
150 class Portal(ConnT)
151 {
152     protected
153     {
154         PreparedStatement!ConnT prepStmt;
155         ConnT conn;
156         string portalName;  // name, reserved for this portal in PSQL connection
157         bool bindRequested = false;
158     }
159 
160     this(PreparedStatement!ConnT ps, bool persist = true)
161     {
162         assert(ps);
163         this.conn = ps.conn;
164         prepStmt = ps;
165         if (persist)
166             portalName = conn.getNewPortalName();
167         else
168             portalName = "";
169     }
170 
171     /// bind empty, parameterless portal. resCodes are requested format codes
172     /// of resulting columns, keep it null to request everything in text format.
173     final void bind(FormatCode[] resCodes = null)
174     {
175         assert(prepStmt.paramCount == 0);
176 
177         auto safePoint = conn.saveBuffer();
178         scope (failure) safePoint.restore();
179 
180         if (bindRequested && portalName.length)
181             postCloseMessage();
182 
183         conn.putBindMessage(portalName, prepStmt.parsedName, resCodes);
184         bindRequested = true;
185     }
186 
187     /**
188     For the 'specs' array of prepared statement parameters types, known at
189     compile-time, write Bind message to connection's write buffer from the
190     representation of 'args' parameters, marshalled to 'specs' types according
191     to 'Marshaller' template. Format codes of the response columns is set
192     via 'resCodes' array, known at compile time.
193     */
194     final void bind(
195             FieldSpec[] specs,
196             FormatCode[] resCodes = null,
197             alias Marshaller = DefaultFieldMarshaller,
198             Args...)
199         (in Args args)
200     {
201         assert(prepStmt.paramCount == Args.length);
202         assert(prepStmt.paramCount == specs.length);
203 
204         auto safePoint = conn.saveBuffer();
205         scope (failure) safePoint.restore();
206 
207         if (bindRequested && portalName.length)
208             postCloseMessage();
209 
210         enum fcodesr = [staticMap!(FCodeOfFSpec!(Marshaller).F, aliasSeqOf!specs)];
211 
212         alias DlgT = int delegate(ubyte[]);
213         DlgT[specs.length] marshallers;
214         foreach(i, paramSpec; aliasSeqOf!specs)
215         {
216             marshallers[i] =
217                 (ubyte[] to) => Marshaller!paramSpec.marshal(to, args[i]);
218         }
219         conn.putBindMessage(portalName, prepStmt.parsedName, fcodesr,
220             marshallers, resCodes);
221         bindRequested = true;
222     }
223 
224     /** This version of bind accept generic InputRanges of format codes and
225     field marshallers and passes them directly to putBindMessage method of
226     connection object. No parameter count and type validation is performed.
227     If this portal is already bound and is a named one, Close message is
228     posted.
229     */
230     final void bind(FR, PR, RR)(scope FR paramCodeRange, scope PR paramMarshRange,
231         scope RR returnCodeRange)
232     {
233         auto safePoint = conn.saveBuffer();
234         scope (failure) safePoint.restore();
235         if (bindRequested && portalName.length)
236             postCloseMessage();
237         conn.putBindMessage(portalName, prepStmt.parsedName, paramCodeRange,
238             paramMarshRange, returnCodeRange);
239         bindRequested = true;
240     }
241 
242     /// Simple portal bind, wich binds all parameters as strings and requests
243     /// all result columns in text format.
244     final void bind(scope Nullable!(string)[] args)
245     {
246         assert(prepStmt.paramCount == args.length);
247 
248         if (bindRequested && portalName.length)
249             postCloseMessage();
250 
251         static struct StrMarshaller
252         {
253             Nullable!string str;
254             this(Nullable!string v) { str = v; }
255 
256             int opCall(ubyte[] buf)
257             {
258                 return marshalNullableStringField(buf, str);
259             }
260         }
261 
262         static struct MarshRange
263         {
264             Nullable!(string)[] params;
265             int idx = 0;
266             @property bool empty() { return idx >= params.length; }
267             void popFront() { idx++; }
268             @property StrMarshaller front()
269             {
270                 return StrMarshaller(params[idx]);
271             }
272         }
273 
274         conn.putBindMessage!(FormatCode[], MarshRange, FormatCode[])(
275             portalName, prepStmt.parsedName, null, MarshRange(args), null);
276         bindRequested = true;
277     }
278 
279     /** Write Close message to connection write buffer in order to
280     explicitly destroy named portal.
281 
282     If successfully created, a named portal object lasts till the end of the
283     current transaction, unless explicitly destroyed. An unnamed portal is
284     destroyed at the end of the transaction, or as soon as the next Bind
285     statement specifying the unnamed portal as destination is issued.
286     (Note that a simple Query message also destroys the unnamed portal.)
287     Named portals must be explicitly closed before they can be redefined
288     by another Bind message, but this is not required for the unnamed portal.
289     */
290     final void postCloseMessage()
291     {
292         assert(bindRequested, "portal was never bound");
293         assert(portalName.length, "no need to close unnamed portals");
294         conn.putCloseMessage(StmtOrPortal.Portal, portalName);
295         bindRequested = false;
296     }
297 
298     /// poll message queue and make sure bind was completed
299     final void ensureBindComplete()
300     {
301         bool is_bound = false;
302         bool interceptor(Message msg, ref bool err, ref string errMsg)
303         {
304             with (BackendMessageType)
305             switch (msg.type)
306             {
307                 case BindComplete:
308                     is_bound = true;
309                     return true;
310                 default:
311                     break;
312             }
313             return false;
314         }
315         conn.pollMessages(&interceptor, true);
316         enforce!PsqlClientException(is_bound, "Bind was not confirmed");
317     }
318 
319     /** Send Describe+Execute command.
320     If describe is false, no RowDescription message will be requested
321     from PSQL - useful for optimistic statically-typed querying.
322     'maxRows' parameter is responsible for portal suspending and is
323     conceptually inferior to simple TCP backpressure mechanisms or result set
324     size limiting. */
325     final void execute(bool describe = true, int maxRows = 0)
326     {
327         assert(bindRequested, "Portal was never bound");
328         if (describe)
329             conn.putDescribeMessage(StmtOrPortal.Portal, portalName);
330         conn.putExecuteMessage(portalName, maxRows);
331     }
332 }
333 
334 
335 /*
336 ////////////////////////////////////////
337 // Functions to work with query results
338 ////////////////////////////////////////
339 */
340 
341 
342 /** Generic result materializer, suitable for both simple and prepared queries.
343 Polls messages from the connection and builds QueryResult structure from
344 them. Throws if something goes wrong. Polling stops when ReadyForQuery message
345 is received. */
346 QueryResult getQueryResults(ConnT)(ConnT conn, bool requireRowDescription = false)
347 {
348     QueryResult res;
349     bool newBlockAwaited = true;
350 
351     bool interceptor(Message msg, ref bool err, ref string errMsg) nothrow
352     {
353         with (BackendMessageType)
354         switch (msg.type)
355         {
356             case EmptyQueryResponse:
357                 if (newBlockAwaited)
358                 {
359                     RowBlock rb;
360                     rb.emptyQuery = true;
361                     res.blocks ~= rb;
362                 }
363                 res.commandsComplete++;
364                 newBlockAwaited = true;
365                 break;
366             case CommandComplete:
367                 if (newBlockAwaited)
368                     res.blocks ~= RowBlock();
369                 res.commandsComplete++;
370                 newBlockAwaited = true;
371                 break;
372             case PortalSuspended:
373                 res.commandsComplete++;
374                 newBlockAwaited = true;
375                 res.blocks[$-1].suspended = true;
376                 break;
377             case RowDescription:
378                 // RowDescription always precedes new row block data
379                 if (newBlockAwaited)
380                 {
381                     RowBlock rb;
382                     rb.rowDesc = dpeq.schema.RowDescription(msg.data);
383                     res.blocks ~= rb;
384                     newBlockAwaited = false;
385                 }
386                 else
387                 {
388                     err = true;
389                     errMsg = "Unexpected RowDescription in the middle of " ~
390                         "row block";
391                 }
392                 break;
393             case DataRow:
394                 if (newBlockAwaited)
395                 {
396                     if (requireRowDescription)
397                     {
398                         err = true;
399                         errMsg ~= "Got row without row description. ";
400                     }
401                     res.blocks ~= RowBlock();
402                     newBlockAwaited = false;
403                 }
404                 res.blocks[$-1].dataRows ~= msg; // we simply save raw bytes
405                 break;
406             default:
407                 break;
408         }
409         return false;
410     }
411 
412     conn.pollMessages(&interceptor, false);
413     return res;
414 }
415 
416 
417 /// Poll messages from the connection until CommandComplete or EmptyQueryResponse
418 /// is received, and return one row block (result of one and only one query).
419 RowBlock getOneRowBlock(ConnT)(ConnT conn, int rowCountLimit = 0,
420     bool requireRowDescription = false)
421 {
422     RowBlock result;
423 
424     bool interceptor(Message msg, ref bool err, ref string errMsg) nothrow
425     {
426         with (BackendMessageType)
427         switch (msg.type)
428         {
429             case EmptyQueryResponse:
430                 result.emptyQuery = true;
431                 return true;
432             case CommandComplete:
433                 return true;
434             case PortalSuspended:
435                 result.suspended = true;
436                 return true;
437             case RowDescription:
438                 result.rowDesc = dpeq.schema.RowDescription(msg.data);
439                 requireRowDescription = false;
440                 break;
441             case DataRow:
442                 if (requireRowDescription)
443                 {
444                     err = true;
445                     errMsg ~= "Missing required RowDescription. ";
446                     break;
447                 }
448                 result.dataRows ~= msg;
449                 if (rowCountLimit != 0)
450                 {
451                     // client code requested early stop
452                     rowCountLimit--;
453                     if (rowCountLimit == 0)
454                         return true;
455                 }
456                 break;
457             default:
458                 break;
459         }
460         return false;
461     }
462 
463     conn.pollMessages(&interceptor, true);
464     return result;
465 }
466 
467 
468 
469 /*
470 /////////////////////////////////////////////////////////////////
471 // Functions used to transform query results into D types
472 /////////////////////////////////////////////////////////////////
473 */
474 
475 //import std.stdio;
476 
477 /** Returns RandomAccessRange of InputRanges of lazy-demarshalled variants.
478 Specific flavor of Variant is derived from Converter.demarshal call return type.
479 Look into marshalling.VariantConverter for demarshal implementation examples.
480 Will append parsed field descriptions to fieldDescs array if passed. */
481 auto blockToVariants(alias Converter = VariantConverter!DefaultFieldMarshaller)
482     (RowBlock block, FieldDescription[]* fieldDescs = null)
483 {
484     alias VariantT = ReturnType!(Converter.demarshal);
485 
486     enforce!PsqlMarshallingException(block.rowDesc.isSet,
487         "Cannot demarshal RowBlock without row description. " ~
488         "Did you send Describe message?");
489     short totalColumns = block.rowDesc.fieldCount;
490     ObjectID[] typeArr = new ObjectID[totalColumns];
491     FormatCode[] fcArr = new FormatCode[totalColumns];
492 
493     int i = 0;
494     foreach (fdesc; block.rowDesc[]) // row description demarshalling happens here
495     {
496         //writeln(fdesc.name);
497         //writeln(fdesc.formatCode);
498         if (fieldDescs)
499             (*fieldDescs)[i] = fdesc;
500         fcArr[i] = fdesc.formatCode;
501         typeArr[i++] = fdesc.type;
502     }
503 
504     static struct RowDemarshaller
505     {
506     private:
507         short column = 0;
508         short totalCols;
509         const(ubyte)[] buf;
510         const(ObjectID)[] types;
511         const(FormatCode)[] fcodes;
512         bool parsed = false;
513 
514         // cache result to prevent repeated demarshalling on
515         // front() call.
516         VariantT res;
517     public:
518         @property bool empty() const { return column >= totalCols; }
519         void popFront()
520         {
521             parsed = false;
522             column++;
523         }
524         @property VariantT front()
525         {
526             if (parsed)
527                 return res;
528             if (column == 0)
529             {
530                 // we need to skip field count in the start of DataRow message
531                 buf = buf[2 .. $];
532             }
533             assert(buf.length > 0);
534             int len = demarshalNumber(buf[0 .. 4]);
535             const(ubyte)[] vbuf = buf[4 .. max(4, len + 4)];
536             //writeln(types[column], " ", buf);
537             res = Converter.demarshal(vbuf, types[column], fcodes[column], len);
538             buf = buf[max(4, len + 4) .. $];
539             parsed = true;
540             return res;
541         }
542     }
543 
544     static struct RowsRange
545     {
546     private:
547         Message[] dataRows;
548         ObjectID[] columnTypes;
549         FormatCode[] fcodes;
550         short totalColumns;
551     public:
552         @property size_t length() { return dataRows.length; }
553         @property bool empty() { return dataRows.empty; }
554         @property RowDemarshaller front()
555         {
556             return RowDemarshaller(0, totalColumns, dataRows[0].data,
557                 columnTypes, fcodes);
558         }
559         @property RowDemarshaller back()
560         {
561             return RowDemarshaller(0, totalColumns, dataRows[$-1].data,
562                 columnTypes, fcodes);
563         }
564         RowDemarshaller opIndex(size_t i)
565         {
566             return RowDemarshaller(0, totalColumns, dataRows[i].data,
567                 columnTypes, fcodes);
568         }
569         void popFront() { dataRows = dataRows[1 .. $]; }
570         void popBack() { dataRows = dataRows[0 .. $-1]; }
571         RowsRange save()
572         {
573             return RowsRange(dataRows, columnTypes, fcodes, totalColumns);
574         }
575     }
576 
577     return RowsRange(block.dataRows, typeArr, fcArr, totalColumns);
578 }
579 
580 
581 
582 
583 /// for row spec `spec` build native tuple representation.
584 template TupleForSpec(FieldSpec[] spec, alias Demarshaller = DefaultFieldMarshaller)
585 {
586     alias TupleForSpec =
587         Tuple!(
588             staticMap!(
589                 SpecMapper!(Demarshaller).Func,
590                 aliasSeqOf!spec));
591 }
592 
593 /// Template function Func returns D type wich corresponds to FieldSpec.
594 template SpecMapper(alias Demarshaller)
595 {
596     template Func(FieldSpec spec)
597     {
598         static if (is(Demarshaller!spec.type))
599             alias Func = Demarshaller!spec.type;
600         else
601             static assert(0, "Demarshaller doesn't support type with oid " ~
602                 spec.typeId.to!string);
603     }
604 }
605 
606 
607 /** Returns RandomAccessRange of lazily-demarshalled tuples.
608 Customazable with Demarshaller template. Will append parsed field descriptions
609 to fieldDescs array if it is provided. */
610 auto blockToTuples
611     (FieldSpec[] spec, alias Demarshaller = DefaultFieldMarshaller)
612     (RowBlock block, FieldDescription[]* fieldDescs = null)
613 {
614     alias ResTuple = TupleForSpec!(spec, Demarshaller);
615     debug pragma(msg, "Resulting tuple from spec: ", ResTuple);
616     enforce!PsqlMarshallingException(block.rowDesc.isSet,
617         "Cannot demarshal RowBlock without row description. " ~
618         "Did you send describe message?");
619     short totalColumns = block.rowDesc.fieldCount;
620     enforce!PsqlMarshallingException(totalColumns == spec.length,
621         "Expected %d columnts in a row, got %d".format(spec.length, totalColumns));
622     FormatCode[] fcArr = new FormatCode[totalColumns];
623 
624     int i = 0;
625     foreach (fdesc; block.rowDesc[]) // row description demarshalling happens here
626     {
627         //writeln(fdesc.name);
628         //writeln(fdesc.formatCode);
629         if (fieldDescs)
630             (*fieldDescs)[i] = fdesc;
631         fcArr[i] = fdesc.formatCode;
632         ObjectID colType = fdesc.type;
633         enforce!PsqlMarshallingException(colType == spec[i].typeId,
634             "Colunm %d type mismatch: expected %d, got %d".format(
635                 i, spec[i].typeId, colType));
636         i++;
637     }
638 
639     //import std.stdio;
640 
641     static ResTuple demarshalRow(const(ubyte)[] from, const(FormatCode)[] fcodes)
642     {
643         ResTuple res;
644         int len = 0;
645         const(ubyte)[] vbuf;
646         from = from[2 .. $];    // skip 16 bits
647         foreach (i, colSpec; aliasSeqOf!(spec))
648         {
649             len = demarshalNumber(from[0 .. 4]);
650             //writeln("col ", i, ", len = ", len, " from = ", from);
651             vbuf = from[4 .. max(4, len + 4)];
652             res[i] = Demarshaller!(colSpec).demarshal(vbuf, fcodes[i], len);
653             from = from[max(4, len + 4) .. $];
654         }
655         enforce!PsqlMarshallingException(from.length == 0,
656             "%d bytes left in supposedly emptied row".format(from.length));
657         return res;
658     }
659 
660     static struct RowsRange
661     {
662     private:
663         Message[] dataRows;
664         FormatCode[] fcodes;
665     public:
666         @property size_t length() { return dataRows.length; }
667         @property bool empty() { return dataRows.empty; }
668         @property ResTuple front()
669         {
670             return demarshalRow(dataRows[0].data, fcodes);
671         }
672         @property ResTuple back()
673         {
674             return demarshalRow(dataRows[$-1].data, fcodes);
675         }
676         ResTuple opIndex(size_t i)
677         {
678             return demarshalRow(dataRows[i].data, fcodes);
679         }
680         void popFront() { dataRows = dataRows[1 .. $]; }
681         void popBack() { dataRows = dataRows[0 .. $-1]; }
682         RowsRange save()
683         {
684             return RowsRange(dataRows, fcodes);
685         }
686     }
687 
688     return RowsRange(block.dataRows, fcArr);
689 }
690 
691 
692 class FormatCodesOfSpec(FieldSpec[] spec, alias Demarshaller)
693 {
694     static const(FormatCode)[spec.length] codes;
695 
696     static this()
697     {
698         foreach (i, fpec; aliasSeqOf!spec)
699             codes[i] = Demarshaller!fpec.formatCode;
700     }
701 }
702 
703 
704 /** Returns RandomAccessRange of lazy-demarshalled tuples. Customazable with
705 Demarshaller template. This version does not require RowDescription, but cannot
706 validate row types reliably. */
707 auto blockToTuples
708     (FieldSpec[] spec, alias Demarshaller = DefaultFieldMarshaller)
709     (Message[] data)
710 {
711     alias ResTuple = TupleForSpec!(spec, Demarshaller);
712     debug pragma(msg, "Resulting tuple from spec: ", ResTuple);
713 
714     //import std.stdio;
715 
716     static ResTuple demarshalRow(const(ubyte)[] from)
717     {
718         ResTuple res;
719         int len = 0;
720         const(ubyte)[] vbuf;
721         from = from[2 .. $];    // skip 16 bytes
722         foreach (i, colSpec; aliasSeqOf!(spec))
723         {
724             len = demarshalNumber(from[0 .. 4]);
725             //writeln("col ", i, ", len = ", len, " from = ", from);
726             vbuf = from[4 .. max(4, len + 4)];
727             FormatCode fcode = FCodeOfFSpec!(Demarshaller).F!(colSpec);
728             res[i] = Demarshaller!(colSpec).demarshal(vbuf, fcode, len);
729             from = from[max(4, len + 4) .. $];
730         }
731         enforce!PsqlMarshallingException(from.length == 0,
732             "%d bytes left in supposedly emptied row".format(from.length));
733         return res;
734     }
735 
736     static struct RowsRange
737     {
738     private:
739         Message[] dataRows;
740     public:
741         @property size_t length() { return dataRows.length; }
742         @property bool empty() { return dataRows.empty; }
743         @property ResTuple front()
744         {
745             return demarshalRow(dataRows[0].data);
746         }
747         @property ResTuple back()
748         {
749             return demarshalRow(dataRows[$-1].data);
750         }
751         ResTuple opIndex(size_t i)
752         {
753             return demarshalRow(dataRows[i].data);
754         }
755         void popFront() { dataRows = dataRows[1 .. $]; }
756         void popBack() { dataRows = dataRows[0 .. $-1]; }
757         RowsRange save()
758         {
759             return RowsRange(dataRows);
760         }
761     }
762 
763     return RowsRange(data);
764 }