Last active
June 25, 2024 06:07
-
-
Save jingwangsg/ec58cf4a3654c2bb9c6c03336e3ed536 to your computer and use it in GitHub Desktop.
Crawler
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from selenium import webdriver | |
| from selenium.webdriver.common.keys import Keys | |
| from selenium.webdriver.common.by import By | |
| from selenium.webdriver.support.ui import WebDriverWait | |
| from selenium.webdriver.support import expected_conditions as EC | |
| from selenium.common.exceptions import NoSuchElementException, StaleElementReferenceException, TimeoutException | |
| import time | |
| import random | |
| import os, os.path as osp | |
| import tempfile | |
| from tqdm import tqdm | |
| import argparse | |
| from kn_util.utils.logger import setup_logger_loguru | |
| from kn_util.utils.io import save_csv, load_csv | |
| setup_logger_loguru() | |
| from loguru import logger | |
| # Setup options for Edge | |
| ch_options = webdriver.EdgeOptions() | |
| ch_options.add_argument( | |
| "user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36" | |
| ) | |
| prefs = {"profile.managed_default_content_settings.images": 2} | |
| ch_options.add_experimental_option("prefs", prefs) | |
| ch_options.add_experimental_option("excludeSwitches", ["enable-automation"]) | |
| ch_options.add_experimental_option("useAutomationExtension", False) | |
| ch_options.add_argument("--disable-infobars") | |
| ch_options.add_argument("--incognito") | |
| ch_options.add_argument("--headless") | |
| # Setup WebDriver (make sure the path to your WebDriver is correct) | |
| driver = webdriver.Edge(options=ch_options) | |
| class BufferedWriter: | |
| def __init__(self, file): | |
| self.file = open(file, "w") | |
| self.buffer = [] | |
| def write(self, data): | |
| self.buffer.append(data) | |
| if len(self.buffer) >= 100: | |
| self.flush() | |
| def flush(self): | |
| self.file.write("\n".join(self.buffer)) | |
| self.buffer = [] | |
| def close(self): | |
| self.flush() | |
| self.file.close() | |
| def has_next_button(): | |
| try: | |
| WebDriverWait(driver, 10).until( | |
| EC.presence_of_element_located( | |
| (By.CSS_SELECTOR, "a.MuiButtonBase-root.MuiButton-root.MuiButton-outlined.MuiButton-outlinedPrimary") | |
| ) | |
| ) | |
| driver.find_element(By.CSS_SELECTOR, "a.MuiButtonBase-root.MuiButton-root.MuiButton-outlined.MuiButton-outlinedPrimary") | |
| return True | |
| except NoSuchElementException: | |
| return False | |
| def get_video_url(link): | |
| vid_str = link.rsplit("/", 1)[1] | |
| split_str = vid_str.split("-", 2) | |
| vid = split_str[1] | |
| name = split_str[2] | |
| link = f"https://ak.picdn.net/shutterstock/videos/{vid}/preview/{name}.mp4" | |
| return link | |
| def wait_and_get_element(term): | |
| try: | |
| element = WebDriverWait(driver, 20).until(EC.presence_of_element_located((By.CSS_SELECTOR, term))) | |
| return element | |
| except TimeoutException: | |
| logger.error(f"Element {term} not found") | |
| time.sleep(10) | |
| driver.refresh() | |
| def _parse_duration(duration): | |
| return int(duration.split("-")[0]), int(duration.split("-")[1]) | |
| def check_available(driver): | |
| if driver.title.startswith("ERROR"): | |
| logger.error("Shutterstock is not available, please try again later") | |
| exit(1) | |
| def get_query_meta(query=None, duration_span="0-1", truncate=True): | |
| base_url = f"https://www.shutterstock.com/zh/video/search" | |
| if query is not None: | |
| base_url += f"/{query}" | |
| url = base_url + f"?duration={duration_span}" | |
| driver.get(url) | |
| logger.info(f"Opening {url}") | |
| check_available(driver) | |
| if driver.title.startswith("0"): | |
| return 0, 0 | |
| total_page_term = "span.MuiTypography-root.MuiTypography-subtitle2.mui-10a3ukw-totalPages-centerPagination" | |
| element = wait_and_get_element(total_page_term) | |
| total_page = int(element.get_attribute("aria-label").split()[1].replace(",", "")) | |
| if truncate and total_page > 999: | |
| logger.info(f"Query {query} Duration {duration_span} has more than 999 pages ({total_page}), only crawling 999 pages") | |
| total_page = 999 | |
| total_element_term = ".MuiTypography-root.MuiTypography-body2.mui-1xvvg7z-subtitle" | |
| element = wait_and_get_element(total_element_term) | |
| total_elements = int(element.text.split()[0].replace(",", "")) | |
| return total_elements, total_page | |
| def crawl_shutterstock_part(query=None, duration=0, page=1, expected_count=100): | |
| base_url = f"https://www.shutterstock.com/zh/video/search" | |
| if query is not None: | |
| base_url += f"/{query}" | |
| if driver.title.startswith("0"): | |
| # no results for this duration | |
| return None | |
| link_term = "a.mui-t7xql4-a-inherit-link" | |
| url = base_url + f"?duration={duration}-{duration+1}&page={page}" | |
| driver.get(url) | |
| logger.info(f"Opening {url}") | |
| check_available(driver) | |
| # Simulate more human-like scrolling | |
| video_count = 0 | |
| WebDriverWait(driver, 10).until(EC.presence_of_element_located((By.CSS_SELECTOR, link_term))) | |
| driver.execute_script("window.scrollBy(0, window.innerHeight * 100);") | |
| while video_count < expected_count: | |
| video_count = len(driver.find_elements(By.CSS_SELECTOR, link_term)) | |
| ret = dict() | |
| for i in range(video_count): | |
| try: | |
| # Re-find the element to avoid staleness | |
| video = driver.find_elements(By.CSS_SELECTOR, link_term)[i] | |
| link = video.get_attribute("href") | |
| link = get_video_url(link) | |
| label = video.get_attribute("aria-label").encode("unicode_escape").decode("utf-8") | |
| ret[link] = (duration, page, label) | |
| except StaleElementReferenceException: | |
| continue # Skip this element and continue with the next one | |
| return ret | |
| def get_filename(query, duration_span, rank, world_size): | |
| return f"{query}.{duration_span}.{rank}_{world_size}" | |
| def build_index(query, duration_span, rank, world_size, output_dir="output"): | |
| st, ed = _parse_duration(duration_span) | |
| query_alias = query.replace(" ", "-") if query is not None else "all" | |
| filename = get_filename(query_alias, duration_span, rank, world_size) | |
| index_file = osp.join(output_dir, f"{filename}.index.csv") | |
| # calculate finished | |
| finished = set() | |
| urls = set() | |
| query_file = osp.join(output_dir, f"{filename}.tsv") | |
| if osp.exists(query_file): | |
| with open(query_file, "r") as f: | |
| lines = f.readlines()[1:] | |
| for line in lines: | |
| duration, page, link, label = line.strip().split("\t") | |
| finished.add((int(duration), int(page))) | |
| urls.add(link) | |
| if osp.exists(index_file): | |
| index_list = load_csv(index_file) | |
| for i in range(len(index_list)): | |
| index_list[i]["duration"] = int(index_list[i]["duration"]) | |
| index_list[i]["page"] = int(index_list[i]["page"]) | |
| index_list[i]["expected_count"] = int(index_list[i]["expected_count"]) | |
| logger.info(f"[{rank}/{world_size}] Index file {index_file} exists, loading from file") | |
| index_list = index_list[rank::world_size] | |
| index_list = sorted([x for x in index_list if (x["duration"], x["page"]) not in finished], key=lambda x: (x["duration"], x["page"])) | |
| return index_list, urls | |
| index_list = [] | |
| expected_count_all = 0 | |
| if rank == 0: | |
| # _, total_pages_all = get_query_meta(query, duration_span, truncate=False) | |
| for duration in tqdm(range(st, ed), desc="Building index"): | |
| total_elements, total_pages = get_query_meta(query, duration_span=f"{duration}-{duration+1}") | |
| logger.info(f"pages for Duration {duration}-{duration+1}: {total_pages}") | |
| expected_count_by_duration = 0 | |
| for i in range(1, total_pages + 1): | |
| expected_count = 100 if i < total_pages else total_elements % 100 | |
| index_list.append({"duration": duration, "page": i, "expected_count": expected_count}) | |
| expected_count_all += expected_count | |
| expected_count_by_duration += expected_count | |
| logger.info(f"Total expected count for Duration {duration}-{duration+1}: {expected_count_by_duration}") | |
| time.sleep(random.uniform(1, 2)) | |
| logger.info(f"Total expected count: {expected_count_all}") | |
| logger.info(f"[{rank}/{world_size}] Saving index to file {index_file}") | |
| save_csv(index_list, index_file) | |
| else: | |
| logger.info(f"[{rank}/{world_size}] Waiting for index file {index_file}") | |
| while not osp.exists(index_file): | |
| time.sleep(1) | |
| logger.info(f"[{rank}/{world_size}] Index file {index_file} exists, loading from file") | |
| index_list = load_csv(index_file) | |
| index_list = index_list[rank::world_size] | |
| index_list = sorted([x for x in index_list if (x["duration"], x["page"]) not in finished], key=lambda x: (x["duration"], x["page"])) | |
| return index_list, urls | |
| # Main crawling function | |
| def crawl_shutterstock( | |
| query=None, | |
| duration_span="0-600", | |
| output_dir="output", | |
| rank=0, | |
| world_size=1, | |
| ): | |
| indexs, finished_urls = build_index( | |
| query, | |
| duration_span=duration_span, | |
| rank=rank, | |
| world_size=world_size, | |
| output_dir=output_dir, | |
| ) | |
| query_alias = query.replace(" ", "-") if query is not None else "all" | |
| query_file = osp.join(output_dir, f"{query_alias}.{duration_span}.{rank}_{world_size}.tsv") | |
| if not osp.exists(query_file): | |
| f = open(query_file, "w") | |
| f.write("duration\tpage\tlink\tlabel\n") | |
| else: | |
| f = open(query_file, "a") | |
| for v in tqdm(indexs): | |
| duration_part, page, expected_count = v["duration"], v["page"], v["expected_count"] | |
| ret = crawl_shutterstock_part( | |
| query, | |
| duration=duration_part, | |
| page=page, | |
| expected_count=expected_count, | |
| ) | |
| # filter out urls that are already crawled | |
| urls = set(ret.keys()) | |
| # if len(urls - finished_urls) < len(urls): | |
| # import ipdb; ipdb.set_trace() | |
| urls = urls - finished_urls | |
| ret = {k: v for k, v in ret.items() if k in urls} | |
| finished_urls.update(urls) | |
| written = "" | |
| for url, (duration, page, label) in ret.items(): | |
| written += f"{duration}\t{page}\t{url}\t{label}\n" | |
| f.write(written) | |
| f.flush() | |
| time.sleep(random.uniform(1, 2)) | |
| f.close() | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--query", type=str, default=None) | |
| parser.add_argument("--rank", type=int, default=0) | |
| parser.add_argument("--world_size", type=int, default=1) | |
| parser.add_argument("--duration", type=str, default="0-600") | |
| args = parser.parse_args() | |
| os.makedirs("output", exist_ok=True) | |
| crawl_shutterstock( | |
| query=args.query, | |
| duration_span=args.duration, | |
| rank=args.rank, | |
| world_size=args.world_size, | |
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import requests | |
| import random | |
| import time | |
| from loguru import logger | |
| import os, os.path as osp | |
| from kn_util.utils.multiproc import map_async_with_thread | |
| from kn_util.utils.logger import setup_logger_loguru | |
| setup_logger_loguru(logger=logger) | |
| HTTPS_PROXY_URLS = [ | |
| "https://raw.githubusercontent.com/zloi-user/hideip.me/main/https.txt", | |
| ] | |
| HTTP_PROXY_URLS = [ | |
| "https://raw.githubusercontent.com/zloi-user/hideip.me/main/http.txt", | |
| ] | |
| HTTP_PROXY_PARSERS = [ | |
| lambda x: [_.rsplit(":", 1)[0] for _ in x.decode("unicode_escape").splitlines()], | |
| ] | |
| HTTPS_PROXY_PARSERS = [ | |
| lambda x: [_.rsplit(":", 1)[0] for _ in x.decode("unicode_escape").splitlines()], | |
| ] | |
| class ProxyPool: | |
| def __init__( | |
| self, | |
| proxy_urls=HTTPS_PROXY_URLS, | |
| proxy_parsers=HTTPS_PROXY_PARSERS, | |
| target_url="www.baidu.com", | |
| domain="https", | |
| exclude_file="failed_proxies.txt", | |
| ): | |
| assert not target_url.startswith("http") and not target_url.startswith("https"), "Please remove http or https from target_url" | |
| self.proxy_urls = proxy_urls | |
| self.proxy_parsers = proxy_parsers | |
| self.refresh_timestamp = time | |
| self.target_url = target_url | |
| self.domain = domain | |
| self.failed_proxies = set() | |
| if osp.exists(exclude_file): | |
| with open(exclude_file, "r") as f: | |
| self.failed_proxies = set(f.read().splitlines()) | |
| self.refresh_proxies() | |
| def refresh_proxies(self): | |
| self.proxies = [] | |
| for i in range(len(self.proxy_urls)): | |
| try: | |
| response = requests.get(self.proxy_urls[i]) | |
| proxies = self.proxy_parsers[i](response.content) | |
| except Exception as e: | |
| print(e) | |
| self.proxies.extend(proxies) | |
| self.validate_proxies() | |
| def validate_proxy(self, proxy, url): | |
| if proxy in self.failed_proxies: | |
| return False | |
| try: | |
| response = requests.get( | |
| url, | |
| proxies={self.domain: f"{self.domain}://{proxy}"}, | |
| allow_redirects=True, | |
| verify=False, | |
| timeout=60, | |
| ) | |
| if response.status_code == 200: | |
| print(f"Validated proxy: {proxy}") | |
| return True | |
| except Exception as e: | |
| # print(f"Failed to validate proxy: {proxy}") | |
| self.failed_proxies.add(proxy) | |
| self.failed_proxies.add(proxy) | |
| return False | |
| def validate_proxies(self, url=None): | |
| if url is None: | |
| url = self.target_url | |
| url = f"{self.domain}://{url}" | |
| validate_results = map_async_with_thread( | |
| iterable=self.proxies, | |
| func=lambda proxy: self.validate_proxy(proxy=proxy, url=url), | |
| num_thread=32, | |
| # test_flag=True, | |
| ) | |
| self.proxies = [proxy for proxy, result in zip(self.proxies, validate_results) if result] | |
| logger.info(f"Validated {len(self.proxies)} proxies") | |
| def get(self): | |
| self.proxies = random.choice(self.proxies) | |
| time_elapsed = time.time() - self.refresh_timestamp | |
| if time_elapsed > 60 * 60 or len(self.proxies) < 100: | |
| self.refresh_proxies() | |
| self.refresh_timestamp = time.time() | |
| def remove(self, proxy): | |
| self.proxies.remove(proxy) | |
| self.failed_proxies.add(proxy) | |
| def update_exclude_file(self, exclude_file): | |
| with open(exclude_file, "w") as f: | |
| f.write("\n".join(self.failed_proxies)) | |
| def __len__(self): | |
| return len(self.proxies) | |
| if __name__ == "__main__": | |
| proxy_pool = ProxyPool( | |
| target_url="www.shutterstock.com/zh/video/search", | |
| domain="https", | |
| exclude_file="failed_proxies.txt", | |
| ) | |
| proxy_pool.update_exclude_file("failed_proxies.txt") | |
| import ipdb | |
| ipdb.set_trace() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment