Moved the safe_read() and safe_write() functions into the sock.c file

since they're now used in other places.

Added support for a true upstream proxy connection. This involved some
rewriting of the handle_connection() function and some of the support
functions so that they do perform the domain filtering and anonymous
filtering while still connecting to the upstream proxy. I think the code
should be cleaned up further.
master
Robert James Kaes 2001-09-16 20:10:19 +00:00
parent a8f0272ace
commit 08baf6b01b
1 changed files with 136 additions and 144 deletions

View File

@ -1,4 +1,4 @@
/* $Id: reqs.c,v 1.27 2001-09-15 21:26:14 rjkaes Exp $ /* $Id: reqs.c,v 1.28 2001-09-16 20:10:19 rjkaes Exp $
* *
* This is where all the work in tinyproxy is actually done. Incoming * This is where all the work in tinyproxy is actually done. Incoming
* connections have a new thread created for them. The thread then * connections have a new thread created for them. The thread then
@ -42,36 +42,6 @@
#define LINE_LENGTH (MAXBUFFSIZE / 3) #define LINE_LENGTH (MAXBUFFSIZE / 3)
/*
* Write the buffer to the socket. If an EINTR occurs, pick up and try
* again.
*/
static ssize_t safe_write(int fd, const void *buffer, size_t count)
{
ssize_t len;
do {
len = write(fd, buffer, count);
} while (len < 0 && errno == EINTR);
return len;
}
/*
* Matched pair for safe_write(). If an EINTR occurs, pick up and try
* again.
*/
static ssize_t safe_read(int fd, void *buffer, size_t count)
{
ssize_t len;
do {
len = read(fd, buffer, count);
} while (len < 0 && errno == EINTR);
return len;
}
/* /*
* Remove any new lines or carriage returns from the end of a string. * Remove any new lines or carriage returns from the end of a string.
*/ */
@ -131,11 +101,25 @@ static char *read_request_line(struct conn_s *connptr)
* This structure holds the information pulled from a URL request. * This structure holds the information pulled from a URL request.
*/ */
struct request_s { struct request_s {
char *method;
char *protocol;
char *host; char *host;
char *path; char *path;
int port; int port;
}; };
static void free_request_struct(struct request_s *request)
{
safefree(request->method);
safefree(request->protocol);
safefree(request->host);
safefree(request->path);
safefree(request);
}
/* /*
* Pull the information out of the URL line. * Pull the information out of the URL line.
*/ */
@ -199,14 +183,12 @@ static int extract_ssl_url(const char *url, struct request_s *request)
* Create a connection for HTTP connections. * Create a connection for HTTP connections.
*/ */
static inline int establish_http_connection(struct conn_s *connptr, static inline int establish_http_connection(struct conn_s *connptr,
const char *method,
const char *protocol,
struct request_s *request) struct request_s *request)
{ {
/* /*
* Send the request line * Send the request line
*/ */
if (safe_write(connptr->server_fd, method, strlen(method)) < 0) if (safe_write(connptr->server_fd, request->method, strlen(request->method)) < 0)
return -1; return -1;
if (safe_write(connptr->server_fd, " ", 1) < 0) if (safe_write(connptr->server_fd, " ", 1) < 0)
return -1; return -1;
@ -266,44 +248,43 @@ static inline int send_ssl_response(struct conn_s *connptr)
* Break the request line apart and figure out where to connect and * Break the request line apart and figure out where to connect and
* build a new request line. Finally connect to the remote server. * build a new request line. Finally connect to the remote server.
*/ */
static int process_request(struct conn_s *connptr, char *request_line) static struct request_s *process_request(struct conn_s *connptr,
char *request_line)
{ {
char *method;
char *url; char *url;
char *protocol; struct request_s *request;
struct request_s request;
int ret; int ret;
size_t request_len; size_t request_len;
/* NULL out all the fields so free's don't cause segfaults. */ /* NULL out all the fields so free's don't cause segfaults. */
memset(&request, 0, sizeof(struct request_s)); request = safecalloc(1, sizeof(struct request_s));
if (!request)
return NULL;
request_len = strlen(request_line) + 1; request_len = strlen(request_line) + 1;
method = safemalloc(request_len); request->method = safemalloc(request_len);
url = safemalloc(request_len); url = safemalloc(request_len);
protocol = safemalloc(request_len); request->protocol = safemalloc(request_len);
if (!method || !url || !protocol) { if (!request->method || !url || !request->protocol) {
safefree(method);
safefree(url); safefree(url);
safefree(protocol); free_request_struct(request);
return -1;
return NULL;
} }
ret = sscanf(request_line, "%[^ ] %[^ ] %[^ ]", method, url, protocol); ret = sscanf(request_line, "%[^ ] %[^ ] %[^ ]", request->method, url, request->protocol);
if (ret < 2) { if (ret < 2) {
log_message(LOG_ERR, "Bad Request on file descriptor %d", connptr->client_fd); log_message(LOG_ERR, "Bad Request on file descriptor %d", connptr->client_fd);
httperr(connptr, 400, "Bad Request. No request found."); httperr(connptr, 400, "Bad Request. No request found.");
safefree(method);
safefree(url); safefree(url);
safefree(protocol); free_request_struct(request);
return -1; return NULL;
} else if (ret == 2) { } else if (ret == 2) {
connptr->simple_req = TRUE; connptr->simple_req = TRUE;
} }
@ -312,47 +293,43 @@ static int process_request(struct conn_s *connptr, char *request_line)
log_message(LOG_ERR, "Null URL on file descriptor %d", connptr->client_fd); log_message(LOG_ERR, "Null URL on file descriptor %d", connptr->client_fd);
httperr(connptr, 400, "Bad Request. Null URL."); httperr(connptr, 400, "Bad Request. Null URL.");
safefree(method);
safefree(url); safefree(url);
safefree(protocol); free_request_struct(request);
return -1; return NULL;
} }
if (strncasecmp(url, "http://", 7) == 0) { if (strncasecmp(url, "http://", 7) == 0) {
/* Make sure the first four characters are lowercase */ /* Make sure the first four characters are lowercase */
memcpy(url, "http", 4); memcpy(url, "http", 4);
if (extract_http_url(url, &request) < 0) { if (extract_http_url(url, request) < 0) {
httperr(connptr, 400, "Bad Request. Could not parse URL."); httperr(connptr, 400, "Bad Request. Could not parse URL.");
safefree(method);
safefree(url); safefree(url);
safefree(protocol); free_request_struct(request);
return -1; return NULL;
} }
connptr->ssl = FALSE; connptr->ssl = FALSE;
} else if (strcmp(method, "CONNECT") == 0) { } else if (strcmp(request->method, "CONNECT") == 0) {
if (extract_ssl_url(url, &request) < 0) { if (extract_ssl_url(url, request) < 0) {
httperr(connptr, 400, "Bad Request. Could not parse URL."); httperr(connptr, 400, "Bad Request. Could not parse URL.");
safefree(method);
safefree(url); safefree(url);
safefree(protocol); free_request_struct(request);
return -1; return NULL;
} }
connptr->ssl = TRUE; connptr->ssl = TRUE;
} else { } else {
log_message(LOG_ERR, "Unknown URL type on file descriptor %d", connptr->client_fd); log_message(LOG_ERR, "Unknown URL type on file descriptor %d", connptr->client_fd);
httperr(connptr, 400, "Bad Request. Unknown URL type."); httperr(connptr, 400, "Bad Request. Unknown URL type.");
safefree(method);
safefree(url); safefree(url);
safefree(protocol); free_request_struct(request);
return -1; return NULL;
} }
safefree(url); safefree(url);
@ -362,17 +339,15 @@ static int process_request(struct conn_s *connptr, char *request_line)
* Filter restricted domains * Filter restricted domains
*/ */
if (config.filter) { if (config.filter) {
if (filter_url(request.host)) { if (filter_url(request->host)) {
log_message(LOG_ERR, "Proxying refused on filtered domain \"%s\"", request.host); update_stats(STAT_DENIED);
log_message(LOG_ERR, "Proxying refused on filtered domain \"%s\"", request->host);
httperr(connptr, 404, "Connection to filtered domain is now allowed."); httperr(connptr, 404, "Connection to filtered domain is now allowed.");
safefree(request.host); free_request_struct(request);
safefree(request.path);
safefree(method); return NULL;
safefree(url);
return -1;
} }
} }
#endif #endif
@ -380,53 +355,16 @@ static int process_request(struct conn_s *connptr, char *request_line)
/* /*
* Check to see if they're requesting the stat host * Check to see if they're requesting the stat host
*/ */
if (!config.stathost && strcmp(config.stathost, request.host) == 0) { if (config.stathost && strcmp(config.stathost, request->host) == 0) {
safefree(request.host); log_message(LOG_NOTICE, "tinyproxy stathost request.");
safefree(request.path);
safefree(method); free_request_struct(request);
safefree(protocol);
showstats(connptr); showstats(connptr);
return 0; return NULL;
} }
/* return request;
* Connect to the remote server.
*/
connptr->server_fd = opensock(request.host, request.port);
if (connptr->server_fd < 0) {
httperr(connptr, 500, HTTP500ERROR);
safefree(request.host);
safefree(request.path);
safefree(method);
safefree(protocol);
return -1;
}
if (!connptr->ssl) {
if (establish_http_connection(connptr, method, protocol, &request) < 0) {
safefree(method);
safefree(protocol);
safefree(request.host);
safefree(request.path);
return -1;
}
}
safefree(method);
safefree(protocol);
safefree(request.host);
safefree(request.path);
return 0;
} }
/* /*
@ -478,7 +416,7 @@ static int pull_client_data(struct conn_s *connptr, unsigned long int length)
return -1; return -1;
} }
if (!connptr->output_message) { if (!connptr->send_message) {
if (safe_write(connptr->server_fd, buffer, len) < 0) { if (safe_write(connptr->server_fd, buffer, len) < 0) {
safefree(buffer); safefree(buffer);
return -1; return -1;
@ -547,7 +485,7 @@ static int process_client_headers(struct conn_s *connptr)
break; break;
} }
if (connptr->output_message) if (connptr->send_message)
continue; continue;
/* /*
@ -582,7 +520,7 @@ static int process_client_headers(struct conn_s *connptr)
} }
} }
if (!connptr->output_message && !connptr->ssl) { if (!connptr->send_message && !connptr->ssl) {
#ifdef XTINYPROXY_ENABLE #ifdef XTINYPROXY_ENABLE
if (config.my_domain if (config.my_domain
&& add_xtinyproxy_header(connptr) < 0) { && add_xtinyproxy_header(connptr) < 0) {
@ -744,7 +682,7 @@ static void initialize_conn(struct conn_s *connptr)
connptr->cbuffer = new_buffer(); connptr->cbuffer = new_buffer();
connptr->sbuffer = new_buffer(); connptr->sbuffer = new_buffer();
connptr->output_message = NULL; connptr->send_message = FALSE;
connptr->simple_req = FALSE; connptr->simple_req = FALSE;
connptr->ssl = FALSE; connptr->ssl = FALSE;
@ -764,7 +702,6 @@ static void destroy_conn(struct conn_s *connptr)
if (connptr->sbuffer) if (connptr->sbuffer)
delete_buffer(connptr->sbuffer); delete_buffer(connptr->sbuffer);
safefree(connptr->output_message);
safefree(connptr); safefree(connptr);
update_stats(STAT_CLOSE); update_stats(STAT_CLOSE);
@ -782,10 +719,12 @@ static void destroy_conn(struct conn_s *connptr)
void handle_connection(int fd) void handle_connection(int fd)
{ {
struct conn_s *connptr; struct conn_s *connptr;
struct request_s *request;
char peer_ipaddr[PEER_IP_LENGTH]; char peer_ipaddr[PEER_IP_LENGTH];
char peer_string[PEER_STRING_LENGTH]; char peer_string[PEER_STRING_LENGTH];
char *request_line; char *request_line = NULL;
log_message(LOG_CONN, "Connect (file descriptor %d): %s [%s]", log_message(LOG_CONN, "Connect (file descriptor %d): %s [%s]",
fd, fd,
@ -793,12 +732,8 @@ void handle_connection(int fd)
getpeer_ip(fd, peer_ipaddr)); getpeer_ip(fd, peer_ipaddr));
connptr = safemalloc(sizeof(struct conn_s)); connptr = safemalloc(sizeof(struct conn_s));
if (!connptr) { if (!connptr)
log_message(LOG_ERR,
"Could not allocate memory for request from [%s]",
peer_ipaddr);
return; return;
}
initialize_conn(connptr); initialize_conn(connptr);
connptr->client_fd = fd; connptr->client_fd = fd;
@ -811,19 +746,20 @@ void handle_connection(int fd)
#ifdef TUNNEL_SUPPORT #ifdef TUNNEL_SUPPORT
/* /*
* If an upstream proxy has been configured then redirect any * If tunnel has been configured then redirect any connections to
* connections to it. If we cannot connect to the upstream, see if * it. I know I used GOTOs, but it seems to me to be the best way
* we can handle it ourselves. I know I used GOTOs, but it seems to * of handling this situations. So sue me. :)
* me to be the best way of handling this situations. So sue me. :)
* - rjkaes * - rjkaes
*/ */
if (config.tunnel_name && config.tunnel_port != -1) { if (config.tunnel_name && config.tunnel_port != -1) {
log_message(LOG_INFO, "Redirecting to %s:%d", log_message(LOG_INFO, "Redirecting to %s:%d",
config.tunnel_name, config.tunnel_port); config.tunnel_name, config.tunnel_port);
connptr->server_fd = opensock(config.tunnel_name, config.tunnel_port); connptr->server_fd = opensock(config.tunnel_name, config.tunnel_port);
if (connptr->server_fd < 0) { if (connptr->server_fd < 0) {
log_message(LOG_WARNING, "Could not connect to tunnel's end, see if we can handle it ourselves."); log_message(LOG_WARNING, "Could not connect to tunnel.");
httperr(connptr, 404, "Unable to connect to tunnel.");
goto internal_proxy; goto internal_proxy;
} }
@ -839,31 +775,86 @@ void handle_connection(int fd)
internal_proxy: internal_proxy:
request_line = read_request_line(connptr); request_line = read_request_line(connptr);
if (!request_line) { if (!request_line) {
update_stats(STAT_BADCONN);
destroy_conn(connptr); destroy_conn(connptr);
return; return;
} }
if (process_request(connptr, request_line) < 0) { request = process_request(connptr, request_line);
safefree(request_line);
destroy_conn(connptr);
return;
}
safefree(request_line); safefree(request_line);
if (!request) {
update_stats(STAT_BADCONN);
if (!connptr->send_message) {
destroy_conn(connptr);
return;
}
} else {
#ifdef UPSTREAM_SUPPORT
if (config.upstream_name && config.upstream_port != -1) {
connptr->server_fd = opensock(config.upstream_name, config.upstream_port);
if (connptr->server_fd < 0) {
log_message(LOG_WARNING, "Could not connect to upstream proxy.");
httperr(connptr, 404, "Unable to connect to upstream proxy.");
goto send_error;
}
/*
* Send a new request line, plus the Host and
* Connection headers. The reason for the new request
* line is that we need to specify the HTTP/1.0
* protocol.
*/
safe_write(connptr->server_fd, request->method, strlen(request->method));
safe_write(connptr->server_fd, " http://", 8);
safe_write(connptr->server_fd, request->host, strlen(request->host));
if (request->port != 80) {
char port_string[16];
sprintf(port_string, ":%d", request->port);
safe_write(connptr->server_fd, port_string, strlen(port_string));
}
safe_write(connptr->server_fd, request->path, strlen(request->path));
safe_write(connptr->server_fd, " HTTP/1.0\r\n", 11);
safe_write(connptr->server_fd, "Host: ", 6);
safe_write(connptr->server_fd, request->host, strlen(request->host));
safe_write(connptr->server_fd, "\r\nConnection: close\r\n", 21);
free_request_struct(request);
} else {
#endif
connptr->server_fd = opensock(request->host, request->port);
if (connptr->server_fd < 0) {
httperr(connptr, 500, HTTP500ERROR);
free_request_struct(request);
goto send_error;
}
if (!connptr->ssl)
establish_http_connection(connptr, request);
free_request_struct(request);
#ifdef UPSTREAM_SUPPORT
}
#endif
}
send_error: send_error:
if (!connptr->simple_req) { if (!connptr->simple_req) {
if (process_client_headers(connptr) < 0) { if (process_client_headers(connptr) < 0) {
update_stats(STAT_BADCONN); update_stats(STAT_BADCONN);
destroy_conn(connptr); if (!connptr->send_message) {
return; destroy_conn(connptr);
return;
}
} }
} }
if (connptr->output_message) { if (connptr->send_message) {
safe_write(connptr->client_fd, connptr->output_message,
strlen(connptr->output_message));
destroy_conn(connptr); destroy_conn(connptr);
return; return;
} }
@ -877,6 +868,7 @@ send_error:
} else { } else {
if (send_ssl_response(connptr) < 0) { if (send_ssl_response(connptr) < 0) {
log_message(LOG_ERR, "Could not send SSL greeting to client."); log_message(LOG_ERR, "Could not send SSL greeting to client.");
update_stats(STAT_BADCONN);
destroy_conn(connptr); destroy_conn(connptr);
return; return;
} }