Merge branch 'v1.9.0'

This commit is contained in:
Mike Fährmann 2019-06-29 15:39:52 +02:00
commit 40da44b17f
No known key found for this signature in database
GPG Key ID: 5680CA389D365A88
19 changed files with 414 additions and 242 deletions

View File

@ -1,5 +1,7 @@
# Changelog
## Unreleased
## 1.8.7 - 2019-06-28
### Additions
- Support for

View File

@ -945,6 +945,16 @@ Description Enable/Disable this downloader module.
=========== =====
downloader.*.mtime
------------------
=========== =====
Type ``bool``
Default ``true``
Description Use |Last-Modified|_ HTTP response headers
to set file modification times.
=========== =====
downloader.*.part
-----------------
=========== =====
@ -1508,7 +1518,7 @@ Logging Configuration
=========== =====
Type ``object``
Example .. code::
Examples .. code::
{
"format": "{asctime} {name}: {message}",
@ -1517,10 +1527,21 @@ Example .. code::
"encoding": "ascii"
}
{
"level": "debug",
"format": {
"debug" : "debug: {message}",
"info" : "[{name}] {message}",
"warning": "Warning: {message}",
"error" : "ERROR: {message}"
}
}
Description Extended logging output configuration.
* format
* Format string for logging messages
* General format string for logging messages
or a dictionary with format strings for each loglevel.
In addition to the default
`LogRecord attributes <https://docs.python.org/3/library/logging.html#logrecord-attributes>`__,
@ -1589,6 +1610,7 @@ Description An object with the ``name`` of a post-processor and its options.
.. |webbrowser.open()| replace:: ``webbrowser.open()``
.. |datetime.max| replace:: ``datetime.max``
.. |Path| replace:: ``Path``
.. |Last-Modified| replace:: ``Last-Modified``
.. |Logging Configuration| replace:: ``Logging Configuration``
.. |Postprocessor Configuration| replace:: ``Postprocessor Configuration``
.. |strptime| replace:: strftime() and strptime() Behavior
@ -1604,6 +1626,7 @@ Description An object with the ``name`` of a post-processor and its options.
.. _requests.request(): https://docs.python-requests.org/en/master/api/#requests.request
.. _timeout: https://docs.python-requests.org/en/latest/user/advanced/#timeouts
.. _verify: https://docs.python-requests.org/en/master/user/advanced/#ssl-cert-verification
.. _Last-Modified: https://www.w3.org/Protocols/rfc2616/rfc2616-sec14.html#sec14.29
.. _`Requests' proxy documentation`: http://docs.python-requests.org/en/master/user/advanced/#proxies
.. _format string: https://docs.python.org/3/library/string.html#formatstrings
.. _format strings: https://docs.python.org/3/library/string.html#formatstrings

View File

