[tls] Use our own ASN.1 routines for certificate parsing
[people/asdlkf/gpxe.git] / src / net / tls.c
index fa4b58d..f5bff7a 100644 (file)
@@ -36,6 +36,8 @@
 #include <gpxe/xfer.h>
 #include <gpxe/open.h>
 #include <gpxe/filter.h>
+#include <gpxe/asn1.h>
+#include <gpxe/x509.h>
 #include <gpxe/tls.h>
 
 static int tls_send_plaintext ( struct tls_session *tls, unsigned int type,
@@ -43,6 +45,33 @@ static int tls_send_plaintext ( struct tls_session *tls, unsigned int type,
 static void tls_clear_cipher ( struct tls_session *tls,
                               struct tls_cipherspec *cipherspec );
 
+/******************************************************************************
+ *
+ * Utility functions
+ *
+ ******************************************************************************
+ */
+
+/**
+ * Extract 24-bit field value
+ *
+ * @v field24          24-bit field
+ * @ret value          Field value
+ *
+ * TLS uses 24-bit integers in several places, which are awkward to
+ * parse in C.
+ */
+static unsigned long tls_uint24 ( uint8_t field24[3] ) {
+       return ( ( field24[0] << 16 ) + ( field24[1] << 8 ) + field24[2] );
+}
+
+/******************************************************************************
+ *
+ * Cleanup functions
+ *
+ ******************************************************************************
+ */
+
 /**
  * Free TLS session
  *
@@ -57,8 +86,7 @@ static void free_tls ( struct refcnt *refcnt ) {
        tls_clear_cipher ( tls, &tls->tx_cipherspec_pending );
        tls_clear_cipher ( tls, &tls->rx_cipherspec );
        tls_clear_cipher ( tls, &tls->rx_cipherspec_pending );
-       free ( tls->rsa_mod );
-       free ( tls->rsa_pub_exp );
+       x509_free_rsa_public_key ( &tls->rsa );
        free ( tls->rx_data );
 
        /* Free TLS structure itself */
@@ -622,8 +650,8 @@ static int tls_send_client_hello ( struct tls_session *tls ) {
 static int tls_send_client_key_exchange ( struct tls_session *tls ) {
        /* FIXME: Hack alert */
        RSA_CTX *rsa_ctx;
-       RSA_pub_key_new ( &rsa_ctx, tls->rsa_mod, tls->rsa_mod_len,
-                         tls->rsa_pub_exp, tls->rsa_pub_exp_len );
+       RSA_pub_key_new ( &rsa_ctx, tls->rsa.modulus, tls->rsa.modulus_len,
+                         tls->rsa.exponent, tls->rsa.exponent_len );
        struct {
                uint32_t type_length;
                uint16_t encrypted_pre_master_secret_len;
@@ -641,8 +669,8 @@ static int tls_send_client_key_exchange ( struct tls_session *tls ) {
        DBGC ( tls, "RSA encrypting plaintext, modulus, exponent:\n" );
        DBGC_HD ( tls, &tls->pre_master_secret,
                  sizeof ( tls->pre_master_secret ) );
-       DBGC_HD ( tls, tls->rsa_mod, tls->rsa_mod_len );
-       DBGC_HD ( tls, tls->rsa_pub_exp, tls->rsa_pub_exp_len );
+       DBGC_HD ( tls, tls->rsa.modulus, tls->rsa.modulus_len );
+       DBGC_HD ( tls, tls->rsa.exponent, tls->rsa.exponent_len );
        RSA_encrypt ( rsa_ctx, ( const uint8_t * ) &tls->pre_master_secret,
                      sizeof ( tls->pre_master_secret ),
                      key_xchg.encrypted_pre_master_secret, 0 );
@@ -761,17 +789,16 @@ static int tls_new_alert ( struct tls_session *tls, void *data, size_t len ) {
 }
 
 /**
- * Receive new Server Hello record
+ * Receive new Server Hello handshake record
  *
  * @v tls              TLS session
- * @v data             Plaintext record
- * @v len              Length of plaintext record
+ * @v data             Plaintext handshake record
+ * @v len              Length of plaintext handshake record
  * @ret rc             Return status code
  */
 static int tls_new_server_hello ( struct tls_session *tls,
                                  void *data, size_t len ) {
        struct {
-               uint32_t type_length;
                uint16_t version;
                uint8_t random[32];
                uint8_t session_id_len;
@@ -818,72 +845,74 @@ static int tls_new_server_hello ( struct tls_session *tls,
 }
 
 /**
- * Receive new Certificate record
+ * Receive new Certificate handshake record
  *
  * @v tls              TLS session
- * @v data             Plaintext record
- * @v len              Length of plaintext record
+ * @v data             Plaintext handshake record
+ * @v len              Length of plaintext handshake record
  * @ret rc             Return status code
  */
 static int tls_new_certificate ( struct tls_session *tls,
                                 void *data, size_t len ) {
        struct {
-               uint32_t type_length;
                uint8_t length[3];
-               uint8_t first_cert_length[3];
-               uint8_t asn1_start[0];
+               uint8_t certificates[0];
        } __attribute__ (( packed )) *certificate = data;
-       uint8_t *cert = certificate->asn1_start;
-       int offset = 0;
-
-       /* FIXME */
-       (void) len;
-
-       if (asn1_next_obj(cert, &offset, ASN1_SEQUENCE) < 0 ||
-           asn1_next_obj(cert, &offset, ASN1_SEQUENCE) < 0 ||
-            asn1_skip_obj(cert, &offset, ASN1_EXPLICIT_TAG) ||
-            asn1_skip_obj(cert, &offset, ASN1_INTEGER) ||
-            asn1_skip_obj(cert, &offset, ASN1_SEQUENCE) ||
-            asn1_skip_obj(cert, &offset, ASN1_SEQUENCE) ||
-            asn1_skip_obj(cert, &offset, ASN1_SEQUENCE) ||
-            asn1_skip_obj(cert, &offset, ASN1_SEQUENCE) ||
-           asn1_next_obj(cert, &offset, ASN1_SEQUENCE) < 0 ||
-            asn1_skip_obj(cert, &offset, ASN1_SEQUENCE) ||
-            asn1_next_obj(cert, &offset, ASN1_BIT_STRING) < 0) {
-               DBGC ( tls, "TLS %p invalid certificate\n", tls );
-               DBGC_HD ( tls, cert + offset, 64 );
-               return -EPERM;
-       }
-       
-       offset++;
-       
-       if (asn1_next_obj(cert, &offset, ASN1_SEQUENCE) < 0) {
-               DBGC ( tls, "TLS %p invalid certificate\n", tls );
-               DBGC_HD ( tls, cert + offset, 64 );
-               return -EPERM;
+       struct {
+               uint8_t length[3];
+               uint8_t certificate[0];
+       } __attribute__ (( packed )) *element =
+                 ( ( void * ) certificate->certificates );
+       size_t elements_len = tls_uint24 ( certificate->length );
+       void *end = ( certificate->certificates + elements_len );
+       struct asn1_cursor cursor;
+       int rc;
+
+       /* Sanity check */
+       if ( end != ( data + len ) ) {
+               DBGC ( tls, "TLS %p received overlength Server Certificate\n",
+                      tls );
+               DBGC_HD ( tls, data, len );
+               return -EINVAL;
        }
-       
-       tls->rsa_mod_len = asn1_get_int(cert, &offset, &tls->rsa_mod);
-       tls->rsa_pub_exp_len = asn1_get_int(cert, &offset, &tls->rsa_pub_exp);
-       
-       DBGC_HD ( tls, tls->rsa_mod, tls->rsa_mod_len );
-       DBGC_HD ( tls, tls->rsa_pub_exp, tls->rsa_pub_exp_len );
 
-       return 0;
+       /* Traverse certificate chain */
+       do {
+               cursor.data = element->certificate;
+               cursor.len = tls_uint24 ( element->length );
+               if ( ( cursor.data + cursor.len ) > end ) {
+                       DBGC ( tls, "TLS %p received corrupt Server "
+                              "Certificate\n", tls );
+                       DBGC_HD ( tls, data, len );
+                       return -EINVAL;
+               }
+
+               // HACK
+               if ( ( rc = x509_rsa_public_key ( &cursor,
+                                                 &tls->rsa ) ) != 0 ) {
+                       DBGC ( tls, "TLS %p cannot determine RSA public key: "
+                              "%s\n", tls, strerror ( rc ) );
+                       return rc;
+               }
+               return 0;
+
+               element = ( cursor.data + cursor.len );
+       } while ( element != end );
+
+       return -EINVAL;
 }
 
 /**
- * Receive new Server Hello Done record
+ * Receive new Server Hello Done handshake record
  *
  * @v tls              TLS session
- * @v data             Plaintext record
- * @v len              Length of plaintext record
+ * @v data             Plaintext handshake record
+ * @v len              Length of plaintext handshake record
  * @ret rc             Return status code
  */
 static int tls_new_server_hello_done ( struct tls_session *tls,
                                       void *data, size_t len ) {
        struct {
-               uint32_t type_length;
                char next[0];
        } __attribute__ (( packed )) *hello_done = data;
        void *end = hello_done->next;
@@ -910,11 +939,11 @@ static int tls_new_server_hello_done ( struct tls_session *tls,
 }
 
 /**
- * Receive new Finished record
+ * Receive new Finished handshake record
  *
  * @v tls              TLS session
- * @v data             Plaintext record
- * @v len              Length of plaintext record
+ * @v data             Plaintext handshake record
+ * @v len              Length of plaintext handshake record
  * @ret rc             Return status code
  */
 static int tls_new_finished ( struct tls_session *tls,
@@ -937,33 +966,47 @@ static int tls_new_finished ( struct tls_session *tls,
  */
 static int tls_new_handshake ( struct tls_session *tls,
                               void *data, size_t len ) {
-       uint8_t *type = data;
+       struct {
+               uint8_t type;
+               uint8_t length[3];
+               uint8_t payload[0];
+       } __attribute__ (( packed )) *handshake = data;
+       void *payload = &handshake->payload;
+       size_t payload_len = tls_uint24 ( handshake->length );
+       void *end = ( payload + payload_len );
        int rc;
 
-       switch ( *type ) {
+       /* Sanity check */
+       if ( end != ( data + len ) ) {
+               DBGC ( tls, "TLS %p received overlength Handshake\n", tls );
+               DBGC_HD ( tls, data, len );
+               return -EINVAL;
+       }
+
+       switch ( handshake->type ) {
        case TLS_SERVER_HELLO:
-               rc = tls_new_server_hello ( tls, data, len );
+               rc = tls_new_server_hello ( tls, payload, payload_len );
                break;
        case TLS_CERTIFICATE:
-               rc = tls_new_certificate ( tls, data, len );
+               rc = tls_new_certificate ( tls, payload, payload_len );
                break;
        case TLS_SERVER_HELLO_DONE:
-               rc = tls_new_server_hello_done ( tls, data, len );
+               rc = tls_new_server_hello_done ( tls, payload, payload_len );
                break;
        case TLS_FINISHED:
-               rc = tls_new_finished ( tls, data, len );
+               rc = tls_new_finished ( tls, payload, payload_len );
                break;
        default:
                DBGC ( tls, "TLS %p ignoring handshake type %d\n",
-                      tls, *type );
+                      tls, handshake->type );
                rc = 0;
                break;
        }
 
        /* Add to handshake digest (except for Hello Requests, which
-        * are explicitly excludede).
+        * are explicitly excluded).
         */
-       if ( *type != TLS_HELLO_REQUEST )
+       if ( handshake->type != TLS_HELLO_REQUEST )
                tls_add_handshake ( tls, data, len );
 
        return rc;