M src/TLSClient.cpp => src/TLSClient.cpp +123 -5
@@ 26,8 26,6 @@
#include <cmake.h>
-#ifdef HAVE_LIBGNUTLS
-
#include <TLSClient.h>
#include <iostream>
#include <unistd.h>
@@ 44,13 42,133 @@
#endif
#include <sys/types.h>
#include <sys/socket.h>
+#include <sys/un.h>
#include <netdb.h>
-#include <gnutls/x509.h>
-#include <text.h>
-#include <i18n.h>
+#include <i18n.h>
+#include <text.h>
#define MAX_BUF 16384
+
+LocalSocketClient::LocalSocketClient ()
+{
+ fd = 0;
+}
+
+LocalSocketClient::~LocalSocketClient ()
+{
+ if (fd)
+ {
+ shutdown (fd, SHUT_RDWR);
+ close (fd);
+ fd = 0;
+ }
+}
+
+void LocalSocketClient::connect (const std::string& host)
+{
+ struct sockaddr_un serv_addr;
+ memset(&serv_addr, 0, sizeof(serv_addr));
+ serv_addr.sun_family = AF_UNIX;
+ serv_addr.sun_path[0] = '\0';
+ strncpy(serv_addr.sun_path+1, host.c_str(), host.length());
+ fd = socket(PF_UNIX, SOCK_STREAM, 0);
+ if (-1 == fd)
+ {
+ // Failed to connect
+ throw format (STRING_CMD_SYNC_CONNECT, host, "Local");
+ }
+ int err = ::connect(fd, (struct sockaddr*) &serv_addr, offsetof(struct sockaddr_un, sun_path) + 1 + host.length());
+ if (0 != err)
+ {
+ // Failed to connect
+ throw format ("Failed to connect to local socket: {1}", ::strerror (err) );
+ }
+}
+
+void LocalSocketClient::send (const std::string& data)
+{
+ std::string packet = "XXXX" + data;
+
+ // Encode the length.
+ unsigned long l = packet.length ();
+ packet[0] = l >>24;
+ packet[1] = l >>16;
+ packet[2] = l >>8;
+ packet[3] = l;
+
+ unsigned int total = 0;
+ unsigned int remaining = packet.length ();
+
+ while (total < packet.length ())
+ {
+ int err;
+ err = ::send (fd, packet.c_str () + total, remaining, 0);
+
+ if (err < 0)
+ {
+ throw format ("Failed to send data: {1}", ::strerror (err) );
+ }
+
+ total += (unsigned int) err;
+ remaining -= (unsigned int) err;
+ }
+}
+
+////////////////////////////////////////////////////////////////////////////////
+void LocalSocketClient::recv (std::string& data)
+{
+ data = ""; // No appending of data.
+ int received = 0;
+
+ // Get the encoded length.
+ unsigned char header[4] = {0};
+ received = ::recv (fd, header, 4, 0);
+
+ if (0 > received)
+ {
+ throw format ("Failed to receive data: {1}", ::strerror (received) );
+ }
+
+ int total = received;
+
+ // Decode the length.
+ unsigned long expected = (header[0]<<24) |
+ (header[1]<<16) |
+ (header[2]<<8) |
+ header[3];
+
+ // Arbitrary buffer size.
+ char buffer[MAX_BUF];
+
+ // Keep reading until no more data. Concatenate chunks of data if a) the
+ // read was interrupted by a signal, and b) if there is more data than
+ // fits in the buffer.
+ do
+ {
+ received = ::recv (fd, buffer, MAX_BUF - 1, 0);
+
+ // Other end closed the connection.
+ if (received == 0)
+ {
+ break;
+ }
+
+ if (received < 0)
+ throw format ("Failed to receive data: {1}", ::strerror (received) );
+
+ buffer [received] = '\0';
+ data += buffer;
+ total += received;
+
+ }
+ while (received > 0 && total < (int) expected);
+}
+#ifdef HAVE_LIBGNUTLS
+
+#include <gnutls/x509.h>
+
+
static int verify_certificate_callback (gnutls_session_t);
////////////////////////////////////////////////////////////////////////////////
M src/TLSClient.h => src/TLSClient.h +15 -1
@@ 26,9 26,23 @@
#ifndef INCLUDED_TLSCLIENT
#define INCLUDED_TLSCLIENT
+#include <string>
+class LocalSocketClient
+{
+public:
+ LocalSocketClient ();
+ ~LocalSocketClient ();
+
+ void connect (const std::string&);
+
+ void send (const std::string&);
+ void recv (std::string&);
+
+private:
+ int fd;
+};
#ifdef HAVE_LIBGNUTLS
-#include <string>
#include <gnutls/gnutls.h>
class TLSClient
M src/commands/CmdSync.cpp => src/commands/CmdSync.cpp +74 -37
@@ 58,7 58,6 @@ CmdSync::CmdSync ()
int CmdSync::execute (std::string& output)
{
int status = 0;
-#ifdef HAVE_LIBGNUTLS
std::stringstream out;
Filter filter;
@@ 82,8 81,9 @@ int CmdSync::execute (std::string& output)
// If no server is set up, quit.
std::string connection = context.config.get ("taskd.server");
- if (connection == "" ||
- connection.rfind (':') == std::string::npos)
+ std::string socket = context.config.get ("taskd.socket");
+ bool local_socket_conn = socket != "";
+ if (!local_socket_conn && (connection == "" || connection.rfind (':') == std::string::npos))
throw std::string (STRING_CMD_SYNC_NO_SERVER);
// Obtain credentials.
@@ 96,35 96,6 @@ int CmdSync::execute (std::string& output)
if (credentials.size () != 3)
throw std::string (STRING_CMD_SYNC_BAD_CRED);
- // This was a Boolean value in 2.3.0, and is a tri-state since 2.4.0.
- std::string trust_value = context.config.get ("taskd.trust");
- if (trust_value != "strict" &&
- trust_value != "ignore hostname" &&
- trust_value != "allow all")
- throw std::string (STRING_CMD_SYNC_TRUST_OBS);
-
- enum TLSClient::trust_level trust = TLSClient::strict;
- if (trust_value == "allow all")
- trust = TLSClient::allow_all;
- else if (trust_value == "ignore hostname")
- trust = TLSClient::ignore_hostname;
-
- // CA must exist, if provided.
- File ca (context.config.get ("taskd.ca"));
- if (ca._data != "" && ! ca.exists ())
- throw std::string (STRING_CMD_SYNC_BAD_CA);
-
- if (trust == TLSClient::allow_all && ca._data != "")
- throw std::string (STRING_CMD_SYNC_TRUST_CA);
-
- File certificate (context.config.get ("taskd.certificate"));
- if (! certificate.exists ())
- throw std::string (STRING_CMD_SYNC_BAD_CERT);
-
- File key (context.config.get ("taskd.key"));
- if (! key.exists ())
- throw std::string (STRING_CMD_SYNC_BAD_KEY);
-
// If this is a first-time initialization, send pending.data and
// completed.data, but not backlog.data.
std::string payload = "";
@@ 183,7 154,50 @@ int CmdSync::execute (std::string& output)
signal (SIGUSR2, SIG_IGN);
Msg response;
- if (send (connection, ca._data, certificate._data, key._data, trust, request, response))
+ bool send_result = false;
+ if (!local_socket_conn) {
+#ifdef HAVE_LIBGNUTLS
+ // This was a Boolean value in 2.3.0, and is a tri-state since 2.4.0.
+ std::string trust_value = context.config.get ("taskd.trust");
+ if (trust_value != "strict" &&
+ trust_value != "ignore hostname" &&
+ trust_value != "allow all")
+ throw std::string (STRING_CMD_SYNC_TRUST_OBS);
+
+ enum TLSClient::trust_level trust = TLSClient::strict;
+ if (trust_value == "allow all")
+ trust = TLSClient::allow_all;
+ else if (trust_value == "ignore hostname")
+ trust = TLSClient::ignore_hostname;
+
+ // CA must exist, if provided.
+ File ca (context.config.get ("taskd.ca"));
+ if (ca._data != "" && ! ca.exists ())
+ throw std::string (STRING_CMD_SYNC_BAD_CA);
+
+ if (trust == TLSClient::allow_all && ca._data != "")
+ throw std::string (STRING_CMD_SYNC_TRUST_CA);
+
+ File certificate (context.config.get ("taskd.certificate"));
+ if (! certificate.exists ())
+ throw std::string (STRING_CMD_SYNC_BAD_CERT);
+
+ File key (context.config.get ("taskd.key"));
+ if (! key.exists ())
+ throw std::string (STRING_CMD_SYNC_BAD_KEY);
+
+ send_result = send (connection, ca._data, certificate._data, key._data, trust, request, response);
+#else
+ // Normal socket connection but not supported platform
+ throw std::string (STRING_CMD_SYNC_NO_TLS);
+#endif
+ } else {
+ out << format ("Will sync using local socket: {1}", socket)
+ << "\n";
+ send_result = sendLocal(socket, request, response);
+ }
+
+ if (send_result)
{
std::string code = response.get ("code");
if (code == "200")
@@ 337,13 351,36 @@ int CmdSync::execute (std::string& output)
signal (SIGUSR1, SIG_DFL);
signal (SIGUSR2, SIG_DFL);
-#else
- // Without GnuTLS found at compile time, there is no working sync command.
- throw std::string (STRING_CMD_SYNC_NO_TLS);
-#endif
return status;
}
+bool CmdSync::sendLocal (
+ const std::string& to,
+ const Msg& request,
+ Msg& response)
+{
+ try
+ {
+ LocalSocketClient client;
+
+ client.connect (to);
+ client.send (request.serialize () + "\n");
+
+ std::string incoming;
+ client.recv (incoming);
+
+ response.parse (incoming);
+ return true;
+ }
+
+ catch (std::string& error)
+ {
+ context.error (error);
+ }
+
+ // Indicate message failed.
+ return false;
+}
#ifdef HAVE_LIBGNUTLS
////////////////////////////////////////////////////////////////////////////////
bool CmdSync::send (
M src/commands/CmdSync.h => src/commands/CmdSync.h +2 -1
@@ 38,8 38,9 @@ public:
CmdSync ();
int execute (std::string&);
-#ifdef HAVE_LIBGNUTLS
private:
+ bool sendLocal (const std::string&, const Msg&, Msg&);
+#ifdef HAVE_LIBGNUTLS
bool send (const std::string&, const std::string&, const std::string&, const std::string&, const enum TLSClient::trust_level, const Msg&, Msg&);
#endif
};