dtls_get_message changes for state machine move

Create a dtls_get_message function similar to the old dtls1_get_message but
in the format required for the new state machine code. The old function will
eventually be deleted in later commits.

Reviewed-by: Tim Hudson <tjh@openssl.org>
Reviewed-by: Richard Levitte <levitte@openssl.org>
This commit is contained in:
Matt Caswell 2015-08-11 11:41:03 +01:00
parent f6a2f2da58
commit 76af303761
5 changed files with 154 additions and 68 deletions

View file

@ -1928,6 +1928,7 @@ void ERR_load_SSL_strings(void);
# define SSL_F_DTLS1_SEND_SERVER_HELLO 266 # define SSL_F_DTLS1_SEND_SERVER_HELLO 266
# define SSL_F_DTLS1_SEND_SERVER_KEY_EXCHANGE 267 # define SSL_F_DTLS1_SEND_SERVER_KEY_EXCHANGE 267
# define SSL_F_DTLS1_WRITE_APP_DATA_BYTES 268 # define SSL_F_DTLS1_WRITE_APP_DATA_BYTES 268
# define SSL_F_DTLS_GET_REASSEMBLED_MESSAGE 370
# define SSL_F_READ_STATE_MACHINE 352 # define SSL_F_READ_STATE_MACHINE 352
# define SSL_F_SSL3_ACCEPT 128 # define SSL_F_SSL3_ACCEPT 128
# define SSL_F_SSL3_ADD_CERT_TO_BUF 296 # define SSL_F_SSL3_ADD_CERT_TO_BUF 296

View file

