--- a/net/sctp/socket.c +++ b/net/sctp/socket.c @@ -1796,7 +1796,7 @@ static int sctp_sendmsg_to_asoc(struct sctp_association *asoc, bool wait_connect = false; struct sctp_chunk *chunk; long timeo; - int err; + int err, mem = 0; if (sinfo->sinfo_stream >= asoc->stream.outcnt) { err = -EINVAL; @@ -1807,11 +1807,12 @@ static int sctp_sendmsg_to_asoc(struct sctp_association *asoc, err = sctp_stream_init_ext(&asoc->stream, sinfo->sinfo_stream); if (err) goto err; + mem = 1; } if (sp->disable_fragments && msg_len > asoc->frag_point) { err = -EMSGSIZE; - goto err; + goto free_mem; } if (asoc->pmtu_pending) { @@ -1883,6 +1884,10 @@ static int sctp_sendmsg_to_asoc(struct sctp_association *asoc, err = msg_len; +free_mem: + if (mem) + sctp_stream_prio_free(&asoc->stream); + err: return err; } --- a/include/net/sctp/structs.h +++ b/include/net/sctp/structs.h @@ -396,6 +396,8 @@ void sctp_stream_free(struct sctp_stream *stream); void sctp_stream_clear(struct sctp_stream *stream); void sctp_stream_update(struct sctp_stream *stream, struct sctp_stream *new); +void sctp_stream_prio_free(struct sctp_stream *stream); + /* What is the current SSN number for this stream? */ #define sctp_ssn_peek(stream, type, sid) \ (sctp_stream_##type((stream), (sid))->ssn) --- a/net/sctp/stream.c +++ b/net/sctp/stream.c @@ -182,12 +182,18 @@ int sctp_stream_init_ext(struct sctp_stream *stream, __u16 sid) return ret; } -void sctp_stream_free(struct sctp_stream *stream) +void sctp_stream_prio_free(struct sctp_stream *stream) { struct sctp_sched_ops *sched = sctp_sched_ops_from_stream(stream); + sched->free(stream); +} + +void sctp_stream_free(struct sctp_stream *stream) +{ int i; - sched->free(stream); + sctp_stream_prio_free(stream); + for (i = 0; i < stream->outcnt; i++) kfree(SCTP_SO(stream, i)->ext); genradix_free(&stream->out); --- a/net/sctp/stream_sched_prio.c +++ b/net/sctp/stream_sched_prio.c @@ -30,7 +30,7 @@ static struct sctp_stream_priorities *sctp_sched_prio_new_head( { struct sctp_stream_priorities *p; - p = kmalloc(sizeof(*p), gfp); + p = kzalloc(sizeof(*p), gfp); if (!p) return NULL;