~kf5jwc/dns-server-rs

50d06f269fbdca915fa5dd711cb02451473ad456 — Kyle Jones 6 months ago f3f0b04
Implement a proxy resolver server! Neato! :D
1 files changed, 170 insertions(+), 0 deletions(-)

A src/bin/proxy_resolver.rs
A src/bin/proxy_resolver.rs => src/bin/proxy_resolver.rs +170 -0
@@ 0,0 1,170 @@
extern crate log;
extern crate env_logger;
extern crate rand;
extern crate structopt;
extern crate dns_server;

use std::io;
use std::net::{IpAddr, SocketAddr, UdpSocket};
use log::{warn, info, debug, trace};
use rand::Rng;
use structopt::StructOpt;
use dns_server::{
    BytePacketBuffer,
    DnsPacket,
    DnsQuestion,
    QueryType,
    ResultCode,
};

#[derive(Debug, StructOpt)]
#[structopt()]
struct Args {
    #[structopt(short="h", parse(try_from_str), default_value="0.0.0.0")]
    server_host: IpAddr,
    #[structopt(short="p", default_value="5655")]
    server_port: u16,
    #[structopt(short="u", parse(try_from_str), default_value="192.168.1.1")]
    upstream_server: IpAddr,
}

fn stage_1(server_socket: &UdpSocket) -> Result<(SocketAddr, BytePacketBuffer), io::Error> {
    let mut request_buffer = BytePacketBuffer::default();
    match server_socket.recv_from(&mut request_buffer.buf) {
        Ok((_, src)) => Ok((src, request_buffer)),
        Err(e) => Err(e),
    }
}

fn main() {
    env_logger::init();
    let args = Args::from_args();

    info!("Starting server on {}:{}", args.server_host, args.server_port);
    let server_socket = UdpSocket::bind((args.server_host, args.server_port)).expect("Opening server listen socket");

    loop {
        let (request_source, request_packet): (_, DnsPacket) = match stage_1(&server_socket) {
            Ok((src, request_buffer)) => {
                info!("Recieved request from {:?}", src);
                (src, request_buffer.into())
            },
            Err(e) => {
                warn!("An error occurred while reading request from socket: {}", e);
                continue;
            },
        };

        if request_packet.questions.is_empty() {
            info!("Request packet contains no questions!");
            send_failure_response(&server_socket, &request_source, &request_packet, ResultCode::FORMERR);
            continue;
        }

        let question = &request_packet.questions[0];
        debug!("Recv'd query: {:?}", question);

        let upstream_socket = UdpSocket::bind(("0.0.0.0", 43102)).expect("Opening upstream request socket");

        debug!("Preparing upstream request");
        let upstream_packet = build_dns_query_packet(question.name.clone(), question.qtype.clone());
        let upstream_buffer: BytePacketBuffer = upstream_packet.into();

        debug!("Sending upstream request");
        match upstream_socket.send_to(upstream_buffer.raw_buffer(), (args.upstream_server, 53)) {
            Ok(_) => {},
            Err(e) => {
                warn!("Unable to send the upstream request: {:?}", e);
                send_failure_response(&server_socket, &request_source, &request_packet, ResultCode::SERVFAIL);
                continue;
            },
        }

        debug!("Receiving upstream response");
        let mut upstream_response_buffer = BytePacketBuffer::default();
        match upstream_socket.recv_from(&mut upstream_response_buffer.buf) {
            Ok(_) => {},
            Err(e) => {
                warn!("Error occurred while attempting to recv an upstream response: {:?}", e);
                send_failure_response(&server_socket, &request_source, &request_packet, ResultCode::SERVFAIL);
                continue;
            },
        }

        debug!("Parsing upstream response");
        let upstream_response_packet: DnsPacket = upstream_response_buffer.into();

        debug!("Building response");
        let response_packet = build_response_packet(request_packet, upstream_response_packet);

        info!("Sending response packet");
        let response_buffer: BytePacketBuffer = response_packet.into();
        match server_socket.send_to(response_buffer.raw_buffer(), request_source) {
            Ok(_) => {},
            Err(e) => {
                warn!("Failed to send response: {:?}", e);
                continue;
            },
        }


    }

}

fn build_response_packet(request_packet: DnsPacket, upstream_response_packet: DnsPacket) -> DnsPacket {
    let mut response_packet = DnsPacket::default();
    response_packet.header.id = request_packet.header.id;
    response_packet.header.recursion_desired = true;
    response_packet.header.recursion_available = true;
    response_packet.header.response = true;

    for record in upstream_response_packet.questions {
        trace!("Answer record: {:?}", record);
        response_packet.questions.push(record);
    }

    for record in upstream_response_packet.answers {
        trace!("Answer record: {:?}", record);
        response_packet.answers.push(record);
    }

    for record in upstream_response_packet.authorities {
        trace!("Authority record: {:?}", record);
        response_packet.authorities.push(record);
    }

    for record in upstream_response_packet.resources {
        trace!("Resource record: {:?}", record);
        response_packet.resources.push(record);
    }

    return response_packet;
}

fn send_failure_response(server_socket: &UdpSocket, request_source: &SocketAddr, request_packet: &DnsPacket, result_code: ResultCode) {
    let mut response_packet = DnsPacket::default();
    response_packet.header.id = request_packet.header.id;
    response_packet.header.rescode = result_code;
    let response_buffer: BytePacketBuffer = response_packet.into();
    match server_socket.send_to(response_buffer.raw_buffer(), request_source) {
        Ok(_) => {},
        Err(e) => warn!("Failed to send response buffer: {:?}", e),
    };
}

fn build_dns_query_packet(name: String, qtype: QueryType) -> DnsPacket {
    let mut rng = rand::thread_rng();
    let mut packet = DnsPacket::default();
    packet.header.id = rng.gen();
    packet.header.questions = 1;
    packet.header.recursion_desired = true;
    packet.questions.push({
        DnsQuestion{
            name: name,
            qtype: qtype,
        }
    });

    return packet;
}