Skip to content
Snippets Groups Projects
Commit 0b3d6c31 authored by Martin Balao's avatar Martin Balao Committed by Christoph Langer
Browse files

8263031: HttpClient throws Exception if it receives a Push Promise that is too large

Reviewed-by: abakhtin, mbaesken
Backport-of: 4d2cd26ab5092ad0a169e4239164a869a4255bd3
parent e59e2e25
No related branches found
No related tags found
No related merge requests found
...@@ -263,6 +263,8 @@ class Http2Connection { ...@@ -263,6 +263,8 @@ class Http2Connection {
private final Decoder hpackIn; private final Decoder hpackIn;
final SettingsFrame clientSettings; final SettingsFrame clientSettings;
private volatile SettingsFrame serverSettings; private volatile SettingsFrame serverSettings;
private record PushContinuationState(HeaderDecoder pushContDecoder, PushPromiseFrame pushContFrame) {}
private volatile PushContinuationState pushContinuationState;
private final String key; // for HttpClientImpl.connections map private final String key; // for HttpClientImpl.connections map
private final FramesDecoder framesDecoder; private final FramesDecoder framesDecoder;
private final FramesEncoder framesEncoder = new FramesEncoder(); private final FramesEncoder framesEncoder = new FramesEncoder();
...@@ -773,8 +775,8 @@ class Http2Connection { ...@@ -773,8 +775,8 @@ class Http2Connection {
} }
if (!(frame instanceof ResetFrame)) { if (!(frame instanceof ResetFrame)) {
if (frame instanceof DataFrame) { if (frame instanceof DataFrame df) {
dropDataFrame((DataFrame)frame); dropDataFrame(df);
} }
if (isServerInitiatedStream(streamid)) { if (isServerInitiatedStream(streamid)) {
if (streamid < nextPushStream) { if (streamid < nextPushStream) {
...@@ -791,21 +793,38 @@ class Http2Connection { ...@@ -791,21 +793,38 @@ class Http2Connection {
} }
return; return;
} }
if (frame instanceof PushPromiseFrame) {
PushPromiseFrame pp = (PushPromiseFrame)frame; // While push frame is not null, the only acceptable frame on this
// stream is a Continuation frame
if (pushContinuationState != null) {
if (frame instanceof ContinuationFrame cf) {
try {
handlePushContinuation(stream, cf);
} catch (UncheckedIOException e) {
debug.log("Error handling Push Promise with Continuation: " + e.getMessage(), e);
protocolError(ErrorFrame.PROTOCOL_ERROR, e.getMessage());
return;
}
} else {
pushContinuationState = null;
protocolError(ErrorFrame.PROTOCOL_ERROR, "Expected a Continuation frame but received " + frame);
return;
}
} else {
if (frame instanceof PushPromiseFrame pp) {
try { try {
handlePushPromise(stream, pp); handlePushPromise(stream, pp);
} catch (UncheckedIOException e) { } catch (UncheckedIOException e) {
protocolError(ResetFrame.PROTOCOL_ERROR, e.getMessage()); protocolError(ErrorFrame.PROTOCOL_ERROR, e.getMessage());
return; return;
} }
} else if (frame instanceof HeaderFrame) { } else if (frame instanceof HeaderFrame hf) {
// decode headers (or continuation) // decode headers
try { try {
decodeHeaders((HeaderFrame) frame, stream.rspHeadersConsumer()); decodeHeaders(hf, stream.rspHeadersConsumer());
} catch (UncheckedIOException e) { } catch (UncheckedIOException e) {
debug.log("Error decoding headers: " + e.getMessage(), e); debug.log("Error decoding headers: " + e.getMessage(), e);
protocolError(ResetFrame.PROTOCOL_ERROR, e.getMessage()); protocolError(ErrorFrame.PROTOCOL_ERROR, e.getMessage());
return; return;
} }
stream.incoming(frame); stream.incoming(frame);
...@@ -814,6 +833,7 @@ class Http2Connection { ...@@ -814,6 +833,7 @@ class Http2Connection {
} }
} }
} }
}
final void dropDataFrame(DataFrame df) { final void dropDataFrame(DataFrame df) {
if (closed) return; if (closed) return;
...@@ -841,11 +861,34 @@ class Http2Connection { ...@@ -841,11 +861,34 @@ class Http2Connection {
{ {
// always decode the headers as they may affect connection-level HPACK // always decode the headers as they may affect connection-level HPACK
// decoding state // decoding state
assert pushContinuationState == null;
HeaderDecoder decoder = new HeaderDecoder(); HeaderDecoder decoder = new HeaderDecoder();
decodeHeaders(pp, decoder); decodeHeaders(pp, decoder);
int promisedStreamid = pp.getPromisedStream();
if (pp.endHeaders()) {
completePushPromise(promisedStreamid, parent, decoder.headers());
} else {
pushContinuationState = new PushContinuationState(decoder, pp);
}
}
private <T> void handlePushContinuation(Stream<T> parent, ContinuationFrame cf)
throws IOException {
var pcs = pushContinuationState;
decodeHeaders(cf, pcs.pushContDecoder);
// if all continuations are sent, set pushWithContinuation to null
if (cf.endHeaders()) {
completePushPromise(pcs.pushContFrame.getPromisedStream(), parent,
pcs.pushContDecoder.headers());
pushContinuationState = null;
}
}
private <T> void completePushPromise(int promisedStreamid, Stream<T> parent, HttpHeaders headers)
throws IOException {
// Perhaps the following checks could be moved to handlePushPromise()
// to reset the PushPromise stream earlier?
HttpRequestImpl parentReq = parent.request; HttpRequestImpl parentReq = parent.request;
int promisedStreamid = pp.getPromisedStream();
if (promisedStreamid != nextPushStream) { if (promisedStreamid != nextPushStream) {
resetStream(promisedStreamid, ResetFrame.PROTOCOL_ERROR); resetStream(promisedStreamid, ResetFrame.PROTOCOL_ERROR);
return; return;
...@@ -856,7 +899,6 @@ class Http2Connection { ...@@ -856,7 +899,6 @@ class Http2Connection {
nextPushStream += 2; nextPushStream += 2;
} }
HttpHeaders headers = decoder.headers();
HttpRequestImpl pushReq = HttpRequestImpl.createPushRequest(parentReq, headers); HttpRequestImpl pushReq = HttpRequestImpl.createPushRequest(parentReq, headers);
Exchange<T> pushExch = new Exchange<>(pushReq, parent.exchange.multi); Exchange<T> pushExch = new Exchange<>(pushReq, parent.exchange.multi);
Stream.PushedStream<T> pushStream = createPushStream(parent, pushExch); Stream.PushedStream<T> pushStream = createPushStream(parent, pushExch);
......
/*
* Copyright (c) 2022, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* This code is free software; you can redistribute it and/or modify it
* under the terms of the GNU General Public License version 2 only, as
* published by the Free Software Foundation.
*
* This code is distributed in the hope that it will be useful, but WITHOUT
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
* version 2 for more details (a copy is included in the LICENSE file that
* accompanied this code).
*
* You should have received a copy of the GNU General Public License version
* 2 along with this work; if not, write to the Free Software Foundation,
* Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
*
* Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
* or visit www.oracle.com if you need additional information or have any
* questions.
*/
/*
* @test
* @bug 8263031
* @summary Tests that the HttpClient can correctly receive a Push Promise
* Frame with the END_HEADERS flag unset followed by one or more
* Continuation Frames.
* @library /test/lib /test/jdk/java/net/httpclient/lib
* @build jdk.test.lib.net.SimpleSSLContext jdk.httpclient.test.lib.http2.Http2TestServer
* jdk.httpclient.test.lib.http2.BodyOutputStream
* jdk.httpclient.test.lib.http2.OutgoingPushPromise
* @run testng/othervm PushPromiseContinuation
*/
import jdk.internal.net.http.common.HttpHeadersBuilder;
import jdk.internal.net.http.frame.ContinuationFrame;
import jdk.internal.net.http.frame.HeaderFrame;
import org.testng.TestException;
import org.testng.annotations.AfterTest;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.BeforeTest;
import org.testng.annotations.Test;
import javax.net.ssl.SSLSession;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpHeaders;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentMap;
import java.util.function.BiPredicate;
import jdk.httpclient.test.lib.http2.Http2TestServer;
import jdk.httpclient.test.lib.http2.Http2TestExchange;
import jdk.httpclient.test.lib.http2.Http2TestExchangeImpl;
import jdk.httpclient.test.lib.http2.Http2Handler;
import jdk.httpclient.test.lib.http2.BodyOutputStream;
import jdk.httpclient.test.lib.http2.OutgoingPushPromise;
import jdk.httpclient.test.lib.http2.Http2TestServerConnection;
import static java.nio.charset.StandardCharsets.UTF_8;
import static org.testng.Assert.assertEquals;
public class PushPromiseContinuation {
static volatile HttpHeaders testHeaders;
static volatile HttpHeadersBuilder testHeadersBuilder;
static volatile int continuationCount;
static final String mainPromiseBody = "Main Promise Body";
static final String mainResponseBody = "Main Response Body";
Http2TestServer server;
URI uri;
// Set up simple client-side push promise handler
ConcurrentMap<HttpRequest, CompletableFuture<HttpResponse<String>>> pushPromiseMap = new ConcurrentHashMap<>();
HttpResponse.PushPromiseHandler<String> pph = (initial, pushRequest, acceptor) -> {
HttpResponse.BodyHandler<String> s = HttpResponse.BodyHandlers.ofString(UTF_8);
pushPromiseMap.put(pushRequest, acceptor.apply(s));
};
@BeforeMethod
public void beforeMethod() {
pushPromiseMap = new ConcurrentHashMap<>();
}
@BeforeTest
public void setup() throws Exception {
server = new Http2TestServer(false, 0);
server.addHandler(new ServerPushHandler(), "/");
// Need to have a custom exchange supplier to manage the server's push
// promise with continuation flow
server.setExchangeSupplier(Http2LPPTestExchangeImpl::new);
System.err.println("PushPromiseContinuation: Server listening on port " + server.getAddress().getPort());
server.start();
int port = server.getAddress().getPort();
uri = new URI("http://localhost:" + port + "/");
}
@AfterTest
public void teardown() {
pushPromiseMap = null;
server.stop();
}
/**
* Tests that when the client receives PushPromise Frame with the END_HEADERS
* flag set to 0x0 and subsequently receives a continuation frame, no exception
* is thrown and all headers from the PushPromise and Continuation Frames sent
* by the server arrive at the client.
*/
@Test
public void testOneContinuation() {
continuationCount = 1;
HttpClient client = HttpClient.newHttpClient();
// Carry out request
HttpRequest hreq = HttpRequest.newBuilder(uri).version(HttpClient.Version.HTTP_2).GET().build();
CompletableFuture<HttpResponse<String>> cf =
client.sendAsync(hreq, HttpResponse.BodyHandlers.ofString(UTF_8), pph);
HttpResponse<String> resp = cf.join();
// Verify results
verify(resp);
}
/**
* Same as above, but tests for the case where two Continuation Frames are sent
* with the END_HEADERS flag set only on the last frame.
*/
@Test
public void testTwoContinuations() {
continuationCount = 2;
HttpClient client = HttpClient.newHttpClient();
// Carry out request
HttpRequest hreq = HttpRequest.newBuilder(uri).version(HttpClient.Version.HTTP_2).GET().build();
CompletableFuture<HttpResponse<String>> cf =
client.sendAsync(hreq, HttpResponse.BodyHandlers.ofString(UTF_8), pph);
HttpResponse<String> resp = cf.join();
// Verify results
verify(resp);
}
@Test
public void testThreeContinuations() {
continuationCount = 3;
HttpClient client = HttpClient.newHttpClient();
// Carry out request
HttpRequest hreq = HttpRequest.newBuilder(uri).version(HttpClient.Version.HTTP_2).GET().build();
CompletableFuture<HttpResponse<String>> cf =
client.sendAsync(hreq, HttpResponse.BodyHandlers.ofString(UTF_8), pph);
HttpResponse<String> resp = cf.join();
// Verify results
verify(resp);
}
private void verify(HttpResponse<String> resp) {
assertEquals(resp.statusCode(), 200);
assertEquals(resp.body(), mainResponseBody);
if (pushPromiseMap.size() > 1) {
System.err.println(pushPromiseMap.entrySet());
throw new TestException("Results map size is greater than 1");
} else {
// This will only iterate once
for (HttpRequest r : pushPromiseMap.keySet()) {
HttpResponse<String> serverPushResp = pushPromiseMap.get(r).join();
// Received headers should be the same as the combined PushPromise
// frame headers combined with the Continuation frame headers
assertEquals(testHeaders, r.headers());
// Check status code and push promise body are as expected
assertEquals(serverPushResp.statusCode(), 200);
assertEquals(serverPushResp.body(), mainPromiseBody);
}
}
}
static class Http2LPPTestExchangeImpl extends Http2TestExchangeImpl {
HttpHeadersBuilder pushPromiseHeadersBuilder;
List<ContinuationFrame> cfs;
Http2LPPTestExchangeImpl(int streamid, String method, HttpHeaders reqheaders,
HttpHeadersBuilder rspheadersBuilder, URI uri, InputStream is,
SSLSession sslSession, BodyOutputStream os,
Http2TestServerConnection conn, boolean pushAllowed) {
super(streamid, method, reqheaders, rspheadersBuilder, uri, is, sslSession, os, conn, pushAllowed);
}
private void setPushHeaders(String name, String value) {
pushPromiseHeadersBuilder.setHeader(name, value);
testHeadersBuilder.setHeader(name, value);
}
private void assembleContinuations() {
for (int i = 0; i < continuationCount; i++) {
HttpHeadersBuilder builder = new HttpHeadersBuilder();
for (int j = 0; j < 10; j++) {
String name = "x-cont-" + i + "-" + j;
builder.setHeader(name, "data_" + j);
testHeadersBuilder.setHeader(name, "data_" + j);
}
ContinuationFrame cf = new ContinuationFrame(streamid, 0x0, conn.encodeHeaders(builder.build()));
// If this is the last Continuation Frame, set the END_HEADERS flag.
if (i >= continuationCount - 1) {
cf.setFlag(HeaderFrame.END_HEADERS);
}
cfs.add(cf);
}
}
@Override
public void serverPush(URI uri, HttpHeaders headers, InputStream content) {
pushPromiseHeadersBuilder = new HttpHeadersBuilder();
testHeadersBuilder = new HttpHeadersBuilder();
cfs = new ArrayList<>();
setPushHeaders(":method", "GET");
setPushHeaders(":scheme", uri.getScheme());
setPushHeaders(":authority", uri.getAuthority());
setPushHeaders(":path", uri.getPath());
for (Map.Entry<String,List<String>> entry : headers.map().entrySet()) {
for (String value : entry.getValue()) {
setPushHeaders(entry.getKey(), value);
}
}
for (int i = 0; i < 10; i++) {
setPushHeaders("x-push-header-" + i, "data_" + i);
}
// Create the Continuation Frame/s, done before Push Promise Frame for test purposes
// as testHeaders contains all headers used in all frames
assembleContinuations();
HttpHeaders pushPromiseHeaders = pushPromiseHeadersBuilder.build();
testHeaders = testHeadersBuilder.build();
// Create the Push Promise Frame
OutgoingPushPromise pp = new OutgoingPushPromise(streamid, uri, pushPromiseHeaders, content);
// Indicates to the client that a continuation should be expected
pp.setFlag(0x0);
try {
// Schedule push promise and continuation for sending
conn.addToOutputQ(pp);
System.err.println("Server: Scheduled a Push Promise to Send");
for (ContinuationFrame cf : cfs) {
conn.addToOutputQ(cf);
System.err.println("Server: Scheduled a Continuation to Send");
}
} catch (IOException ex) {
System.err.println("Server: pushPromise exception: " + ex);
}
}
}
static class ServerPushHandler implements Http2Handler {
public void handle(Http2TestExchange exchange) throws IOException {
System.err.println("Server: handle " + exchange);
try (InputStream is = exchange.getRequestBody()) {
is.readAllBytes();
}
if (exchange.serverPushAllowed()) {
pushPromise(exchange);
}
// response data for the main response
try (OutputStream os = exchange.getResponseBody()) {
byte[] bytes = mainResponseBody.getBytes(UTF_8);
exchange.sendResponseHeaders(200, bytes.length);
os.write(bytes);
}
}
static final BiPredicate<String,String> ACCEPT_ALL = (x, y) -> true;
private void pushPromise(Http2TestExchange exchange) throws IOException {
URI requestURI = exchange.getRequestURI();
URI uri = requestURI.resolve("/promise");
InputStream is = new ByteArrayInputStream(mainPromiseBody.getBytes(UTF_8));
Map<String, List<String>> map = new HashMap<>();
map.put("x-promise", List.of("promise-header"));
HttpHeaders headers = HttpHeaders.of(map, ACCEPT_ALL);
exchange.serverPush(uri, headers, is);
System.err.println("Server: Push Promise complete");
}
}
}
...@@ -23,19 +23,20 @@ ...@@ -23,19 +23,20 @@
package jdk.httpclient.test.lib.http2; package jdk.httpclient.test.lib.http2;
import jdk.internal.net.http.common.HttpHeadersBuilder;
import jdk.internal.net.http.frame.HeaderFrame;
import jdk.internal.net.http.frame.HeadersFrame;
import javax.net.ssl.SSLSession;
import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.io.IOException;
import java.net.URI;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.URI;
import java.net.http.HttpHeaders; import java.net.http.HttpHeaders;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import javax.net.ssl.SSLSession;
import jdk.internal.net.http.common.HttpHeadersBuilder;
import jdk.internal.net.http.frame.HeaderFrame;
import jdk.internal.net.http.frame.HeadersFrame;
public class Http2TestExchangeImpl implements Http2TestExchange { public class Http2TestExchangeImpl implements Http2TestExchange {
...@@ -193,6 +194,7 @@ public class Http2TestExchangeImpl implements Http2TestExchange { ...@@ -193,6 +194,7 @@ public class Http2TestExchangeImpl implements Http2TestExchange {
} }
HttpHeaders combinedHeaders = headersBuilder.build(); HttpHeaders combinedHeaders = headersBuilder.build();
OutgoingPushPromise pp = new OutgoingPushPromise(streamid, uri, combinedHeaders, content); OutgoingPushPromise pp = new OutgoingPushPromise(streamid, uri, combinedHeaders, content);
pp.setFlag(HeaderFrame.END_HEADERS);
try { try {
conn.outputQ.put(pp); conn.outputQ.put(pp);
......
...@@ -23,32 +23,63 @@ ...@@ -23,32 +23,63 @@
package jdk.httpclient.test.lib.http2; package jdk.httpclient.test.lib.http2;
import jdk.internal.net.http.common.HttpHeadersBuilder;
import jdk.internal.net.http.frame.DataFrame;
import jdk.internal.net.http.frame.ErrorFrame;
import jdk.internal.net.http.frame.FramesDecoder;
import jdk.internal.net.http.frame.FramesEncoder;
import jdk.internal.net.http.frame.GoAwayFrame;
import jdk.internal.net.http.frame.HeaderFrame;
import jdk.internal.net.http.frame.HeadersFrame;
import jdk.internal.net.http.frame.Http2Frame;
import jdk.internal.net.http.frame.PingFrame;
import jdk.internal.net.http.frame.PushPromiseFrame;
import jdk.internal.net.http.frame.ResetFrame;
import jdk.internal.net.http.frame.SettingsFrame;
import jdk.internal.net.http.frame.WindowUpdateFrame;
import jdk.internal.net.http.hpack.Decoder;
import jdk.internal.net.http.hpack.DecodingCallback;
import jdk.internal.net.http.hpack.Encoder;
import sun.net.www.http.ChunkedInputStream;
import sun.net.www.http.HttpClient;
import javax.net.ssl.SNIHostName;
import javax.net.ssl.SNIMatcher;
import javax.net.ssl.SNIServerName;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSession;
import javax.net.ssl.SSLSocket;
import javax.net.ssl.StandardConstants;
import java.io.BufferedInputStream; import java.io.BufferedInputStream;
import java.io.BufferedOutputStream; import java.io.BufferedOutputStream;
import java.io.Closeable; import java.io.Closeable;
import java.io.IOException; import java.io.IOException;
import java.io.UncheckedIOException;
import java.io.InputStream; import java.io.InputStream;
import java.io.OutputStream; import java.io.OutputStream;
import java.io.UncheckedIOException;
import java.net.InetAddress;
import java.net.Socket; import java.net.Socket;
import java.net.URI; import java.net.URI;
import java.net.InetAddress;
import javax.net.ssl.*;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.net.http.HttpHeaders; import java.net.http.HttpHeaders;
import java.nio.ByteBuffer; import java.nio.ByteBuffer;
import java.util.*; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Properties;
import java.util.Random;
import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.ConcurrentLinkedQueue; import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.ExecutorService;
import java.util.function.Consumer; import java.util.function.Consumer;
import jdk.internal.net.http.common.HttpHeadersBuilder;
import jdk.internal.net.http.frame.*;
import jdk.internal.net.http.hpack.Decoder;
import jdk.internal.net.http.hpack.DecodingCallback;
import jdk.internal.net.http.hpack.Encoder;
import sun.net.www.http.ChunkedInputStream;
import sun.net.www.http.HttpClient;
import static java.nio.charset.StandardCharsets.ISO_8859_1; import static java.nio.charset.StandardCharsets.ISO_8859_1;
import static java.nio.charset.StandardCharsets.UTF_8; import static java.nio.charset.StandardCharsets.UTF_8;
import static jdk.internal.net.http.frame.SettingsFrame.HEADER_TABLE_SIZE; import static jdk.internal.net.http.frame.SettingsFrame.HEADER_TABLE_SIZE;
...@@ -912,7 +943,7 @@ public class Http2TestServerConnection { ...@@ -912,7 +943,7 @@ public class Http2TestServerConnection {
private void handlePush(OutgoingPushPromise op) throws IOException { private void handlePush(OutgoingPushPromise op) throws IOException {
int promisedStreamid = nextPushStreamId; int promisedStreamid = nextPushStreamId;
PushPromiseFrame pp = new PushPromiseFrame(op.parentStream, PushPromiseFrame pp = new PushPromiseFrame(op.parentStream,
HeaderFrame.END_HEADERS, op.getFlags(),
promisedStreamid, promisedStreamid,
encodeHeaders(op.headers), encodeHeaders(op.headers),
0); 0);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment