@@ 8,6 8,7 @@ use errors;
use fmt;
use format::ssh;
use fs;
+use himitsu::client;
use himitsu::query;
use io;
use log;
@@ 22,6 23,10 @@ use unix::signal;
let running: bool = true;
+type server = struct {
+ sock: (void | net::socket),
+};
+
export fn main() void = {
// TODO: Parse options (foreground/background, socket path)
// TODO: Fork to background
@@ 46,6 51,14 @@ export fn main() void = {
signal::handle(signal::sig::INT, &handle_signal);
signal::handle(signal::sig::TERM, &handle_signal);
+ // make write/read from borken sockets cause an error
+ signal::ignore(signal::sig::PIPE);
+
+ let server = server {
+ sock = void,
+ };
+ defer server_finish(&server);
+
for (running) {
const client = match (net::accept(socket)) {
case errors::interrupted =>
@@ 55,7 68,7 @@ export fn main() void = {
case let fd: net::socket =>
yield fd;
};
- serve(client);
+ serve(client, &server);
};
log::println("Terminated.");
@@ 65,10 78,20 @@ fn handle_signal(sig: signal::sig, info: *signal::siginfo, ucontext: *opaque) vo
running = false;
};
-fn serve(conn: net::socket) void = {
+fn serve(conn: net::socket, server: *server) void = {
const agent = agent::new(conn);
defer agent::agent_finish(&agent);
+ const himitsu = match (himitsu_connection(server)) {
+ case let s: net::socket =>
+ yield s;
+ case let e: net::error =>
+ log::println("himitsu: could not connect:", net::strerror(e));
+ const answer: agent::message = agent::agent_failure;
+ agent::writemsg(&agent, &answer)!;
+ return;
+ };
+
for (true) {
const msg = match (agent::readmsg(&agent)) {
case io::EOF => break;
@@ 83,9 106,9 @@ fn serve(conn: net::socket) void = {
const res = match (msg) {
case agent::request_identities =>
- yield handle_req_ident(&agent);
+ yield handle_req_ident(&agent, himitsu);
case let msg: agent::sign_request =>
- yield handle_sign_request(&agent, &msg);
+ yield handle_sign_request(server, &agent, &msg, himitsu);
case agent::extension =>
const answer: agent::message = agent::extension_failure;
agent::writemsg(&agent, &answer)!;
@@ 107,9 130,10 @@ fn serve(conn: net::socket) void = {
};
};
-fn handle_req_ident(agent: *agent::agent) (void | agent::error | net::error) = {
- const himitsu = himitsu_connect()?;
- defer io::close(himitsu)!;
+fn handle_req_ident(
+ agent: *agent::agent,
+ himitsu: net::socket
+) (void | agent::error | net::error) = {
fmt::fprintln(himitsu, "query proto=ssh")?;
let idents: agent::identities_answer = [];
@@ 215,8 239,10 @@ fn handle_req_ident(agent: *agent::agent) (void | agent::error | net::error) = {
};
fn handle_sign_request(
+ s: *server,
agent: *agent::agent,
msg: *agent::sign_request,
+ himitsu: net::socket,
) (void | agent::error | net::error) = {
const source = memio::fixed(msg.key);
const key = match (ssh::decodepublic(&source)) {
@@ 237,8 263,6 @@ fn handle_sign_request(
io::close(&b64en)!;
fmt::fprintln(&req)!;
- const himitsu = himitsu_connect()?;
- defer io::close(himitsu)!;
io::writeall(himitsu, memio::buffer(&req))?;
let comment = "";
@@ 258,6 282,9 @@ fn handle_sign_request(
if (bytes::equal(buf, strings::toutf8("end"))) {
break;
+ } else if (found) {
+ // continue until end
+ continue;
} else if (bytes::hasprefix(buf, strings::toutf8("error "))) {
break;
} else if (bytes::hasprefix(buf, strings::toutf8("key "))) {
@@ 315,7 342,6 @@ fn handle_sign_request(
yield;
};
found = true;
- break;
};
if (!found) {
@@ 336,10 362,39 @@ fn handle_sign_request(
log::printfln("Signed challenge with key {}", comment);
};
-fn himitsu_connect() (net::socket | net::error) = {
+fn himitsu_connection(s: *server) (net::socket | net::error) = {
+ let sock = match (s.sock) {
+ case void =>
+ return himitsu_connect(s);
+ case let s: net::socket =>
+ yield s;
+ };
+
+ // check if the socket is still working.
+ match (client::get_state(sock)) {
+ case client::state =>
+ return sock;
+ case let e: client::hierror =>
+ return sock;
+ case =>
+ return himitsu_connect(s);
+ };
+};
+
+fn himitsu_connect(s: *server) (net::socket | net::error) = {
let path = path::init()!;
const sockpath = path::set(&path, dirs::runtime()!, "himitsu")!;
- return unix::connect(sockpath);
+ let sock = unix::connect(sockpath)?;
+ s.sock = sock;
+ return sock;
+};
+
+fn server_finish(s: *server) void = {
+ match (s.sock) {
+ case void => void;
+ case let s: net::socket =>
+ net::close(s)!;
+ };
};
// XXX: This really belongs in the himitsu library