/* * Example ECDHE with Curve25519, based on ecdh_curve25519.c from mbed TLS * * Copyright (C) 2018, Samuel Kupka, All Rights Reserved * SPDX-License-Identifier: Apache-2.0 * * Original file (ecdh_curve25519.c) license: * * Copyright (C) 2006-2015, ARM Limited, All Rights Reserved * SPDX-License-Identifier: Apache-2.0 * * Licensed under the Apache License, Version 2.0 (the "License"); you may * not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * */ /* Compilation: * g++ -O2 -Wall -Wextra -Werror -std=c++14 ecdh.cpp -o ecdh -lmbedcrypto * And run: * ./ecdh */ #include <string> #include <exception> #include <iostream> #include <random> #include <algorithm> #include <cstdio> #include <cstdlib> #include <cstdarg> #include <cinttypes> #include <mbedtls/ecdh.h> #if !defined(MBEDTLS_ECDH_C) || !defined(MBEDTLS_ECP_DP_CURVE25519_ENABLED) #error Missing important component #endif /* Just simple Error exception class, copied from here: https://stackoverflow.com/a/12261993 */ class Error : public std::exception { private: char text[2048]; public: Error (char const* fmt, ...) __attribute__((format(printf,2,3))) { va_list ap; ::va_start (ap, fmt); ::vsnprintf (text, sizeof (text), fmt, ap); ::va_end (ap); } char const* what () const throw () override { return text; } }; /* This is a random generator suitable for mbedtls functions using mt19937 from C++ std library */ static int _cpp_random_for_mbedtls (void *p_rng, unsigned char *output, size_t output_len) { (void) p_rng; static std::random_device rd; static std::mt19937 mte (rd ()); std::uniform_int_distribution<unsigned char> dist (0x00, 0xff); std::generate_n (output, output_len, [&] (void) -> unsigned char { return dist(mte); }); return 0; } class OneSide { private: const std::string _name; mbedtls_ecdh_context _ctx; std::string _public_key; std::string _secret; char _text[2048]; unsigned char _output[32]; void print (char const* fmt, ...); public: explicit OneSide (const std::string &name); ~OneSide (); void calc_public_key (void); void calc_secret (const std::string &other_public); std::string get_public_key (void) const; std::string get_secret (void) const; }; void OneSide::print (char const* fmt, ...) { va_list ap; ::va_start (ap, fmt); ::vsnprintf (_text, sizeof (_text), fmt, ap); ::va_end (ap); std::cout << _name << " : " << _text << std::endl; } OneSide::OneSide (const std::string &name) : _name (name) { /* Initialize ECDH context */ mbedtls_ecdh_init (&_ctx); } OneSide::~OneSide () { /* Free ECDH context */ mbedtls_ecdh_free (&_ctx); } void OneSide::calc_public_key (void) { int ret; /* Inialize context and generate keypair */ ret = mbedtls_ecp_group_load (&_ctx.grp, MBEDTLS_ECP_DP_CURVE25519); if (ret != 0) { throw Error ("mbedtls_ecp_group_load = %d", ret); } ret = mbedtls_ecdh_gen_public (&_ctx.grp, &_ctx.d, &_ctx.Q, _cpp_random_for_mbedtls, nullptr); if (ret != 0) { throw Error ("mbedtls_ecdh_gen_public = %d", ret); } /* Convert public key to big-endian binary format */ ret = mbedtls_mpi_write_binary (&_ctx.Q.X, _output, sizeof (_output)); if (ret != 0) { throw Error ("mbedtls_mpi_write_binary = %d", ret); } /* Convert public key to std::string. It uses 32 bytes. */ _public_key.assign (reinterpret_cast<const char *> (_output), sizeof (_output)); print ("my public key is %02" PRIX8 "...%02" PRIX8 " (%zu bytes)", static_cast<uint8_t> (_public_key.front ()), static_cast<uint8_t> (_public_key.back ()), _public_key.size ()); } void OneSide::calc_secret (const std::string &other_public) { int ret; print ("got other public key %02" PRIX8 "...%02" PRIX8 " (%zu bytes)", static_cast<uint8_t> (other_public.front ()), static_cast<uint8_t> (other_public.back ()), other_public.size ()); /* Set value from integer */ ret = mbedtls_mpi_lset (&_ctx.Qp.Z, 1); if (ret != 0) { throw Error ("mbedtls_mpi_lset = %d", ret); } /* Read value from big-endian binary string */ ret = mbedtls_mpi_read_binary (&_ctx.Qp.X, reinterpret_cast<const unsigned char *> (other_public.data ()), other_public.size ()); if (ret != 0) { throw Error ("mbedtls_mpi_read_binary = %d", ret); } /* Compute shared secret */ ret = mbedtls_ecdh_compute_shared (&_ctx.grp, &_ctx.z, &_ctx.Qp, &_ctx.d, _cpp_random_for_mbedtls, nullptr); if (ret != 0) { throw Error ("mbedtls_ecdh_compute_shared = %d", ret); } /* Convert shared secret to big-endian binary format. It uses 32 bytes. */ ret = mbedtls_mpi_write_binary (&_ctx.z, _output, sizeof (_output)); if (ret != 0) { throw Error ("mbedtls_mpi_write_binary = %d", ret); } /* Convert shared secret to std::string */ _secret.assign (reinterpret_cast<const char *> (_output), sizeof (_output)); print ("shared secret is %02" PRIX8 "...%02" PRIX8 " (%zu bytes)", static_cast<uint8_t> (_secret.front ()), static_cast<uint8_t> (_secret.back ()), _secret.size ()); } std::string OneSide::get_public_key (void) const { return _public_key; } std::string OneSide::get_secret (void) const { return _secret; } int main (int argc, char *argv[]) { (void) argc; (void) argv; try { /* Create both sides */ OneSide side_a ("Side A"); OneSide side_b ("Side B"); /* Calculate public key on both sides */ side_a.calc_public_key (); side_b.calc_public_key (); /* Pass public keys from one side to the other and calc shared secret */ side_a.calc_secret (side_b.get_public_key ()); side_b.calc_secret (side_a.get_public_key ()); /* Get both shared secrets and compare them */ const std::string secret_a = side_a.get_secret (); const std::string secret_b = side_b.get_secret (); if (secret_a.compare (secret_b) != 0) { throw Error ("secrets from A and from B differ"); } std::cout << "OK" << std::endl; } catch (const std::exception &e) { std::cerr << "Error: " << e.what () << std::endl; } catch (...) { std::cerr << "Fatal Error" << std::endl; } }