diff --git a/programs/fileio.c b/programs/fileio.c index 78fb9a26..b384c3c2 100644 --- a/programs/fileio.c +++ b/programs/fileio.c @@ -31,6 +31,11 @@ #include /* clock */ #include /* errno */ +#if defined (_MSC_VER) +# include +# include +#endif + #include "mem.h" #include "fileio.h" #define ZSTD_STATIC_LINKING_ONLY /* ZSTD_magicNumber, ZSTD_frameHeaderSize_max */ @@ -188,6 +193,18 @@ void FIO_setOverlapLog(unsigned overlapLog){ /*-************************************* * Functions ***************************************/ +/** FIO_remove() : + * @result : Unlink `fileName`, even if it's read-only */ +static int FIO_remove(const char* path) +{ +#if defined(_WIN32) || defined(WIN32) + /* windows doesn't allow remove read-only files, so try to make it + * writable first */ + chmod(path, _S_IWRITE); +#endif + return remove(path); +} + /** FIO_openSrcFile() : * condition : `dstFileName` must be non-NULL. * @result : FILE* to `dstFileName`, or NULL if it fails */ @@ -230,23 +247,29 @@ static FILE* FIO_openDstFile(const char* dstFileName) if (g_sparseFileSupport == 1) { g_sparseFileSupport = ZSTD_SPARSE_DEFAULT; } - if (!g_overwrite && strcmp (dstFileName, nulmark)) { /* Check if destination file already exists */ + if (strcmp (dstFileName, nulmark)) { /* Check if destination file already exists */ f = fopen( dstFileName, "rb" ); if (f != 0) { /* dest file exists, prompt for overwrite authorization */ fclose(f); - if (g_displayLevel <= 1) { - /* No interaction possible */ - DISPLAY("zstd: %s already exists; not overwritten \n", dstFileName); - return NULL; - } - DISPLAY("zstd: %s already exists; do you wish to overwrite (y/N) ? ", dstFileName); - { int ch = getchar(); - if ((ch!='Y') && (ch!='y')) { - DISPLAY(" not overwritten \n"); + if (!g_overwrite) { + if (g_displayLevel <= 1) { + /* No interaction possible */ + DISPLAY("zstd: %s already exists; not overwritten \n", dstFileName); return NULL; } - while ((ch!=EOF) && (ch!='\n')) ch = getchar(); /* flush rest of input line */ - } } } + DISPLAY("zstd: %s already exists; do you wish to overwrite (y/N) ? ", dstFileName); + { int ch = getchar(); + if ((ch!='Y') && (ch!='y')) { + DISPLAY(" not overwritten \n"); + return NULL; + } + while ((ch!=EOF) && (ch!='\n')) ch = getchar(); /* flush rest of input line */ + } + } + + /* need to unlink */ + FIO_remove(dstFileName); + } } f = fopen( dstFileName, "wb" ); if (f==NULL) DISPLAYLEVEL(1, "zstd: %s: %s\n", dstFileName, strerror(errno)); } diff --git a/tests/playTests.sh b/tests/playTests.sh index 5abbb14e..266bbd91 100755 --- a/tests/playTests.sh +++ b/tests/playTests.sh @@ -109,6 +109,14 @@ $ZSTD -q tmp && die "overwrite check failed!" $ECHO "test : force overwrite" $ZSTD -q -f tmp $ZSTD -q --force tmp +$ECHO "test : overwrite readonly file" +rm -f tmpro tmpro.zst +$ECHO foo > tmpro.zst +$ECHO foo > tmpro +chmod 400 tmpro.zst +$ZSTD -q tmpro && die "should have refused to overwrite read-only file" +$ZSTD -q -f tmpro +rm -f tmpro tmpro.zst $ECHO "test : file removal" $ZSTD -f --rm tmp ls tmp && die "tmp should no longer be present"