@ -148,8 +148,13 @@
{
"mode": "terminal",
"log": {
"format": "{name}: {message}",
"level": "info"
"level": "info",
"format": {
"debug" : "\u001b[0;37m{name}: {message}\u001b[0m",
"info" : "\u001b[1;37m{name}: {message}\u001b[0m",
"warning": "\u001b[1;33m{name}: {message}\u001b[0m",
"error" : "\u001b[1;31m{name}: {message}\u001b[0m"
}
},
"logfile": {
"path": "~/gallery-dl/log.txt",

View File

@ -152,6 +152,17 @@
"http":
{
"mtime": true,
"rate": null,
"retries": 5,
"timeout": 30.0,
"verify": true
},
"ytdl":
{
"format": null,
"mtime": true,
"rate": null,
"retries": 5,
"timeout": 30.0,
@ -164,6 +175,7 @@
"mode": "auto",
"progress": true,
"shorten": true,
"log": "[{name}][{levelname}] {message}",
"logfile": null,
"unsupportedfile": null
},

View File

@ -46,7 +46,7 @@ ImageFap https://imagefap.com/ Images from Users, Gall
imgbox https://imgbox.com/ Galleries, individual Images
imgth https://imgth.com/ Galleries
imgur https://imgur.com/ Albums, individual Images
Instagram https://www.instagram.com/ Images from Users, individual Images, Tag-Searches
Instagram https://www.instagram.com/ Images from Users, individual Images, Tag-Searches Optional
Jaimini's Box https://jaiminisbox.com/reader/ Chapters, Manga
Joyreactor http://joyreactor.cc/ |joyreactor-C|
Keenspot http://www.keenspot.com/ Comics

View File

@ -22,15 +22,23 @@ def find(scheme):
try:
return _cache[scheme]
except KeyError:
klass = None
pass
klass = None
if scheme == "https":
scheme = "http"
if scheme in modules: # prevent unwanted imports
try:
if scheme in modules: # prevent unwanted imports
module = importlib.import_module("." + scheme, __package__)
klass = module.__downloader__
except (ImportError, AttributeError, TypeError):
module = importlib.import_module("." + scheme, __package__)
klass = module.__downloader__
except ImportError:
pass
if scheme == "http":
_cache["http"] = _cache["https"] = klass
else:
_cache[scheme] = klass
return klass
return klass
# --------------------------------------------------------------------

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2018 Mike Fährmann
# Copyright 2014-2019 Mike Fährmann
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 as
@ -9,23 +9,18 @@
"""Common classes and constants used by downloader modules."""
import os
import time
import logging
from .. import config, util, exception
from requests.exceptions import RequestException
from ssl import SSLError
from .. import config, util
class DownloaderBase():
"""Base class for downloaders"""
scheme = ""
retries = 1
def __init__(self, extractor, output):
self.session = extractor.session
self.out = output
self.log = logging.getLogger("downloader." + self.scheme)
self.downloading = False
self.part = self.config("part", True)
self.partdir = self.config("part-directory")
@ -34,137 +29,8 @@ class DownloaderBase():
os.makedirs(self.partdir, exist_ok=True)
def config(self, key, default=None):
"""Interpolate config value for 'key'"""
"""Interpolate downloader config value for 'key'"""
return config.interpolate(("downloader", self.scheme, key), default)
def download(self, url, pathfmt):
"""Download the resource at 'url' and write it to a file-like object"""
try:
return self.download_impl(url, pathfmt)
except Exception:
print()
raise
finally:
# remove file from incomplete downloads
if self.downloading and not self.part:
try:
os.remove(pathfmt.temppath)
except (OSError, AttributeError):
pass
def download_impl(self, url, pathfmt):
"""Actual implementaion of the download process"""
adj_ext = None
tries = 0
msg = ""
if self.part:
pathfmt.part_enable(self.partdir)
while True:
self.reset()
if tries:
self.log.warning("%s (%d/%d)", msg, tries, self.retries)
if tries >= self.retries:
return False
time.sleep(tries)
tries += 1
# check for .part file
filesize = pathfmt.part_size()
# connect to (remote) source
try:
offset, size = self.connect(url, filesize)
except exception.DownloadRetry as exc:
msg = exc
continue
except exception.DownloadComplete:
break
except Exception as exc:
self.log.warning(exc)
return False
# check response
if not offset:
mode = "w+b"
if filesize:
self.log.info("Unable to resume partial download")
else:
mode = "r+b"
self.log.info("Resuming download at byte %d", offset)
# set missing filename extension
if not pathfmt.has_extension:
pathfmt.set_extension(self.get_extension())
if pathfmt.exists():
pathfmt.temppath = ""
return True
self.out.start(pathfmt.path)
self.downloading = True
with pathfmt.open(mode) as file:
if offset:
file.seek(offset)
# download content
try:
self.receive(file)
except (RequestException, SSLError) as exc:
msg = exc
print()
continue
# check filesize
if size and file.tell() < size:
msg = "filesize mismatch ({} < {})".format(
file.tell(), size)
continue
# check filename extension
adj_ext = self._check_extension(file, pathfmt)
break
self.downloading = False
if adj_ext:
pathfmt.set_extension(adj_ext)
return True
def connect(self, url, offset):
"""Connect to 'url' while respecting 'offset' if possible
Returns a 2-tuple containing the actual offset and expected filesize.
If the returned offset-value is greater than zero, all received data
will be appended to the existing .part file.
Return '0' as second tuple-field to indicate an unknown filesize.
"""
def receive(self, file):
"""Write data to 'file'"""
def reset(self):
"""Reset internal state / cleanup"""
def get_extension(self):
"""Return a filename extension appropriate for the current request"""
@staticmethod
def _check_extension(file, pathfmt):
"""Check filename extension against fileheader"""
extension = pathfmt.keywords["extension"]
if extension in FILETYPE_CHECK:
file.seek(0)
header = file.read(8)
if len(header) >= 8 and not FILETYPE_CHECK[extension](header):
for ext, check in FILETYPE_CHECK.items():
if ext != extension and check(header):
return ext
return None
FILETYPE_CHECK = {
"jpg": lambda h: h[0:2] == b"\xff\xd8",
"png": lambda h: h[0:8] == b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a",
"gif": lambda h: h[0:4] == b"GIF8" and h[5] == 97,
}
"""Write data from 'url' into the file specified by 'pathfmt'"""

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2018 Mike Fährmann
# Copyright 2014-2019 Mike Fährmann
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 as
@ -8,11 +8,13 @@
"""Downloader module for http:// and https:// URLs"""
import os
import time
import mimetypes
from requests.exceptions import ConnectionError, Timeout
from requests.exceptions import RequestException, ConnectionError, Timeout
from ssl import SSLError
from .common import DownloaderBase
from .. import text, exception
from .. import text
class HttpDownloader(DownloaderBase):
@ -20,11 +22,12 @@ class HttpDownloader(DownloaderBase):
def __init__(self, extractor, output):
DownloaderBase.__init__(self, extractor, output)
self.response = None
self.retries = self.config("retries", extractor._retries)
self.timeout = self.config("timeout", extractor._timeout)
self.verify = self.config("verify", extractor._verify)
self.mtime = self.config("mtime", True)
self.rate = self.config("rate")
self.downloading = False
self.chunk_size = 16384
if self.rate:
@ -34,41 +37,133 @@ class HttpDownloader(DownloaderBase):
elif self.rate < self.chunk_size:
self.chunk_size = self.rate
def connect(self, url, offset):
headers = {}
if offset:
headers["Range"] = "bytes={}-".format(offset)
def download(self, url, pathfmt):
try:
self.response = self.session.request(
"GET", url, stream=True, headers=headers, allow_redirects=True,
timeout=self.timeout, verify=self.verify)
except (ConnectionError, Timeout) as exc:
raise exception.DownloadRetry(exc)
return self._download_impl(url, pathfmt)
except Exception:
print()
raise
finally:
# remove file from incomplete downloads
if self.downloading and not self.part:
try:
os.unlink(pathfmt.temppath)
except (OSError, AttributeError):
pass
code = self.response.status_code
if code == 200: # OK
offset = 0
size = self.response.headers.get("Content-Length")
elif code == 206: # Partial Content
size = self.response.headers["Content-Range"].rpartition("/")[2]
elif code == 416: # Requested Range Not Satisfiable
raise exception.DownloadComplete()
elif code == 429 or 500 <= code < 600: # Server Error
raise exception.DownloadRetry(
"{} Server Error: {} for url: {}".format(
code, self.response.reason, url))
else:
self.response.raise_for_status()
def _download_impl(self, url, pathfmt):
response = None
adj_ext = None
tries = 0
msg = ""
return offset, text.parse_int(size)
if self.part:
pathfmt.part_enable(self.partdir)
def receive(self, file):
while True:
if tries:
if response:
response.close()
self.log.warning("%s (%d/%d)", msg, tries, self.retries)
if tries >= self.retries:
return False
time.sleep(tries)
tries += 1
# check for .part file
filesize = pathfmt.part_size()
if filesize:
headers = {"Range": "bytes={}-".format(filesize)}
else:
headers = None
# connect to (remote) source
try:
response = self.session.request(
"GET", url, stream=True, headers=headers,
timeout=self.timeout, verify=self.verify)
except (ConnectionError, Timeout) as exc:
msg = str(exc)
continue
except Exception as exc:
self.log.warning("%s", exc)
return False
# check response
code = response.status_code
if code == 200: # OK
offset = 0
size = response.headers.get("Content-Length")
elif code == 206: # Partial Content
offset = filesize
size = response.headers["Content-Range"].rpartition("/")[2]
elif code == 416: # Requested Range Not Satisfiable
break
else:
msg = "{}: {} for url: {}".format(code, response.reason, url)
if code == 429 or 500 <= code < 600: # Server Error
continue
self.log.warning("%s", msg)
return False
size = text.parse_int(size)
# set missing filename extension
if not pathfmt.has_extension:
pathfmt.set_extension(self.get_extension(response))
if pathfmt.exists():
pathfmt.temppath = ""
return True
# set open mode
if not offset:
mode = "w+b"
if filesize:
self.log.info("Unable to resume partial download")
else:
mode = "r+b"
self.log.info("Resuming download at byte %d", offset)
# start downloading
self.out.start(pathfmt.path)
self.downloading = True
with pathfmt.open(mode) as file:
if offset:
file.seek(offset)
# download content
try:
self.receive(response, file)
except (RequestException, SSLError) as exc:
msg = str(exc)
print()
continue
# check filesize
if size and file.tell() < size:
msg = "filesize mismatch ({} < {})".format(
file.tell(), size)
continue
# check filename extension
adj_ext = self.check_extension(file, pathfmt)
break
self.downloading = False
if adj_ext:
pathfmt.set_extension(adj_ext)
if self.mtime:
filetime = response.headers.get("Last-Modified")
if filetime:
pathfmt.keywords["_filetime"] = filetime
return True
def receive(self, response, file):
if self.rate:
total = 0 # total amount of bytes received
start = time.time() # start time
for data in self.response.iter_content(self.chunk_size):
for data in response.iter_content(self.chunk_size):
file.write(data)
if self.rate:
@ -79,13 +174,8 @@ class HttpDownloader(DownloaderBase):
# sleep if less time passed than expected
time.sleep(expected - delta)
def reset(self):
if self.response:
self.response.close()
self.response = None
def get_extension(self):
mtype = self.response.headers.get("Content-Type", "image/jpeg")
def get_extension(self, response):
mtype = response.headers.get("Content-Type", "image/jpeg")
mtype = mtype.partition(";")[0]
if mtype in MIMETYPE_MAP:
@ -100,6 +190,26 @@ class HttpDownloader(DownloaderBase):
"No filename extension found for MIME type '%s'", mtype)
return "txt"
@staticmethod
def check_extension(file, pathfmt):
"""Check filename extension against fileheader"""
extension = pathfmt.keywords["extension"]
if extension in FILETYPE_CHECK:
file.seek(0)
header = file.read(8)
if len(header) >= 8 and not FILETYPE_CHECK[extension](header):
for ext, check in FILETYPE_CHECK.items():
if ext != extension and check(header):
return ext
return None
FILETYPE_CHECK = {
"jpg": lambda h: h[0:2] == b"\xff\xd8",
"png": lambda h: h[0:8] == b"\x89\x50\x4e\x47\x0d\x0a\x1a\x0a",
"gif": lambda h: h[0:4] == b"GIF8" and h[5] == 97,
}
MIMETYPE_MAP = {
"image/jpeg": "jpg",

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
# Copyright 2014-2018 Mike Fährmann
# Copyright 2014-2019 Mike Fährmann
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License version 2 as
@ -14,24 +14,13 @@ from .common import DownloaderBase
class TextDownloader(DownloaderBase):
scheme = "text"
def __init__(self, extractor, output):
DownloaderBase.__init__(self, extractor, output)
self.content = b""
def connect(self, url, offset):
data = url.encode()
self.content = data[offset + 5:]
return offset, len(data) - 5
def receive(self, file):
file.write(self.content)
def reset(self):
self.content = b""
@staticmethod
def get_extension():
return "txt"
def download(self, url, pathfmt):
if self.part:
pathfmt.part_enable(self.partdir)
self.out.start(pathfmt.path)
with pathfmt.open("wb") as file:
file.write(url.encode()[5:])
return True
__downloader__ = TextDownloader

View File

@ -27,6 +27,7 @@ class YoutubeDLDownloader(DownloaderBase):
"socket_timeout": self.config("timeout", extractor._timeout),
"nocheckcertificate": not self.config("verify", extractor._verify),
"nopart": not self.part,
"updatetime": self.config("mtime", True),
}
options.update(self.config("raw-options") or {})
@ -36,6 +37,9 @@ class YoutubeDLDownloader(DownloaderBase):
self.ytdl = YoutubeDL(options)
def download(self, url, pathfmt):
for cookie in self.session.cookies:
self.ytdl.cookiejar.set_cookie(cookie)
try:
info_dict = self.ytdl.extract_info(url[5:], download=False)
except Exception:

View File

@ -11,7 +11,8 @@
import hashlib
import json
from .common import Extractor, Message
from .. import text
from .. import text, exception
from ..cache import cache
class InstagramExtractor(Extractor):
@ -21,11 +22,14 @@ class InstagramExtractor(Extractor):
filename_fmt = "{sidecar_media_id:?/_/}{media_id}.{extension}"
archive_fmt = "{media_id}"
root = "https://www.instagram.com"
cookiedomain = ".instagram.com"
cookienames = ("sessionid",)
def get_metadata(self):
return {}
def items(self):
self.login()
yield Message.Version, 1
metadata = self.get_metadata()
@ -40,6 +44,46 @@ class InstagramExtractor(Extractor):
yield Message.Url, \
'ytdl:{}/p/{}/'.format(self.root, data['shortcode']), data
def login(self):
if self._check_cookies(self.cookienames):
return
username, password = self._get_auth_info()
if username:
self.session.cookies.set("ig_cb", "1", domain="www.instagram.com")
self._update_cookies(self._login_impl(username, password))
@cache(maxage=360*24*3600, keyarg=1)
def _login_impl(self, username, password):
self.log.info("Logging in as %s", username)
page = self.request(self.root + "/accounts/login/").text
headers = {
"Referer" : self.root + "/accounts/login/",
"X-IG-App-ID" : "936619743392459",
"X-Requested-With": "XMLHttpRequest",
}
response = self.request(self.root + "/web/__mid/", headers=headers)
headers["X-CSRFToken"] = response.cookies["csrftoken"]
headers["X-Instagram-AJAX"] = text.extract(
page, '"rollout_hash":"', '"')[0]
url = self.root + "/accounts/login/ajax/"
data = {
"username" : username,
"password" : password,
"queryParams" : "{}",
"optIntoOneTap": "true",
}
response = self.request(url, method="POST", headers=headers, data=data)
if not response.json().get("authenticated"):
raise exception.AuthenticationError()
return {
key: self.session.cookies.get(key)
for key in ("sessionid", "mid", "csrftoken")
}
def _extract_shared_data(self, page):
return json.loads(text.extract(page,
'window._sharedData = ', ';</script>')[0])

View File

@ -281,20 +281,22 @@ class DownloadJob(Job):
def get_downloader(self, scheme):
"""Return a downloader suitable for 'scheme'"""
if scheme == "https":
scheme = "http"
try:
return self.downloaders[scheme]
except KeyError:
pass
klass = downloader.find(scheme)
if klass and config.get(("downloader", scheme, "enabled"), True):
if klass and config.get(("downloader", klass.scheme, "enabled"), True):
instance = klass(self.extractor, self.out)
else:
instance = None
self.log.error("'%s:' URLs are not supported/enabled", scheme)
self.downloaders[scheme] = instance
if klass.scheme == "http":
self.downloaders["http"] = self.downloaders["https"] = instance
else:
self.downloaders[scheme] = instance
return instance
def initialize(self, keywords=None):

View File

@ -182,6 +182,12 @@ def build_parser():
dest="part", nargs=0, action=ConfigConstAction, const=False,
help="Do not use .part files",
)
downloader.add_argument(
"--no-mtime",
dest="mtime", nargs=0, action=ConfigConstAction, const=False,
help=("Do not set file modification times according to "
"Last-Modified HTTP response headers")
)
downloader.add_argument(
"--no-check-certificate",
dest="verify", nargs=0, action=ConfigConstAction, const=False,

View File

@ -35,6 +35,30 @@ class Logger(logging.Logger):
return rv
class Formatter(logging.Formatter):
"""Custom formatter that supports different formats per loglevel"""
def __init__(self, fmt, datefmt):
if not isinstance(fmt, dict):
fmt = {"debug": fmt, "info": fmt, "warning": fmt, "error": fmt}
self.formats = fmt
self.datefmt = datefmt
def format(self, record):
record.message = record.getMessage()
fmt = self.formats[record.levelname]
if "{asctime" in fmt:
record.asctime = self.formatTime(record, self.datefmt)
msg = fmt.format_map(record.__dict__)
if record.exc_info and not record.exc_text:
record.exc_text = self.formatException(record.exc_info)
if record.exc_text:
msg = msg + "\n" + record.exc_text
if record.stack_info:
msg = msg + "\n" + record.stack_info
return msg
def initialize_logging(loglevel):
"""Setup basic logging functionality before configfiles have been loaded"""
# convert levelnames to lowercase
@ -46,7 +70,7 @@ def initialize_logging(loglevel):
logging.Logger.manager.setLoggerClass(Logger)
# setup basic logging to stderr
formatter = logging.Formatter(LOG_FORMAT, LOG_FORMAT_DATE, "{")
formatter = Formatter(LOG_FORMAT, LOG_FORMAT_DATE)
handler = logging.StreamHandler()
handler.setFormatter(formatter)
handler.setLevel(loglevel)
@ -80,13 +104,11 @@ def setup_logging_handler(key, fmt=LOG_FORMAT, lvl=LOG_LEVEL):
"%s: missing or invalid path (%s)", key, exc)
return None
level = opts.get("level", lvl)
logfmt = opts.get("format", fmt)
datefmt = opts.get("format-date", LOG_FORMAT_DATE)
formatter = logging.Formatter(logfmt, datefmt, "{")
handler.setFormatter(formatter)
handler.setLevel(level)
handler.setLevel(opts.get("level", lvl))
handler.setFormatter(Formatter(
opts.get("format", fmt),
opts.get("format-date", LOG_FORMAT_DATE),
))
return handler
@ -100,10 +122,10 @@ def configure_logging_handler(key, handler):
if handler.level == LOG_LEVEL and "level" in opts:
handler.setLevel(opts["level"])
if "format" in opts or "format-date" in opts:
logfmt = opts.get("format", LOG_FORMAT)
datefmt = opts.get("format-date", LOG_FORMAT_DATE)
formatter = logging.Formatter(logfmt, datefmt, "{")
handler.setFormatter(formatter)
handler.setFormatter(Formatter(
opts.get("format", LOG_FORMAT),
opts.get("format-date", LOG_FORMAT_DATE),
))
# --------------------------------------------------------------------

