from __future__ import annotations import collections import dataclasses import typing from yt_dlp.extractor.youtube.jsc._builtin.ejs import _EJS_WIKI_URL from yt_dlp.extractor.youtube.jsc._registry import ( _jsc_preferences, _jsc_providers, ) from yt_dlp.extractor.youtube.jsc.provider import ( JsChallengeProvider, JsChallengeProviderError, JsChallengeProviderRejectedRequest, JsChallengeProviderResponse, JsChallengeRequest, JsChallengeResponse, JsChallengeType, NChallengeInput, NChallengeOutput, SigChallengeInput, SigChallengeOutput, ) from yt_dlp.extractor.youtube.pot._director import YoutubeIEContentProviderLogger, provider_display_list from yt_dlp.extractor.youtube.pot._provider import ( IEContentProviderLogger, ) from yt_dlp.extractor.youtube.pot.provider import ( provider_bug_report_message, ) if typing.TYPE_CHECKING: from collections.abc import Iterable from yt_dlp.extractor.youtube.jsc._builtin.ejs import _SkippedComponent from yt_dlp.extractor.youtube.jsc.provider import Preference as JsChallengePreference class JsChallengeRequestDirector: def __init__(self, logger: IEContentProviderLogger): self.providers: dict[str, JsChallengeProvider] = {} self.preferences: list[JsChallengePreference] = [] self.logger = logger def register_provider(self, provider: JsChallengeProvider): self.providers[provider.PROVIDER_KEY] = provider def register_preference(self, preference: JsChallengePreference): self.preferences.append(preference) def _get_providers(self, requests: list[JsChallengeRequest]) -> Iterable[JsChallengeProvider]: """Sorts available providers by preference, given a request""" preferences = { provider: sum(pref(provider, requests) for pref in self.preferences) for provider in self.providers.values() } if self.logger.log_level <= self.logger.LogLevel.TRACE: # calling is_available() for every JS Challenge provider upfront may have some overhead self.logger.trace(f'JS Challenge Providers: {provider_display_list(self.providers.values())}') self.logger.trace('JS Challenge Provider preferences for this request: {}'.format(', '.join( f'{provider.PROVIDER_NAME}={pref}' for provider, pref in preferences.items()))) return ( provider for provider in sorted( self.providers.values(), key=preferences.get, reverse=True) if provider.is_available() ) def _handle_error(self, e: Exception, provider: JsChallengeProvider, requests: list[JsChallengeRequest]): if isinstance(e, JsChallengeProviderRejectedRequest): self.logger.trace( f'JS Challenge Provider "{provider.PROVIDER_NAME}" rejected ' f'{"this request" if len(requests) == 1 else f"{len(requests)} requests"}, ' f'trying next available provider. Reason: {e}', ) elif isinstance(e, JsChallengeProviderError): if len(requests) == 1: self.logger.warning( f'Error solving {requests[0].type.value} challenge request using "{provider.PROVIDER_NAME}" provider: {e}.\n' f' input = {requests[0].input}\n' f' {(provider_bug_report_message(provider, before="") if not e.expected else "")}') else: self.logger.warning( f'Error solving {len(requests)} challenge requests using "{provider.PROVIDER_NAME}" provider: {e}.\n' f' requests = {requests}\n' f' {(provider_bug_report_message(provider, before="") if not e.expected else "")}') else: self.logger.error( f'Unexpected error solving {len(requests)} challenge request(s) using "{provider.PROVIDER_NAME}" provider: {e!r}\n' f' requests = {requests}\n' f' {provider_bug_report_message(provider, before="")}', cause=e) def bulk_solve(self, requests: list[JsChallengeRequest]) -> list[tuple[JsChallengeRequest, JsChallengeResponse]]: """Solves multiple JS Challenges in bulk, returning a list of responses""" if not self.providers: self.logger.trace('No JS Challenge providers registered') return [] results = [] next_requests = requests[:] skipped_components = [] for provider in self._get_providers(next_requests): if not next_requests: break self.logger.trace( f'Attempting to solve {len(next_requests)} challenges using "{provider.PROVIDER_NAME}" provider') try: for response in provider.bulk_solve([dataclasses.replace(request) for request in next_requests]): if not validate_provider_response(response): self.logger.warning( f'JS Challenge Provider "{provider.PROVIDER_NAME}" returned an invalid response:' f' response = {response!r}\n' f' {provider_bug_report_message(provider, before="")}') continue if response.error: self._handle_error(response.error, provider, [response.request]) continue if (vr_msg := validate_response(response.response, response.request)) is not True: self.logger.warning( f'Invalid JS Challenge response received from "{provider.PROVIDER_NAME}" provider: {vr_msg or ""}\n' f' response = {response.response}\n' f' request = {response.request}\n' f' {provider_bug_report_message(provider, before="")}') continue try: next_requests.remove(response.request) except ValueError: self.logger.warning( f'JS Challenge Provider "{provider.PROVIDER_NAME}" returned a response for an unknown request:\n' f' request = {response.request}\n' f' {provider_bug_report_message(provider, before="")}') continue results.append((response.request, response.response)) except Exception as e: if isinstance(e, JsChallengeProviderRejectedRequest) and e._skipped_components: skipped_components.extend(e._skipped_components) self._handle_error(e, provider, next_requests) continue if skipped_components: self.__report_skipped_components(skipped_components) if len(results) != len(requests): self.logger.trace( f'Not all JS Challenges were solved, expected {len(requests)} responses, got {len(results)}') self.logger.trace(f'Unsolved requests: {next_requests}') else: self.logger.trace(f'Solved all {len(requests)} requested JS Challenges') return results def __report_skipped_components(self, components: list[_SkippedComponent], /): runtime_components = collections.defaultdict(list) for component in components: runtime_components[component.component].append(component.runtime) for runtimes in runtime_components.values(): runtimes.sort() description_lookup = { 'ejs:npm': 'NPM package', 'ejs:github': 'challenge solver script', } descriptions = [ f'{description_lookup.get(component, component)} ({", ".join(runtimes)})' for component, runtimes in runtime_components.items() if runtimes ] flags = [ f' --remote-components {f"{component} (recommended)" if component == "ejs:github" else f"{component} "}' for component, runtimes in runtime_components.items() if runtimes ] def join_parts(parts, joiner): if not parts: return '' if len(parts) == 1: return parts[0] return f'{", ".join(parts[:-1])} {joiner} {parts[-1]}' if len(descriptions) == 1: msg = ( f'Remote component {descriptions[0]} was skipped. ' f'It may be required to solve JS challenges. ' f'You can enable the download with {flags[0]}') else: msg = ( f'Remote components {join_parts(descriptions, "and")} were skipped. ' f'These may be required to solve JS challenges. ' f'You can enable these downloads with {join_parts(flags, "or")}, respectively') self.logger.warning(f'{msg}. For more information and alternatives, refer to {_EJS_WIKI_URL}') def close(self): for provider in self.providers.values(): provider.close() EXTRACTOR_ARG_PREFIX = 'youtubejsc' def initialize_jsc_director(ie): assert ie._downloader is not None, 'Downloader not set' enable_trace = ie._configuration_arg( 'jsc_trace', ['false'], ie_key='youtube', casesense=False)[0] == 'true' if enable_trace: log_level = IEContentProviderLogger.LogLevel.TRACE elif ie.get_param('verbose', False): log_level = IEContentProviderLogger.LogLevel.DEBUG else: log_level = IEContentProviderLogger.LogLevel.INFO def get_provider_logger_and_settings(provider, logger_key): logger_prefix = f'{logger_key}:{provider.PROVIDER_NAME}' extractor_key = f'{EXTRACTOR_ARG_PREFIX}-{provider.PROVIDER_KEY.lower()}' return ( YoutubeIEContentProviderLogger(ie, logger_prefix, log_level=log_level), ie.get_param('extractor_args', {}).get(extractor_key, {})) director = JsChallengeRequestDirector( logger=YoutubeIEContentProviderLogger(ie, 'jsc', log_level=log_level), ) ie._downloader.add_close_hook(director.close) for provider in _jsc_providers.value.values(): logger, settings = get_provider_logger_and_settings(provider, 'jsc') director.register_provider(provider(ie, logger, settings)) for preference in _jsc_preferences.value: director.register_preference(preference) if director.logger.log_level <= director.logger.LogLevel.DEBUG: # calling is_available() for every JS Challenge provider upfront may have some overhead director.logger.debug(f'JS Challenge Providers: {provider_display_list(director.providers.values())}') director.logger.trace(f'Registered {len(director.preferences)} JS Challenge provider preferences') return director def validate_provider_response(response: JsChallengeProviderResponse) -> bool: return ( isinstance(response, JsChallengeProviderResponse) and isinstance(response.request, JsChallengeRequest) and ( isinstance(response.response, JsChallengeResponse) or (response.error is not None and isinstance(response.error, Exception))) ) def validate_response(response: JsChallengeResponse, request: JsChallengeRequest) -> bool | str: if not isinstance(response, JsChallengeResponse): return 'Response is not a JsChallengeResponse' if request.type == JsChallengeType.N: return validate_nsig_challenge_output(response.output, request.input) else: return validate_sig_challenge_output(response.output, request.input) def validate_nsig_challenge_output(challenge_output: NChallengeOutput, challenge_input: NChallengeInput) -> bool | str: if not ( isinstance(challenge_output, NChallengeOutput) and len(challenge_output.results) == len(challenge_input.challenges) and all(isinstance(k, str) and isinstance(v, str) for k, v in challenge_output.results.items()) and all(challenge in challenge_output.results for challenge in challenge_input.challenges) ): return 'Invalid NChallengeOutput' # Validate n results are valid - if they end with the input challenge then the js function returned with an exception. for challenge, result in challenge_output.results.items(): if result.endswith(challenge): return f'n result is invalid for {challenge!r}: {result!r}' return True def validate_sig_challenge_output(challenge_output: SigChallengeOutput, challenge_input: SigChallengeInput) -> bool: return ( isinstance(challenge_output, SigChallengeOutput) and len(challenge_output.results) == len(challenge_input.challenges) and all(isinstance(k, str) and isinstance(v, str) for k, v in challenge_output.results.items()) and all(challenge in challenge_output.results for challenge in challenge_input.challenges) ) or 'Invalid SigChallengeOutput'