#include "./SocketAddr.hpp"

#include <array>
#include <charconv>
#include <cstring>
#include <optional>
#include <sstream>
#include <system_error>

using nixnet::errno_t;
using std::array;
using std::nullopt;
using std::optional;
using std::ostream;
using std::string;
using std::string_view;
using std::stringstream;
using std::system_error;
using std::unexpected;

namespace nixnet {

result<SocketAddr> pton_helper(sa_family_t family,
                               string_view raw,
                               uint16_t port_num) {
  SocketAddr result_sock;
  sockaddr* sa_raw = result_sock.data();
  void* sa_raw_addr = nullptr;
  if (family == AF_INET) {
    sa_raw_addr = &(reinterpret_cast<sockaddr_in*>(sa_raw)->sin_addr);
  } else if (family == AF_INET6) {
    sa_raw_addr = &(reinterpret_cast<sockaddr_in6*>(sa_raw)->sin6_addr);
  } else {
    return std::unexpected(errno_t::AFNOSUPPORT);
  }
  auto res = inet_pton(family, raw.data(), sa_raw_addr);
  if (res == -1) {
    return std::unexpected(errno_t::AFNOSUPPORT);
  } else if (res != 1) {
    return std::unexpected(errno_t::INVAL);
  }
  result_sock.m_data.ss_family = family;
  result_sock.set_port(port_num);
  return result_sock;
}

static optional<uint16_t> parse_port_num(string_view port_str) {
  uint16_t port_num;
  auto result = std::from_chars(port_str.data(),
                                port_str.data() + port_str.size(), port_num);
  if (result.ec == std::errc::result_out_of_range) {
    return nullopt;
  } else if (result.ec == std::errc::invalid_argument) {
    return nullopt;
  } else if (result.ptr != port_str.data() + port_str.size()) {
    return nullopt;
  }
  return port_num;
}

result<SocketAddr> SocketAddr::v4_from_str(string_view raw, uint16_t port_num) {
  return pton_helper(AF_INET, raw, port_num);
}

result<SocketAddr> SocketAddr::v4_from_str(string_view raw,
                                           string_view port_str) {
  auto opt = parse_port_num(port_str);
  if (!opt) {
    return unexpected(errno_t::INVAL);
  }
  return pton_helper(AF_INET, raw, opt.value());
}

result<SocketAddr> SocketAddr::v6_from_str(string_view raw, uint16_t port_num) {
  return pton_helper(AF_INET6, raw, port_num);
}

result<SocketAddr> SocketAddr::v6_from_str(string_view raw,
                                           string_view port_str) {
  auto opt = parse_port_num(port_str);
  if (!opt) {
    return unexpected(errno_t::INVAL);
  }
  return pton_helper(AF_INET6, raw, opt.value());
}

result<SocketAddr> SocketAddr::from_csocket(const sockaddr* addr) {
  SocketAddr res;

  if (addr->sa_family == AF_INET) {
    memcpy(&res.m_data, addr, sizeof(sockaddr_in));
  } else if (addr->sa_family == AF_INET6) {
    memcpy(&res.m_data, addr, sizeof(sockaddr_in6));
  } else {
    return unexpected(errno_t::AFNOSUPPORT);
  }

  return res;
}

sockaddr* SocketAddr::data() {
  return reinterpret_cast<sockaddr*>(&m_data);
}

const sockaddr* SocketAddr::data() const {
  return reinterpret_cast<const sockaddr*>(&m_data);
}

bool SocketAddr::is_ipv4() const {
  return m_data.ss_family == AF_INET;
}

bool SocketAddr::is_ipv6() const {
  return m_data.ss_family == AF_INET6;
}

void SocketAddr::set_port(uint16_t port_num) {
  reinterpret_cast<sockaddr_in*>(&m_data)->sin_port = htons(port_num);
}

uint16_t SocketAddr::port() const {
  return ntohs(reinterpret_cast<const sockaddr_in*>(&m_data)->sin_port);
}

string to_string(const SocketAddr& value) {
  array<char, INET6_ADDRSTRLEN> addrbuf;
  auto* sa_raw = value.data();
  const void* sa_raw_addr = nullptr;
  if (value.is_ipv4()) {
    sa_raw_addr = &(reinterpret_cast<const sockaddr_in*>(sa_raw)->sin_addr);
  } else {
    sa_raw_addr = &(reinterpret_cast<const sockaddr_in6*>(sa_raw)->sin6_addr);
  }

  auto res = inet_ntop(sa_raw->sa_family, sa_raw_addr, addrbuf.data(),
                       INET6_ADDRSTRLEN);
  if (res == nullptr) {
    // really should not error, but just in case, we have this code.
    // If your code is throwing an error please let me know
    throw system_error(std::make_error_code(static_cast<std::errc>(errno)),
                       "Error while converting an address to a string.");
  }
  stringstream ss{};
  ss << addrbuf.data() << ":" << value.port();
  return ss.str();
}

ostream& operator<<(ostream& out, const SocketAddr& value) {
  out << to_string(value);
  return out;
}

};  // namespace nixnet