View File

@ -12,6 +12,7 @@ import re
import os
import sys
import json
import time
import shutil
import string
import _string
@ -19,6 +20,7 @@ import sqlite3
import datetime
import operator
import itertools
import email.utils
import urllib.parse
from . import text, exception
@ -629,17 +631,23 @@ class PathFormat():
os.unlink(self.temppath)
return
if self.temppath == self.realpath:
return
if self.temppath != self.realpath:
# move temp file to its actual location
try:
os.replace(self.temppath, self.realpath)
except OSError:
shutil.copyfile(self.temppath, self.realpath)
os.unlink(self.temppath)
try:
os.replace(self.temppath, self.realpath)
return
except OSError:
pass
shutil.copyfile(self.temppath, self.realpath)
os.unlink(self.temppath)
if "_filetime" in self.keywords:
# try to set file times
try:
filetime = email.utils.mktime_tz(email.utils.parsedate_tz(
self.keywords["_filetime"]))
if filetime:
os.utime(self.realpath, (time.time(), filetime))
except Exception:
pass
@staticmethod
def adjust_path(path):

View File

@ -6,4 +6,4 @@
# it under the terms of the GNU General Public License version 2 as
# published by the Free Software Foundation.
__version__ = "1.8.7"
__version__ = "1.9.0-dev"

