diff --git a/programs/fileio.c b/programs/fileio.c index 3cae2166..76e5cbf5 100644 --- a/programs/fileio.c +++ b/programs/fileio.c @@ -330,6 +330,7 @@ struct FIO_ctx_s { /* file i/o info */ int nbFilesTotal; + int hasStdinInput; /* file i/o state */ int currFileIdx; @@ -386,6 +387,7 @@ FIO_ctx_t* FIO_createContext(void) if (!ret) EXM_THROW(21, "Allocation error : not enough memory"); ret->currFileIdx = 0; + ret->hasStdinInput = 0; ret->nbFilesTotal = 1; ret->nbFilesProcessed = 0; ret->totalBytesInput = 0; @@ -539,6 +541,16 @@ void FIO_setNbFilesTotal(FIO_ctx_t* const fCtx, int value) fCtx->nbFilesTotal = value; } +void FIO_determineHasStdinInput(FIO_ctx_t* const fCtx, const FileNamesTable* const filenames) { + size_t i = 0; + for ( ; i < filenames->tableSize; ++i) { + if (!strcmp(stdinmark, filenames->fileNames[i])) { + fCtx->hasStdinInput = 1; + return; + } + } +} + /*-************************************* * Functions ***************************************/ @@ -603,8 +615,8 @@ static FILE* FIO_openSrcFile(const char* srcFileName) * condition : `dstFileName` must be non-NULL. * @result : FILE* to `dstFileName`, or NULL if it fails */ static FILE* -FIO_openDstFile(FIO_prefs_t* const prefs, - const char* srcFileName, const char* dstFileName) +FIO_openDstFile(FIO_ctx_t* fCtx, FIO_prefs_t* const prefs, + const char* srcFileName, const char* dstFileName) { if (prefs->testMode) return NULL; /* do not open file in test mode */ @@ -650,7 +662,7 @@ FIO_openDstFile(FIO_prefs_t* const prefs, return NULL; } DISPLAY("zstd: %s already exists; ", dstFileName); - if (UTIL_requireUserConfirmation("overwrite (y/n) ? ", "Not overwritten \n", "yY")) + if (UTIL_requireUserConfirmation("overwrite (y/n) ? ", "Not overwritten \n", "yY", fCtx->hasStdinInput)) return NULL; } /* need to unlink */ @@ -847,7 +859,7 @@ static int FIO_removeMultiFilesWarning(FIO_ctx_t* const fCtx, const FIO_prefs_t* } DISPLAYLEVEL(2, "\nThe concatenated output CANNOT regenerate the original directory tree. ") if (prefs->removeSrcFile) { - error = g_display_prefs.displayLevel > displayLevelCutoff && UTIL_requireUserConfirmation("This is a destructive operation. Proceed? (y/n): ", "Aborting...", "yY"); + error = g_display_prefs.displayLevel > displayLevelCutoff && UTIL_requireUserConfirmation("This is a destructive operation. Proceed? (y/n): ", "Aborting...", "yY", fCtx->hasStdinInput); } } DISPLAY("\n"); @@ -1574,7 +1586,7 @@ static int FIO_compressFilename_dstFile(FIO_ctx_t* const fCtx, if (ress.dstFile == NULL) { closeDstFile = 1; DISPLAYLEVEL(6, "FIO_compressFilename_dstFile: opening dst: %s \n", dstFileName); - ress.dstFile = FIO_openDstFile(prefs, srcFileName, dstFileName); + ress.dstFile = FIO_openDstFile(fCtx, prefs, srcFileName, dstFileName); if (ress.dstFile==NULL) return 1; /* could not open dstFileName */ /* Must only be added after FIO_openDstFile() succeeds. * Otherwise we may delete the destination file if it already exists, @@ -1781,7 +1793,7 @@ int FIO_compressMultipleFilenames(FIO_ctx_t* const fCtx, FIO_freeCResources(ress); return 1; } - ress.dstFile = FIO_openDstFile(prefs, NULL, outFileName); + ress.dstFile = FIO_openDstFile(fCtx, prefs, NULL, outFileName); if (ress.dstFile == NULL) { /* could not open outFileName */ error = 1; } else { @@ -2475,7 +2487,7 @@ static int FIO_decompressDstFile(FIO_ctx_t* const fCtx, if ((ress.dstFile == NULL) && (prefs->testMode==0)) { releaseDstFile = 1; - ress.dstFile = FIO_openDstFile(prefs, srcFileName, dstFileName); + ress.dstFile = FIO_openDstFile(fCtx, prefs, srcFileName, dstFileName); if (ress.dstFile==NULL) return 1; /* Must only be added after FIO_openDstFile() succeeds. @@ -2708,7 +2720,7 @@ FIO_decompressMultipleFilenames(FIO_ctx_t* const fCtx, return 1; } if (!prefs->testMode) { - ress.dstFile = FIO_openDstFile(prefs, NULL, outFileName); + ress.dstFile = FIO_openDstFile(fCtx, prefs, NULL, outFileName); if (ress.dstFile == 0) EXM_THROW(19, "cannot open %s", outFileName); } for (; fCtx->currFileIdx < fCtx->nbFilesTotal; fCtx->currFileIdx++) { diff --git a/programs/fileio.h b/programs/fileio.h index 1a0a35de..e290efcc 100644 --- a/programs/fileio.h +++ b/programs/fileio.h @@ -107,6 +107,7 @@ void FIO_setContentSize(FIO_prefs_t* const prefs, int value); /* FIO_ctx_t functions */ void FIO_setNbFilesTotal(FIO_ctx_t* const fCtx, int value); +void FIO_determineHasStdinInput(FIO_ctx_t* const fCtx, const FileNamesTable* const filenames); /*-************************************* * Single File functions diff --git a/programs/util.c b/programs/util.c index d828dc42..980ab5a4 100644 --- a/programs/util.c +++ b/programs/util.c @@ -88,8 +88,14 @@ UTIL_STATIC void* UTIL_realloc(void *ptr, size_t size) int g_utilDisplayLevel; int UTIL_requireUserConfirmation(const char* prompt, const char* abortMsg, - const char* acceptableLetters) { + const char* acceptableLetters, int hasStdinInput) { int ch, result; + + if (hasStdinInput) { + UTIL_DISPLAY("stdin is an input - not proceeding.\n"); + return 1; + } + UTIL_DISPLAY("%s", prompt); ch = getchar(); result = 0; diff --git a/programs/util.h b/programs/util.h index eeb6a15e..25fa3f53 100644 --- a/programs/util.h +++ b/programs/util.h @@ -96,8 +96,9 @@ extern int g_utilDisplayLevel; /** * Displays a message prompt and returns success (0) if first character from stdin * matches any from acceptableLetters. Otherwise, returns failure (1) and displays abortMsg. + * If any of the inputs are stdin itself, then automatically return failure (1). */ -int UTIL_requireUserConfirmation(const char* prompt, const char* abortMsg, const char* acceptableLetters); +int UTIL_requireUserConfirmation(const char* prompt, const char* abortMsg, const char* acceptableLetters, int hasStdinInput); /*-**************************************** diff --git a/programs/zstdcli.c b/programs/zstdcli.c index 5d1c09de..3df91c10 100644 --- a/programs/zstdcli.c +++ b/programs/zstdcli.c @@ -1286,6 +1286,7 @@ int main(int const argCount, const char* argv[]) /* IO Stream/File */ FIO_setNbFilesTotal(fCtx, (int)filenames->tableSize); + FIO_determineHasStdinInput(fCtx, filenames); FIO_setNotificationLevel(g_displayLevel); FIO_setPatchFromMode(prefs, patchFromDictFileName != NULL); if (memLimit == 0) { diff --git a/tests/playTests.sh b/tests/playTests.sh index b0eb08ca..7f75d850 100755 --- a/tests/playTests.sh +++ b/tests/playTests.sh @@ -310,6 +310,23 @@ test -f precompressedFilterTestDir/input.6.zst.zst println "Test completed" + +println "\n===> warning prompts should not occur if stdin is an input" +println "y" > tmpPrompt +println "hello world" >> tmpPrompt +zstd tmpPrompt -f +zstd < tmpPrompt -o tmpPrompt.zst && die "should have aborted immediately and failed to overwrite" +zstd < tmpPrompt -o tmpPrompt.zst -f # should successfully overwrite with -f +zstd -q -d -f tmpPrompt.zst -o tmpPromptRegenerated +$DIFF tmpPromptRegenerated tmpPrompt # the first 'y' character should not be swallowed + +echo 'yes' | zstd tmpPrompt -o tmpPrompt.zst # accept piped "y" input to force overwrite when using files +echo 'yes' | zstd < tmpPrompt -o tmpPrompt.zst && die "should have aborted immediately and failed to overwrite" +zstd tmpPrompt - < tmpPrompt -o tmpPromp.zst --rm && die "should have aborted immediately and failed to remove" + +println "Test completed" + + println "\n===> recursive mode test " # combination of -r with empty list of input file zstd -c -r < tmp > tmp.zst