#include "./UdpSocket.hpp"

using std::byte;
using std::pair;
using std::span;
using std::unexpected;
using std::chrono::duration_cast;
using std::chrono::floor;
using std::chrono::microseconds;
using std::chrono::seconds;

#include <unistd.h>

namespace nixnet {

result<UdpSocket> UdpSocket::bind(const SocketAddr& addr) {
  UdpSocket result_sock;
  int fd = ::socket(addr.data()->sa_family, SOCK_DGRAM, 0);
  if (fd < 0) {
    return unexpected(errno_t(errno));
  }
  int res = ::bind(fd, addr.data(), addr.data_size());
  if (res != 0) {
    ::close(fd);
    return unexpected(errno_t(errno));
  }
  result_sock.m_fd = fd;
  return result_sock;
}

result<UdpSocket> UdpSocket::unspecified_local_addr(sa_family_t family) {
  UdpSocket result_sock;
  int fd = ::socket(family, SOCK_DGRAM, 0);
  if (fd < 0) {
    return unexpected(errno_t(errno));
  }
  result_sock.m_fd = fd;
  return result_sock;
}

UdpSocket::UdpSocket(UdpSocket&& other) : m_fd(other.m_fd) {
  other.m_fd = -1;
}

UdpSocket& UdpSocket::operator=(UdpSocket&& other) {
  if (&other != this) {
    if (m_fd >= 0) {
      ::close(m_fd);
    }
    m_fd = other.m_fd;
    other.m_fd = -1;
  }
  return *this;
}

UdpSocket::~UdpSocket() {
  if (m_fd >= 0) {
    ::close(m_fd);
  }
}

void UdpSocket::close() {
  if (m_fd >= 0) {
    ::close(m_fd);
    m_fd = -1;
  }
}

result<size_t> UdpSocket::sendto(const std::span<const std::byte> bytes,
                                 const SocketAddr& addr) const {
  size_t mtu_min = addr.is_ipv4() ? 68 : 1280;
  if (bytes.size() > mtu_min) {
    return unexpected(errno_t::MSGSIZE);
  }
  // in "real" code we should remove the above 4 lines of code.
  // We have to do this for this project to simuulate sending data over the
  // network. Since the tests have your computer send data to itself, the
  // message size cap is probably closer to 65507 bytes, which is not what you
  // can always expect.

  ssize_t res = ::sendto(m_fd, bytes.data(), bytes.size(), 0, addr.data(),
                         addr.data_size());
  if (res == -1) {
    return unexpected(errno_t(errno));
  }
  return static_cast<size_t>(res);
}

result<pair<size_t, SocketAddr>> UdpSocket::recvfrom(
    std::span<std::byte> bytes) const {
  sockaddr_storage ss;
  socklen_t ss_size = sizeof(ss);
  ssize_t res = ::recvfrom(m_fd, bytes.data(), bytes.size(), 0,
                           reinterpret_cast<sockaddr*>(&ss), &ss_size);
  if (res == -1) {
    return unexpected(errno_t(errno));
  }
  auto sockaddr_res =
      SocketAddr::from_csocket(reinterpret_cast<sockaddr*>(&ss));
  if (!sockaddr_res) {
    return unexpected(sockaddr_res.error());
  }
  return std::make_pair(static_cast<size_t>(res), sockaddr_res.value());
}

errno_t UdpSocket::connect(const SocketAddr& addr) {
  int res = ::connect(m_fd, addr.data(), addr.data_size());
  if (res != 0) {
    return errno_t(errno);
  }
  return errno_t::SUCCESS;
}

result<size_t> UdpSocket::send(const std::span<std::byte> bytes) const {
  sockaddr_storage ss;
  socklen_t ss_size = sizeof(ss);
  int temp = ::getpeername(m_fd, reinterpret_cast<sockaddr*>(&ss), &ss_size);
  if (temp < 0) {
    return unexpected(errno_t(errno));
  }
  size_t mtu_min = ss.ss_family == AF_INET ? 68 : 1280;
  if (bytes.size() > mtu_min) {
    return unexpected(errno_t::MSGSIZE);
  }
  // in "real" code we should remove the above 10 lines of code.
  // We have to do this for this project to simuulate sending data over the
  // network. Since the tests have your computer send data to itself, the
  // message size cap is probably closer to 65507 bytes, which is not what you
  // can always expect.

  ssize_t res = ::send(m_fd, bytes.data(), bytes.size(), 0);
  if (res < 0) {
    return unexpected(errno_t(errno));
  }
  return static_cast<size_t>(res);
}

result<size_t> UdpSocket::recv(std::span<std::byte> bytes) const {
  ssize_t res = ::recv(m_fd, bytes.data(), bytes.size(), 0);
  if (res < 0) {
    return unexpected(errno_t(errno));
  }
  return static_cast<size_t>(res);
}

static timeval ms_to_timeval(microseconds time) {
  timeval ret;
  if (time.count() < 0) {
    ret.tv_sec = 0;
    ret.tv_usec = 0;
    return ret;
  }

  seconds secs = floor<seconds>(time);
  ret.tv_sec = secs.count();
  ret.tv_usec = time.count() - duration_cast<microseconds>(secs).count();

  return ret;
}

errno_t UdpSocket::set_read_timeout(microseconds time) {
  timeval tv = ms_to_timeval(time);
  int res = ::setsockopt(m_fd, SOL_SOCKET, SO_RCVTIMEO, &tv, sizeof(tv));
  if (res != 0) {
    return errno_t(errno);
  }
  return errno_t::SUCCESS;
}

errno_t UdpSocket::set_write_timeout(microseconds time) {
  timeval tv = ms_to_timeval(time);
  int res = ::setsockopt(m_fd, SOL_SOCKET, SO_SNDTIMEO, &tv, sizeof(tv));
  if (res != 0) {
    return errno_t(errno);
  }
  return errno_t::SUCCESS;
}

static microseconds timeval_to_ms(timeval tv) {
  return microseconds(tv.tv_usec) +
         duration_cast<microseconds>(seconds(tv.tv_sec));
}

result<microseconds> UdpSocket::read_timeout() const {
  timeval tv;
  socklen_t tv_size = sizeof(tv);
  int res = ::getsockopt(m_fd, SOL_SOCKET, SO_RCVTIMEO, &tv, &tv_size);
  if (res != 0) {
    return unexpected(errno_t(errno));
  }
  return timeval_to_ms(tv);
}

result<microseconds> UdpSocket::write_timeout() const {
  timeval tv;
  socklen_t tv_size = sizeof(tv);
  int res = ::getsockopt(m_fd, SOL_SOCKET, SO_SNDTIMEO, &tv, &tv_size);
  if (res != 0) {
    return unexpected(errno_t(errno));
  }
  return timeval_to_ms(tv);
}

int UdpSocket::raw_fd() {
  return m_fd;
}

result<SocketAddr> UdpSocket::local_addr() const {
  sockaddr_storage ss;
  socklen_t ss_size = sizeof(ss);
  int temp = ::getsockname(m_fd, reinterpret_cast<sockaddr*>(&ss), &ss_size);
  if (temp < 0) {
    return unexpected(errno_t(errno));
  }

  return SocketAddr::from_csocket(reinterpret_cast<sockaddr*>(&ss));
};

result<SocketAddr> UdpSocket::peer_addr() const {
  sockaddr_storage ss;
  socklen_t ss_size = sizeof(ss);
  int temp = ::getpeername(m_fd, reinterpret_cast<sockaddr*>(&ss), &ss_size);
  if (temp < 0) {
    return unexpected(errno_t(errno));
  }

  return SocketAddr::from_csocket(reinterpret_cast<sockaddr*>(&ss));
};

size_t UdpSocket::get_mtu() const {
  SocketAddr addr = local_addr().value();
  if (addr.is_ipv4()) {
    return 68;
  }
  return 1280;
}

};  // namespace nixnet
