1 module nats.client;
2 
3 import vibe.core.core;
4 import vibe.core.log;
5 import core.time;
6 import std.exception;
7 
8 public import nats.interface_;
9 import nats.parser;
10 
11 enum VERSION = "nats_v0.5.1";
12 
13 // Allow quietening the debug logging from nats client
14 version (NatsClientQuiet) {}
15 else version = NatsClientLogging;
16 
17 final class Nats
18 {
19     import std.format: sformat;
20     import std.socket: Address, parseAddress;
21     import std.traits: Unqual;
22     import eventcore.core: IOMode;
23     import vibe.core.net: TCPConnection, connectTCP, WaitForDataStatus;
24     import vibe.core.sync: TaskCondition, RecursiveTaskMutex, TaskMutex;
25     import vibe.data.json: Json;
26     import nbuff: Nbuff;
27 
28     public:
29 
30     static natsClientVersion = VERSION;
31 
32     static struct ConnectInfo
33     {
34         string name;
35         string lang = "dlang";
36         string version_ = VERSION;
37         string user;
38         string pass;
39         bool verbose;
40         bool pedantic;
41     }
42 
43     static struct Statistics
44     {
45         ulong bytesReceived;
46         ulong bytesSent;
47         ulong msgsReceived;
48         ulong msgsSent;
49     }
50 
51     this(string natsUri = "nats://127.0.0.1:4222", string name = null) @trusted
52     {
53         NatsClientConfig config;
54         config.natsUri = natsUri;
55         config.clientId = name;
56         this(config);
57     }
58 
59     this(NatsClientConfig config) @trusted
60     {
61         import std.algorithm.searching: findSplit, startsWith;
62         import std.format: format;
63         import vibe.inet.url: URL;
64 
65         immutable ushort defaultPort = 4222;
66         uint schemaSkip;
67     
68         if (config.natsUri.startsWith("nats:"))
69         {
70             _useTls = false;
71             schemaSkip = 4;
72         }
73         else if (config.natsUri.startsWith("tls:"))
74         {
75             _useTls = true;
76             schemaSkip = 3;
77         }
78 
79         string uri = schemaSkip ? format!"tcp%s"(config.natsUri[schemaSkip..$]) : config.natsUri;
80         auto immutable url = URL(uri);
81 
82         _connectInfo.name = config.clientId;
83         _connectInfo.user = url.username;
84         _connectInfo.pass = url.password;
85         _host = url.host;
86         _port = (url.port == 0) ? defaultPort : url.port;
87         _heartbeatInterval = config.heartbeatInterval;
88         _reconnectInterval = config.reconnectInterval;
89         _connectTimeout = config.connectTimeout;
90 
91         _connectMutex = new TaskMutex;
92         _connStateChange = new TaskCondition(_connectMutex);
93         _flushMutex = new TaskMutex;
94         _writeMutex = new RecursiveTaskMutex;
95         _flushSync = new TaskCondition(_flushMutex);
96         // reserve a sensible minimum amount of space for new subscriptions
97         _subs.reserve(16);
98     }
99 
100     unittest
101     {
102         NatsClientConfig testConfig;
103         testConfig.natsUri = "nats://kookman:pass@127.0.0.1:4222";
104         auto nats = new Nats(testConfig);
105         assert (nats._host == "127.0.0.1");
106         assert (nats._port == 4222);
107         assert (nats._connectInfo.user == "kookman");
108         assert (nats._connectInfo.pass == "pass");
109     }
110 
111     ~this() @trusted
112     {
113         version (NatsClientLogging) logDebug("nats client object destroyed!");
114     }
115 
116     void connect() @safe
117     {
118         _connector = runTask(&connector);
119     }
120 
121 
122     bool connected() @safe const nothrow
123     {
124         return (_connState == NatsState.CONNECTED);
125     }
126 
127 
128     const(Subscription[]) subscriptions() @safe const nothrow
129     {
130         return _subs;
131     }
132 
133 
134     const(Statistics) statistics() @safe const nothrow
135     {
136         Statistics stats;
137 
138         stats.bytesReceived = _bytesReceived;
139         stats.bytesSent = _bytesSent;
140         stats.msgsReceived = _msgRecv;
141         stats.msgsSent = _msgSent;
142         return stats;
143     }  
144 
145 
146     Subscription subscribe(string subject, NatsHandler handler, 
147         string queueGroup = null) @safe
148     {
149         Subscription s;
150 
151         s = new Subscription;
152         s.subject = subject;
153         s.sid = cast(uint)_subs.length;
154         s.handler = handler;
155         s.queueGroup = queueGroup;
156         _subs ~= s;
157         sendSubscribe(s);
158         return s;
159     }
160 
161 
162     void unsubscribe(Subscription s, uint msgsToDrain = 0) @safe
163     {
164         char[OUTCMD_BUFSIZE] buffer = void;
165         char[] cmd;
166 
167         if (msgsToDrain > 0)
168         {
169             s.msgsToExpire = s.msgsReceived + msgsToDrain;
170             cmd = buffer.sformat!"UNSUB %s %s"(s.sid, msgsToDrain);			
171         }
172         else
173         {
174             s.closed = true;
175             cmd = buffer.sformat!"UNSUB %s"(s.sid);			
176         }
177         write(cmd);
178     }
179 
180 
181     void publish(T)(scope const(char)[] subject, scope const(T)[] payload) @safe nothrow
182         if (is(Unqual!T == ubyte) || is(Unqual!T == char))
183    {
184         sendPublish(subject, payload);
185     }
186 
187 
188     void publishRequest(T)(scope const(char)[] subject, scope const(T)[] payload, NatsHandler handler) @safe nothrow
189         if (is(Unqual!T == ubyte) || is(Unqual!T == char))
190     {
191         import std.conv: to;
192         auto inboxId = to!string(_msgSent);
193         _inboxes[_msgSent] = handler;
194         auto replyInbox = _inboxPrefix ~ inboxId;
195         sendPublish(subject, payload, replyInbox);
196     }
197 
198 
199     Duration flush(Duration timeout = 5.seconds) @safe nothrow
200     {
201         _flushMutex.lock();
202         scope(exit) _flushMutex.unlock();
203         
204         auto start = MonoTime.currTime();
205         try 
206             write(PING);
207         catch (Exception e) {
208             logError("nats.client: Flush write error (%s)", e.msg);
209             disconnect();
210         }
211         _pingSent++;
212         _flushSync.wait();
213         _lastFlushRTT = MonoTime.currTime() - start;
214         return _lastFlushRTT;
215     }
216 
217     void disconnect() @safe nothrow
218     {
219         if (_connState != NatsState.DISCONNECTED) {
220             logWarn("nats.client: Disconnecting from Nats.");
221             _connState = NatsState.DISCONNECTED;
222             _connStateChange.notify();
223         }
224     }
225 
226     void close() @safe nothrow
227     {
228         if (_connState != NatsState.CLOSED) {
229             logInfo("nats.client: Closing TCP connection to Nats server.");
230             try 
231                 _conn.close();
232             catch (Exception e) {
233                 logError("nats.client: Error (%s) closing TCP connection!");
234             }
235             _connState = NatsState.CLOSED;
236             _connStateChange.notify();
237         }
238     }
239 
240     private:
241 
242     enum OUTCMD_BUFSIZE = 256;
243 
244     TCPConnection 		 _conn;
245     RecursiveTaskMutex	 _writeMutex;
246     TaskMutex            _connectMutex;
247     TaskCondition        _connStateChange;
248     TaskMutex 			 _flushMutex;
249     TaskCondition   	 _flushSync;
250     Duration 			 _lastFlushRTT;
251     Duration			 _heartbeatInterval;
252     Duration             _reconnectInterval;
253     Duration             _connectTimeout;
254     Task                 _connector;
255     Task                 _heartbeater;
256     Task                 _listener;
257     ulong                _bytesSent;
258     ulong                _bytesReceived;
259     uint                 _msgSent;
260     uint                 _msgRecv;
261     uint                 _pingSent;
262     uint                 _pongRecv;
263     uint                 _pingRecv;
264     uint                 _pongSent;
265     Subscription[]       _subs;
266     ConnectInfo          _connectInfo;
267     Json 				 _info;
268     string               _host;
269     ushort               _port;
270     bool 				 _useTls;
271     NatsState			 _connState;
272     string               _inboxPrefix;
273     NatsHandler[uint]    _inboxes;
274 
275 
276     void setupNatsConnection() @safe nothrow
277     {
278         import vibe.data.json: serializeToJsonString;
279 
280         char[OUTCMD_BUFSIZE] buffer = void;
281         char[] cmd;
282 
283         _conn.keepAlive(true);
284         _conn.tcpNoDelay(true);
285         // send the CONNECT string
286         try {
287             _conn.readTimeout(10.seconds);
288             cmd = buffer.sformat("CONNECT %s", serializeToJsonString(_connectInfo));
289             version (NatsClientLogging) logDebug("nats.client: Socket connected. Sending: %s", cmd);
290             write(cmd);
291         }
292         catch (Exception e) {
293             logError("nats.client: Failed to send initial CONNECT string (%s).", e.msg);
294             return;
295         }
296         auto rtt = flush();
297         version (NatsClientLogging) logDebug("nats.client: Flush roundtrip (%s) completed.", rtt);
298         // create a connection specific inbox subscription
299         _inboxPrefix = "_INBOX_" ~ _conn.localAddress.toString() ~ "_.";
300         auto inbox = new Subscription;
301         inbox.subject = _inboxPrefix ~ "*";
302         inbox.sid = 0;
303         inbox.handler = &inboxHandler;
304         //_subs[0] is always my _INBOX_ so just replace any existing one
305         if (_subs.length == 0) 
306             _subs ~= inbox;
307         else
308             _subs[0] = inbox;
309         version (NatsClientLogging) logDebug("nats.client: Setting up inbox subscription: %s", inbox.subject);
310         sendSubscribe(inbox);
311         if (_subs.length > 1) 
312         {
313             version (NatsClientLogging) logDebug("nats.client: Re-sending active subscriptions after reconnection to Nats server.");
314             foreach (priorSubscription; _subs[1..$]) {
315                 if (!priorSubscription.closed) sendSubscribe(priorSubscription);
316             }
317         }
318         _connState = NatsState.CONNECTED;
319     }
320 
321     void connector() @safe nothrow
322     {
323         import std.random: uniform;
324         auto reconnectTimer = createTimer(null);
325         _connectMutex.lock();
326         scope(exit) _connectMutex.unlock();
327         
328         connector:
329         while (_connState != NatsState.CLOSED) {
330             final switch (_connState) {
331                 case NatsState.INIT:
332                 case NatsState.RECONNECTING:
333                     version (NatsClientLogging) logDiagnostic(
334                         "nats.client: Establishing connection to NATS server (%s) port %s ...", _host, _port);
335                     try {
336                         if (_conn.connected)
337                             _conn.close();
338                         _connState = NatsState.CONNECTING;
339                         _conn = connectTCP(_host, _port, null, 0, 5.seconds);
340                     }
341                     catch (Exception e) {
342                         logWarn("nats.client: Exception whilst attempting connect to Nats server, msg: %s", e.msg);
343                     }
344                     break;
345                 
346                 case NatsState.CONNECTING:
347                     if (_conn.connected) {
348                         _listener = runTask(&listener);
349                         setupNatsConnection();
350                     } else {
351                         logWarn("nats.client Connector: Connection attempt to %s failed.", _host);
352                         _connState = NatsState.DISCONNECTED;
353                     }
354                     break;                   
355 
356                 case NatsState.DISCONNECTED:
357                     // wait for an interval + random extra delay to avoid thundering herd on Nats server down
358                     try {
359                         Duration delay = _reconnectInterval + uniform(0, 1000).msecs;
360                         logInfo("nats.client: Waiting %s before attempting reconnect.", delay);
361                         reconnectTimer.rearm(delay, false);
362                         reconnectTimer.wait();
363                     }
364                     catch (Exception e) {
365                         logWarn("nats.client: Reconnect timer interrupted.");
366                         break;
367                     }
368                     _connState = NatsState.RECONNECTING;
369                     break;
370 
371                 case NatsState.CONNECTED:
372                     logInfo("nats.client: Nats connection ready.");
373                     _heartbeater = runTask(&heartbeater);
374                     _connStateChange.wait();
375                     break;
376 
377                 case NatsState.CLOSED:
378                     break connector;
379             }
380         }
381         version (NatsClientLogging) logDiagnostic("nats.client: Connector task terminating.");
382     }
383 
384 
385     void listener() @safe nothrow
386     {
387         version (NatsClientLogging) logDiagnostic("nats.client: listener task started.");
388         try {
389             enum size_t readSize = 2048;
390             Nbuff buffer;
391 
392             while(_connState == NatsState.CONNECTING || _connState == NatsState.CONNECTED) {
393                 // task blocks here until we receive data. No timeout as that is handled
394                 // by the heartbeater task.
395                 auto result = _conn.waitForDataEx();
396                 if (result == WaitForDataStatus.dataAvailable) {
397                     auto readBuffer = Nbuff.get(readSize);
398                     auto readBufferSpace = () @trusted { return readBuffer.data(); }();
399                     immutable bytesRead = _conn.read(readBufferSpace, IOMode.once);
400                     _bytesReceived += bytesRead;
401                     // append this newly read chunk to the buffer and process
402                     buffer.append(readBuffer, bytesRead);
403                     size_t consumed = processNatsStream(buffer);
404                     version (NatsClientLogging) {
405                         if (consumed < buffer.length)
406                             logTrace("Fragment (length: %s) left in buffer. Consolidating with another read.", buffer.length - consumed);
407                     }
408                     // free fully processed messages from the buffer
409                     buffer.pop(consumed);
410                 }
411                 else if (result == WaitForDataStatus.noMoreData) {
412                     throw new Exception("WaitForDataStatus.noMoreData");
413                 }
414             }
415         }
416         catch (Exception e) {
417             logWarn("nats.client: Nats session disconnected! (%s).", e.msg);
418             disconnect();
419         }
420         version (NatsClientLogging) logDiagnostic("nats.client: Listener task terminating. Connector will attempt reconnect.");
421     }
422 
423 
424     void heartbeater() @safe nothrow
425     {
426         version (NatsClientLogging) logDiagnostic("nats.client: heartbeater task started.");
427 
428         auto timer = createTimer(null);
429         bool flushPending = false;
430 
431         void heartbeat() @safe nothrow
432         {
433             flushPending = true;
434             scope(exit) flushPending = false;
435             auto t = runTask(() @safe nothrow {
436                 auto rtt = flush();
437                 version (NatsClientLogging) logDebug("nats.client: Nats Heartbeat RTT: %s", rtt);
438             });
439             try
440                 t.join();
441             catch (Exception e) {
442                 logWarn("nats.client: Heartbeat flush failed to sync.");
443                 disconnect();
444             }
445         }
446 
447         while(_connState == NatsState.CONNECTED) {
448             auto prevSent = _msgSent + _pingSent;
449             auto prevRecv = _msgRecv + _pingRecv;
450             timer.rearm(_heartbeatInterval, false);
451             try
452                 timer.wait();
453             catch (Exception e) {
454                 logWarn("nats.client: Heartbeat timer interrupted.");
455                 break;
456             }
457             if (_msgSent + _pingSent == prevSent && _msgRecv + _pingRecv == prevRecv && _connState == NatsState.CONNECTED) {
458                 version (NatsClientLogging) logDebugV("nats.client: Nats connection idle for %s. Sending heartbeat.",
459                     _heartbeatInterval);
460                 runTask(&heartbeat);
461             } else if (flushPending) {
462                 logError("nats.client: Heartbeat did not return within heartbeat interval (%s).", _heartbeatInterval);
463                 disconnect();
464             }
465         }
466         version (NatsClientLogging) logDiagnostic("nats.client: Nats session not connected! Heartbeater task terminating.");
467     }
468 
469 
470     void sendPublish(T)(scope const(char)[] subject, scope const(T)[] payload, scope const(char)[] replySubject = null) @safe nothrow
471         if (is(Unqual!T == ubyte) || is(Unqual!T == char))
472     {
473         char[OUTCMD_BUFSIZE] buffer = void;
474         char[] cmd;
475     
476         try {
477             cmd = buffer.sformat!"PUB %s %s %s"(subject, replySubject, payload.length);
478             // ensure we don't interleave other writes between writing PUB command & payload
479             _writeMutex.lock();
480             scope(exit) _writeMutex.unlock();       
481             write(cmd);
482             write(payload);
483             _msgSent++;
484         }
485         catch (Exception e) {
486             logError("nats.client: Publish failed (%s), command (%s)", e.msg, cmd);
487             disconnect();
488         }
489     }
490 
491 
492     void sendSubscribe(Subscription s) @safe nothrow
493     {
494         char[OUTCMD_BUFSIZE] buffer = void;
495         char[] cmd;
496 
497         try {
498             if (s.queueGroup)
499                 cmd = buffer.sformat!"SUB %s %s %s"(s.subject, s.queueGroup, s.sid);
500             else
501                 cmd = buffer.sformat!"SUB %s %s"(s.subject, s.sid);
502             write(cmd);
503         }
504         catch (Exception e) {
505             logError("nats.client: Error (%s) subscribing on (%s)", e.msg, s.subject);
506             disconnect();
507         }
508     }
509 
510 
511     void write(T)(scope const(T)[] buffer) @safe 
512         if (is(Unqual!T == ubyte) || is(Unqual!T == char))
513     {
514         synchronized(_writeMutex)
515         {
516             version (NatsClientLogging) logTrace("nats.client write: %s", () @trusted { return cast(string)buffer; }());            
517             auto bytesWritten = _conn.write(cast(const(ubyte)[]) buffer, IOMode.all);
518             // nats protocol requires all writes to be separated by CRLF
519             bytesWritten += _conn.write(CRLF, IOMode.all);
520             if (bytesWritten != buffer.length + 2)
521             {
522                 logError("nats.client: Error writing to Nats connection! Socket write error.");
523             }
524             _bytesSent += bytesWritten;
525         }
526     }
527 
528 
529     size_t processNatsStream(ref Nbuff natsResponse) @safe
530     {
531         import std.algorithm.comparison: min;
532 
533         size_t consumed = 0;
534         Subscription subscription;
535     
536         loop:
537         while(consumed < natsResponse.length)
538         {
539             Msg msg;
540             Nbuff msgPayload;
541 
542             if (consumed > 0)
543                 natsResponse.pop(consumed);
544             auto responseData = () @trusted { return natsResponse.data().data(); }();
545             version (NatsClientLogging)
546                 logTrace("Sending response slice length %d to parseNats.", responseData.length);
547             consumed = parseNats(responseData, msg);
548 
549             final switch (msg.type)
550             {	
551                 case NatsResponse.FRAGMENT:
552                     if (consumed > 0) {
553                         break;
554                     } else {
555                         break loop;
556                     }
557                         
558                 case NatsResponse.MSG:
559                 case NatsResponse.MSG_REPLY:
560                     immutable alreadyRead = min(msg.length, natsResponse.length - consumed);
561                     if (msg.length > alreadyRead)
562                     {
563                         version (NatsClientLogging) 
564                             logTrace("MSG payload exceeds initial buffer (length: %s).", msg.length);
565                         auto msgPayloadBuffer = Nbuff.get(msg.length - alreadyRead);
566                         auto remainingPayload = () @trusted { return msgPayloadBuffer.data(); }();
567                         immutable bytesRead = _conn.read(remainingPayload[0 .. msg.length - alreadyRead], IOMode.all);
568                         if (bytesRead < msg.length - alreadyRead)
569                         {
570                             logError("nats.client: Message payload incomplete! (expected: %s bytes)", msg.length);
571                         }
572                         natsResponse.append(msgPayloadBuffer, bytesRead);
573                         _bytesReceived += bytesRead;
574                     }
575                     msgPayload = natsResponse[consumed .. consumed + msg.length];
576                     consumed += msg.length;
577                     msg.payload = () @trusted { return msgPayload.data().data(); }();
578                     _msgRecv++;
579                     subscription = _subs[msg.sid];
580                     subscription.msgsReceived++;
581                     subscription.bytesReceived += msg.payload.length;
582                     if (subscription.msgsReceived > subscription.msgsToExpire)
583                         subscription.closed = true;
584                     if (subscription.closed)
585                         logWarn("nats.client: Discarding message received on closed subscription! (msg.subject: %s, subscription: %s)", 
586                             msg.subject, subscription.subject);
587                     else
588                         // now call the message handler
589                         // note: this is a synchronous callback - don't block the event loop for too long
590                         // if necessary, copy what is needed out of the msg and send to a new Task or Thread
591                         subscription.handler(msg);
592                     continue loop;
593                 
594                 case NatsResponse.PING:
595                     write(PONG);
596                     _pongSent++;
597                     version (NatsClientLogging) logDebugV("Pong sent.");
598                     continue loop;
599                 
600                 case NatsResponse.PONG:
601                     _pongRecv++;
602                     if (_pongRecv == _pingSent) 
603                         _flushSync.notify();
604                     continue loop;
605 
606                 case NatsResponse.OK:
607                     version (NatsClientLogging) logDebug("Ok received.");
608                     continue loop;
609 
610                 case NatsResponse.INFO:
611                     version (NatsClientLogging) logDebug("Server INFO msg: %s", msg.payloadAsString);
612                     processServerInfo(msg);
613                     continue loop;
614 
615                 case NatsResponse.ERR:
616                     logError("nats.client: Server ERR msg: %s", msg.payloadAsString);
617                     continue loop;
618             }
619         }
620         return consumed;		
621     }
622 
623     void processServerInfo(scope Msg msg) @safe
624     {
625         import vibe.data.json: parseJsonString;
626 
627         _info = parseJsonString(msg.payloadAsString);
628     }
629 
630     void inboxHandler(scope Msg msg) @safe nothrow
631     {
632         import std.algorithm.searching: findSplitAfter;
633         import std.conv: to;
634         
635         version (NatsClientLogging) 
636             logDebugV("Inbox %s handler called with msg: %s", msg.subject, msg.payloadAsString);
637         uint inbox;
638         bool badResponse = false;
639         try {
640             auto findInbox = msg.subject.findSplitAfter("_.");
641             if (!findInbox) 
642                 badResponse = true;
643             else
644                 inbox = findInbox[1].to!uint;
645         }
646         catch (Exception e) {
647             badResponse = true;
648         }
649         if (badResponse) {
650             logWarn("nats.client: Unexpected msg (%s) received in inbox. Discarding.", msg.subject);
651             return;
652         }
653         auto p_handler = (inbox in _inboxes);
654         if (p_handler !is null) {
655             (*p_handler)(msg);
656             _inboxes.remove(inbox);
657         }
658         else {
659             logWarn("nats.client: No response handler for response in inbox: %s. Response discarded.", 
660                 msg.subject);
661         }
662     }
663 }