View File

@ -108,6 +108,7 @@ AUTH_MAP = {
"exhentai" : "Optional",
"flickr" : "Optional (OAuth)",
"idolcomplex": "Optional",
"instagram" : "Optional",
"luscious" : "Optional",
"mangoxo" : "Optional",
"nijie" : "Required",

View File

@ -8,13 +8,16 @@
# published by the Free Software Foundation.
import re
import sys
import base64
import os.path
import tempfile
import unittest
import threading
import http.server
import unittest
from unittest.mock import Mock, MagicMock, patch
import gallery_dl.downloader as downloader
import gallery_dl.extractor as extractor
import gallery_dl.config as config
@ -23,6 +26,73 @@ from gallery_dl.output import NullOutput
from gallery_dl.util import PathFormat
class MockDownloaderModule(Mock):
__downloader__ = "mock"
class TestDownloaderModule(unittest.TestCase):
@classmethod
def setUpClass(cls):
# allow import of ytdl downloader module without youtube_dl installed
sys.modules["youtube_dl"] = MagicMock()
@classmethod
def tearDownClass(cls):
del sys.modules["youtube_dl"]
def tearDown(self):
downloader._cache.clear()
def test_find(self):
cls = downloader.find("http")
self.assertEqual(cls.__name__, "HttpDownloader")
self.assertEqual(cls.scheme , "http")
cls = downloader.find("https")
self.assertEqual(cls.__name__, "HttpDownloader")
self.assertEqual(cls.scheme , "http")
cls = downloader.find("text")
self.assertEqual(cls.__name__, "TextDownloader")
self.assertEqual(cls.scheme , "text")
cls = downloader.find("ytdl")
self.assertEqual(cls.__name__, "YoutubeDLDownloader")
self.assertEqual(cls.scheme , "ytdl")
self.assertEqual(downloader.find("ftp"), None)
self.assertEqual(downloader.find("foo"), None)
self.assertEqual(downloader.find(1234) , None)
self.assertEqual(downloader.find(None) , None)
@patch("importlib.import_module")
def test_cache(self, import_module):
import_module.return_value = MockDownloaderModule()
downloader.find("http")
downloader.find("text")
downloader.find("ytdl")
self.assertEqual(import_module.call_count, 3)
downloader.find("http")
downloader.find("text")
downloader.find("ytdl")
self.assertEqual(import_module.call_count, 3)
@patch("importlib.import_module")
def test_cache_http(self, import_module):
import_module.return_value = MockDownloaderModule()
downloader.find("http")
downloader.find("https")
self.assertEqual(import_module.call_count, 1)
@patch("importlib.import_module")
def test_cache_https(self, import_module):
import_module.return_value = MockDownloaderModule()
downloader.find("https")
downloader.find("http")
self.assertEqual(import_module.call_count, 1)
class TestDownloaderBase(unittest.TestCase):
@classmethod
@ -134,9 +204,6 @@ class TestTextDownloader(TestDownloaderBase):
def test_text_offset(self):
self._run_test("text:foobar", "foo", "foobar", "txt", "txt")
def test_text_extension(self):
self._run_test("text:foobar", None, "foobar", None, "txt")
def test_text_empty(self):
self._run_test("text:", None, "", "txt", "txt")

View File

@ -161,9 +161,10 @@ class ResultJob(job.DownloadJob):
self.hash_keyword = hashlib.sha1()
self.hash_archive = hashlib.sha1()
self.hash_content = hashlib.sha1()
if content:
self.fileobj = TestPathfmt(self.hash_content)
self.get_downloader("http")._check_extension = lambda a, b: None
self.get_downloader("http").check_extension = lambda a, b: None
self.format_directory = TestFormatter(
"".join(self.extractor.directory_fmt))
@ -217,6 +218,7 @@ class TestPathfmt():
self.hashobj = hashobj
self.path = ""
self.size = 0
self.keywords = {}
self.has_extension = True
def __enter__(self):
@ -279,9 +281,10 @@ def setup_test_config():
config.set(("extractor", "password"), name)
config.set(("extractor", "nijie", "username"), email)
config.set(("extractor", "seiga", "username"), email)
config.set(("extractor", "danbooru", "username"), None)
config.set(("extractor", "twitter" , "username"), None)
config.set(("extractor", "mangoxo" , "password"), "VZ8DL3983u")
config.set(("extractor", "danbooru" , "username"), None)
config.set(("extractor", "instagram", "username"), None)
config.set(("extractor", "twitter" , "username"), None)
config.set(("extractor", "mangoxo" , "password"), "VZ8DL3983u")
config.set(("extractor", "deviantart", "client-id"), "7777")
config.set(("extractor", "deviantart", "client-secret"),