@ -161,7 +161,8 @@ static void dtls1_set_message_header_int(SSL *s, unsigned char mt,
unsigned long frag_off, unsigned long frag_off,
unsigned long frag_len); unsigned long frag_len);
static long dtls1_get_message_fragment(SSL *s, int st1, int stn, int mt, static long dtls1_get_message_fragment(SSL *s, int st1, int stn, int mt,
long max, int *ok); int *ok);
static int dtls_get_reassembled_message(SSL *s, long *len);
static hm_fragment *dtls1_hm_fragment_new(unsigned long frag_len, static hm_fragment *dtls1_hm_fragment_new(unsigned long frag_len,
int reassembly) int reassembly)
@ -481,7 +482,7 @@ long dtls1_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok)
memset(msg_hdr, 0, sizeof(*msg_hdr)); memset(msg_hdr, 0, sizeof(*msg_hdr));
again: again:
i = dtls1_get_message_fragment(s, st1, stn, mt, max, ok); i = dtls1_get_message_fragment(s, st1, stn, mt, ok);
if (i == DTLS1_HM_BAD_FRAGMENT || i == DTLS1_HM_FRAGMENT_RETRY) { if (i == DTLS1_HM_BAD_FRAGMENT || i == DTLS1_HM_FRAGMENT_RETRY) {
/* bad fragment received */ /* bad fragment received */
goto again; goto again;
@ -523,6 +524,12 @@ long dtls1_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok)
msg_len += DTLS1_HM_HEADER_LENGTH; msg_len += DTLS1_HM_HEADER_LENGTH;
} }
if (msg_len > (unsigned long)max) {
al = SSL_AD_ILLEGAL_PARAMETER;
SSLerr(SSL_F_DTLS1_GET_MESSAGE, SSL_R_EXCESSIVE_MESSAGE_SIZE);
goto f_err;
}
ssl3_finish_mac(s, p, msg_len); ssl3_finish_mac(s, p, msg_len);
if (s->msg_callback) if (s->msg_callback)
s->msg_callback(0, s->version, SSL3_RT_HANDSHAKE, s->msg_callback(0, s->version, SSL3_RT_HANDSHAKE,
@ -542,8 +549,72 @@ long dtls1_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok)
return -1; return -1;
} }
static int dtls1_preprocess_fragment(SSL *s, struct hm_header_st *msg_hdr, int dtls_get_message(SSL *s, int *mt, unsigned long *len)
int max) {
struct hm_header_st *msg_hdr;
unsigned char *p;
unsigned long msg_len;
int ok;
long tmplen;
msg_hdr = &s->d1->r_msg_hdr;
memset(msg_hdr, 0, sizeof(*msg_hdr));
again:
ok = dtls_get_reassembled_message(s, &tmplen);
if (tmplen == DTLS1_HM_BAD_FRAGMENT
|| tmplen == DTLS1_HM_FRAGMENT_RETRY) {
/* bad fragment received */
goto again;
} else if (tmplen <= 0 && !ok) {
return 0;
}
*mt = s->s3->tmp.message_type;
p = (unsigned char *)s->init_buf->data;
if (*mt == SSL3_MT_CHANGE_CIPHER_SPEC) {
if (s->msg_callback) {
s->msg_callback(0, s->version, SSL3_RT_CHANGE_CIPHER_SPEC,
p, 1, s, s->msg_callback_arg);
}
/*
* This isn't a real handshake message so skip the processing below.
*/
return 1;
}
msg_len = msg_hdr->msg_len;
/* reconstruct message header */
*(p++) = msg_hdr->type;
l2n3(msg_len, p);
s2n(msg_hdr->seq, p);
l2n3(0, p);
l2n3(msg_len, p);
if (s->version != DTLS1_BAD_VER) {
p -= DTLS1_HM_HEADER_LENGTH;
msg_len += DTLS1_HM_HEADER_LENGTH;
}
ssl3_finish_mac(s, p, msg_len);
if (s->msg_callback)
s->msg_callback(0, s->version, SSL3_RT_HANDSHAKE,
p, msg_len, s, s->msg_callback_arg);
memset(msg_hdr, 0, sizeof(*msg_hdr));
s->d1->handshake_read_seq++;
s->init_msg = s->init_buf->data + DTLS1_HM_HEADER_LENGTH;
*len = s->init_num;
return 1;
}
static int dtls1_preprocess_fragment(SSL *s, struct hm_header_st *msg_hdr)
{ {
size_t frag_off, frag_len, msg_len; size_t frag_off, frag_len, msg_len;
@ -557,11 +628,6 @@ static int dtls1_preprocess_fragment(SSL *s, struct hm_header_st *msg_hdr,
return SSL_AD_ILLEGAL_PARAMETER; return SSL_AD_ILLEGAL_PARAMETER;
} }
if ((frag_off + frag_len) > (unsigned long)max) {
SSLerr(SSL_F_DTLS1_PREPROCESS_FRAGMENT, SSL_R_EXCESSIVE_MESSAGE_SIZE);
return SSL_AD_ILLEGAL_PARAMETER;
}
if (s->d1->r_msg_hdr.frag_off == 0) { /* first fragment */ if (s->d1->r_msg_hdr.frag_off == 0) { /* first fragment */
/* /*
* msg_len is limited to 2^24, but is effectively checked against max * msg_len is limited to 2^24, but is effectively checked against max
@ -590,7 +656,7 @@ static int dtls1_preprocess_fragment(SSL *s, struct hm_header_st *msg_hdr,
return 0; /* no error */ return 0; /* no error */
} }
static int dtls1_retrieve_buffered_fragment(SSL *s, long max, int *ok) static int dtls1_retrieve_buffered_fragment(SSL *s, int *ok)
{ {
/*- /*-
* (0) check whether the desired fragment is available * (0) check whether the desired fragment is available
@ -617,7 +683,7 @@ static int dtls1_retrieve_buffered_fragment(SSL *s, long max, int *ok)
unsigned long frag_len = frag->msg_header.frag_len; unsigned long frag_len = frag->msg_header.frag_len;
pqueue_pop(s->d1->buffered_messages); pqueue_pop(s->d1->buffered_messages);
al = dtls1_preprocess_fragment(s, &frag->msg_header, max); al = dtls1_preprocess_fragment(s, &frag->msg_header);
if (al == 0) { /* no alert */ if (al == 0) { /* no alert */
unsigned char *p = unsigned char *p =
@ -859,19 +925,44 @@ dtls1_process_out_of_seq_message(SSL *s, const struct hm_header_st *msg_hdr,
} }
static long static long
dtls1_get_message_fragment(SSL *s, int st1, int stn, int mt, long max, int *ok) dtls1_get_message_fragment(SSL *s, int st1, int stn, int mt, int *ok)
{
long len;
do {
*ok = dtls_get_reassembled_message(s, &len);
/* A CCS isn't a real handshake message, so if we get one there is no
* message sequence number to give us confidence that this was really
* intended to be at this point in the handshake sequence. Therefore we
* only allow this if we were explicitly looking for it (i.e. if |mt|
* is -1 we still don't allow it). If we get one when we're not
* expecting it then probably something got re-ordered or this is a
* retransmit. We should drop this and try again.
*/
} while (*ok && mt != SSL3_MT_CHANGE_CIPHER_SPEC
&& s->s3->tmp.message_type == SSL3_MT_CHANGE_CIPHER_SPEC);
if (*ok)
s->state = stn;
return len;
}
static int dtls_get_reassembled_message(SSL *s, long *len)
{ {
unsigned char wire[DTLS1_HM_HEADER_LENGTH]; unsigned char wire[DTLS1_HM_HEADER_LENGTH];
unsigned long len, frag_off, frag_len; unsigned long mlen, frag_off, frag_len;
int i, al, recvd_type; int i, al, recvd_type;
struct hm_header_st msg_hdr; struct hm_header_st msg_hdr;
int ok;
redo: redo:
/* see if we have the required fragment already */ /* see if we have the required fragment already */
if ((frag_len = dtls1_retrieve_buffered_fragment(s, max, ok)) || *ok) { if ((frag_len = dtls1_retrieve_buffered_fragment(s, &ok)) || ok) {
if (*ok) if (ok)
s->init_num = frag_len; s->init_num = frag_len;
return frag_len; *len = frag_len;
return ok;
} }
/* read handshake message header */ /* read handshake message header */
@ -879,53 +970,37 @@ dtls1_get_message_fragment(SSL *s, int st1, int stn, int mt, long max, int *ok)
DTLS1_HM_HEADER_LENGTH, 0); DTLS1_HM_HEADER_LENGTH, 0);
if (i <= 0) { /* nbio, or an error */ if (i <= 0) { /* nbio, or an error */
s->rwstate = SSL_READING; s->rwstate = SSL_READING;
*ok = 0; *len = i;
return i; return 0;
} }
if(recvd_type == SSL3_RT_CHANGE_CIPHER_SPEC) { if(recvd_type == SSL3_RT_CHANGE_CIPHER_SPEC) {
/* This isn't a real handshake message - its a CCS. if (wire[0] != SSL3_MT_CCS) {
* There is no message sequence number in a CCS to give us confidence al = SSL_AD_UNEXPECTED_MESSAGE;
* that this was really intended to be at this point in the handshake SSLerr(SSL_F_DTLS_GET_REASSEMBLED_MESSAGE,
* sequence. Therefore we only allow this if we were explicitly looking SSL_R_BAD_CHANGE_CIPHER_SPEC);
* for it (i.e. if |mt| is -1 we still don't allow it). goto f_err;
*/
if(mt == SSL3_MT_CHANGE_CIPHER_SPEC) {
if (wire[0] != SSL3_MT_CCS) {
al = SSL_AD_UNEXPECTED_MESSAGE;
SSLerr(SSL_F_DTLS1_GET_MESSAGE_FRAGMENT, SSL_R_BAD_CHANGE_CIPHER_SPEC);
goto f_err;
}
memcpy(s->init_buf->data, wire, i);
s->init_num = i - 1;
s->init_msg = s->init_buf->data + 1;
s->s3->tmp.message_type = SSL3_MT_CHANGE_CIPHER_SPEC;
s->s3->tmp.message_size = i - 1;
s->state = stn;
*ok = 1;
return i-1;
} else {
/*
* We weren't expecting a CCS yet. Probably something got
* re-ordered or this is a retransmit. We should drop this and try
* again.
*/
s->init_num = 0;
goto redo;
} }
memcpy(s->init_buf->data, wire, i);
s->init_num = i - 1;
s->init_msg = s->init_buf->data + 1;
s->s3->tmp.message_type = SSL3_MT_CHANGE_CIPHER_SPEC;
s->s3->tmp.message_size = i - 1;
*len = i - 1;
return 1;
} }
/* Handshake fails if message header is incomplete */ /* Handshake fails if message header is incomplete */
if (i != DTLS1_HM_HEADER_LENGTH) { if (i != DTLS1_HM_HEADER_LENGTH) {
al = SSL_AD_UNEXPECTED_MESSAGE; al = SSL_AD_UNEXPECTED_MESSAGE;
SSLerr(SSL_F_DTLS1_GET_MESSAGE_FRAGMENT, SSL_R_UNEXPECTED_MESSAGE); SSLerr(SSL_F_DTLS_GET_REASSEMBLED_MESSAGE, SSL_R_UNEXPECTED_MESSAGE);
goto f_err; goto f_err;
} }
/* parse the message fragment header */ /* parse the message fragment header */
dtls1_get_message_header(wire, &msg_hdr); dtls1_get_message_header(wire, &msg_hdr);
len = msg_hdr.msg_len; mlen = msg_hdr.msg_len;
frag_off = msg_hdr.frag_off; frag_off = msg_hdr.frag_off;
frag_len = msg_hdr.frag_len; frag_len = msg_hdr.frag_len;
@ -935,7 +1010,7 @@ dtls1_get_message_fragment(SSL *s, int st1, int stn, int mt, long max, int *ok)
*/ */
if (frag_len > RECORD_LAYER_get_rrec_length(&s->rlayer)) { if (frag_len > RECORD_LAYER_get_rrec_length(&s->rlayer)) {
al = SSL3_AD_ILLEGAL_PARAMETER; al = SSL3_AD_ILLEGAL_PARAMETER;
SSLerr(SSL_F_DTLS1_GET_MESSAGE_FRAGMENT, SSL_R_BAD_LENGTH); SSLerr(SSL_F_DTLS_GET_REASSEMBLED_MESSAGE, SSL_R_BAD_LENGTH);
goto f_err; goto f_err;
} }
@ -945,11 +1020,15 @@ dtls1_get_message_fragment(SSL *s, int st1, int stn, int mt, long max, int *ok)
* While listening, we accept seq 1 (ClientHello with cookie) * While listening, we accept seq 1 (ClientHello with cookie)
* although we're still expecting seq 0 (ClientHello) * although we're still expecting seq 0 (ClientHello)
*/ */
if (msg_hdr.seq != s->d1->handshake_read_seq) if (msg_hdr.seq != s->d1->handshake_read_seq) {
return dtls1_process_out_of_seq_message(s, &msg_hdr, ok); *len = dtls1_process_out_of_seq_message(s, &msg_hdr, &ok);
return ok;
}
if (frag_len && frag_len < len) if (frag_len && frag_len < mlen) {
return dtls1_reassemble_fragment(s, &msg_hdr, ok); *len = dtls1_reassemble_fragment(s, &msg_hdr, &ok);
return ok;
}
if (!s->server && s->d1->r_msg_hdr.frag_off == 0 && if (!s->server && s->d1->r_msg_hdr.frag_off == 0 &&
wire[0] == SSL3_MT_HELLO_REQUEST) { wire[0] == SSL3_MT_HELLO_REQUEST) {
@ -969,13 +1048,13 @@ dtls1_get_message_fragment(SSL *s, int st1, int stn, int mt, long max, int *ok)
} else { /* Incorrectly formated Hello request */ } else { /* Incorrectly formated Hello request */
al = SSL_AD_UNEXPECTED_MESSAGE; al = SSL_AD_UNEXPECTED_MESSAGE;
SSLerr(SSL_F_DTLS1_GET_MESSAGE_FRAGMENT, SSLerr(SSL_F_DTLS_GET_REASSEMBLED_MESSAGE,
SSL_R_UNEXPECTED_MESSAGE); SSL_R_UNEXPECTED_MESSAGE);
goto f_err; goto f_err;
} }
} }
if ((al = dtls1_preprocess_fragment(s, &msg_hdr, max))) if ((al = dtls1_preprocess_fragment(s, &msg_hdr)))
goto f_err; goto f_err;
if (frag_len > 0) { if (frag_len > 0) {
@ -991,8 +1070,8 @@ dtls1_get_message_fragment(SSL *s, int st1, int stn, int mt, long max, int *ok)
*/ */
if (i <= 0) { if (i <= 0) {
s->rwstate = SSL_READING; s->rwstate = SSL_READING;
*ok = 0; *len = i;
return i; return 0;
} }
} else } else
i = 0; i = 0;
@ -1003,28 +1082,24 @@ dtls1_get_message_fragment(SSL *s, int st1, int stn, int mt, long max, int *ok)
*/ */
if (i != (int)frag_len) { if (i != (int)frag_len) {
al = SSL3_AD_ILLEGAL_PARAMETER; al = SSL3_AD_ILLEGAL_PARAMETER;
SSLerr(SSL_F_DTLS1_GET_MESSAGE_FRAGMENT, SSL3_AD_ILLEGAL_PARAMETER); SSLerr(SSL_F_DTLS_GET_REASSEMBLED_MESSAGE, SSL3_AD_ILLEGAL_PARAMETER);
goto f_err; goto f_err;
} }
*ok = 1;
s->state = stn;
/* /*
* Note that s->init_num is *not* used as current offset in * Note that s->init_num is *not* used as current offset in
* s->init_buf->data, but as a counter summing up fragments' lengths: as * s->init_buf->data, but as a counter summing up fragments' lengths: as
* soon as they sum up to handshake packet length, we assume we have got * soon as they sum up to handshake packet length, we assume we have got
* all the fragments. * all the fragments.
*/ */
s->init_num = frag_len; *len = s->init_num = frag_len;
return frag_len; return 1;
f_err: f_err:
ssl3_send_alert(s, SSL3_AL_FATAL, al); ssl3_send_alert(s, SSL3_AL_FATAL, al);
s->init_num = 0; s->init_num = 0;
*len = -1;
*ok = 0; return 0;
return (-1);
} }
/*- /*-

View file

@ -112,6 +112,8 @@ static ERR_STRING_DATA SSL_str_functs[] = {
{ERR_FUNC(SSL_F_DTLS1_SEND_SERVER_KEY_EXCHANGE), {ERR_FUNC(SSL_F_DTLS1_SEND_SERVER_KEY_EXCHANGE),
"dtls1_send_server_key_exchange"}, "dtls1_send_server_key_exchange"},
{ERR_FUNC(SSL_F_DTLS1_WRITE_APP_DATA_BYTES), "dtls1_write_app_data_bytes"}, {ERR_FUNC(SSL_F_DTLS1_WRITE_APP_DATA_BYTES), "dtls1_write_app_data_bytes"},
{ERR_FUNC(SSL_F_DTLS_GET_REASSEMBLED_MESSAGE),
"DTLS_GET_REASSEMBLED_MESSAGE"},
{ERR_FUNC(SSL_F_READ_STATE_MACHINE), "READ_STATE_MACHINE"}, {ERR_FUNC(SSL_F_READ_STATE_MACHINE), "READ_STATE_MACHINE"},
{ERR_FUNC(SSL_F_SSL3_ACCEPT), "ssl3_accept"}, {ERR_FUNC(SSL_F_SSL3_ACCEPT), "ssl3_accept"},
{ERR_FUNC(SSL_F_SSL3_ADD_CERT_TO_BUF), "SSL3_ADD_CERT_TO_BUF"}, {ERR_FUNC(SSL_F_SSL3_ADD_CERT_TO_BUF), "SSL3_ADD_CERT_TO_BUF"},

View file

@ -2236,6 +2236,7 @@ long dtls1_ctrl(SSL *s, int cmd, long larg, void *parg);
__owur int dtls1_shutdown(SSL *s); __owur int dtls1_shutdown(SSL *s);
__owur long dtls1_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok); __owur long dtls1_get_message(SSL *s, int st1, int stn, int mt, long max, int *ok);
__owur int dtls_get_message(SSL *s, int *mt, unsigned long *len);
__owur int dtls1_dispatch_alert(SSL *s); __owur int dtls1_dispatch_alert(SSL *s);
__owur int ssl_init_wbio_buffer(SSL *s, int push); __owur int ssl_init_wbio_buffer(SSL *s, int push);

View file

@ -464,7 +464,14 @@ static enum SUB_STATE_RETURN read_state_machine(SSL *s) {
case READ_STATE_HEADER: case READ_STATE_HEADER:
s->init_num = 0; s->init_num = 0;
/* Get the state the peer wants to move to */ /* Get the state the peer wants to move to */
ret = tls_get_message_header(s, &mt); if (SSL_IS_DTLS(s)) {
/*
* In DTLS we get the whole message in one go - header and body
*/
ret = dtls_get_message(s, &mt, &len);
} else {
ret = tls_get_message_header(s, &mt);
}
if (ret == 0) { if (ret == 0) {
/* Could be non-blocking IO */ /* Could be non-blocking IO */