Skip to content

Instantly share code, notes, and snippets.

@jingwangsg
Last active June 25, 2024 06:07
Show Gist options
  • Select an option

  • Save jingwangsg/ec58cf4a3654c2bb9c6c03336e3ed536 to your computer and use it in GitHub Desktop.

Select an option

Save jingwangsg/ec58cf4a3654c2bb9c6c03336e3ed536 to your computer and use it in GitHub Desktop.
Crawler
from selenium import webdriver
from selenium.webdriver.common.keys import Keys
from selenium.webdriver.common.by import By
from selenium.webdriver.support.ui import WebDriverWait
from selenium.webdriver.support import expected_conditions as EC
from selenium.common.exceptions import NoSuchElementException, StaleElementReferenceException, TimeoutException
import time
import random
import os, os.path as osp
import tempfile
from tqdm import tqdm
import argparse
from kn_util.utils.logger import setup_logger_loguru
from kn_util.utils.io import save_csv, load_csv
setup_logger_loguru()
from loguru import logger
# Setup options for Edge
ch_options = webdriver.EdgeOptions()
ch_options.add_argument(
"user-agent=Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.36"
)
prefs = {"profile.managed_default_content_settings.images": 2}
ch_options.add_experimental_option("prefs", prefs)
ch_options.add_experimental_option("excludeSwitches", ["enable-automation"])
ch_options.add_experimental_option("useAutomationExtension", False)
ch_options.add_argument("--disable-infobars")
ch_options.add_argument("--incognito")
ch_options.add_argument("--headless")
# Setup WebDriver (make sure the path to your WebDriver is correct)
driver = webdriver.Edge(options=ch_options)
class BufferedWriter:
def __init__(self, file):
self.file = open(file, "w")
self.buffer = []
def write(self, data):
self.buffer.append(data)
if len(self.buffer) >= 100:
self.flush()
def flush(self):
self.file.write("\n".join(self.buffer))
self.buffer = []
def close(self):
self.flush()
self.file.close()
def has_next_button():
try:
WebDriverWait(driver, 10).until(
EC.presence_of_element_located(
(By.CSS_SELECTOR, "a.MuiButtonBase-root.MuiButton-root.MuiButton-outlined.MuiButton-outlinedPrimary")
)
)
driver.find_element(By.CSS_SELECTOR, "a.MuiButtonBase-root.MuiButton-root.MuiButton-outlined.MuiButton-outlinedPrimary")
return True
except NoSuchElementException:
return False
def get_video_url(link):
vid_str = link.rsplit("/", 1)[1]
split_str = vid_str.split("-", 2)
vid = split_str[1]
name = split_str[2]
link = f"https://ak.picdn.net/shutterstock/videos/{vid}/preview/{name}.mp4"
return link
def wait_and_get_element(term):
try:
element = WebDriverWait(driver, 20).until(EC.presence_of_element_located((By.CSS_SELECTOR, term)))
return element
except TimeoutException:
logger.error(f"Element {term} not found")
time.sleep(10)
driver.refresh()
def _parse_duration(duration):
return int(duration.split("-")[0]), int(duration.split("-")[1])
def check_available(driver):
if driver.title.startswith("ERROR"):
logger.error("Shutterstock is not available, please try again later")
exit(1)
def get_query_meta(query=None, duration_span="0-1", truncate=True):
base_url = f"https://www.shutterstock.com/zh/video/search"
if query is not None:
base_url += f"/{query}"
url = base_url + f"?duration={duration_span}"
driver.get(url)
logger.info(f"Opening {url}")
check_available(driver)
if driver.title.startswith("0"):
return 0, 0
total_page_term = "span.MuiTypography-root.MuiTypography-subtitle2.mui-10a3ukw-totalPages-centerPagination"
element = wait_and_get_element(total_page_term)
total_page = int(element.get_attribute("aria-label").split()[1].replace(",", ""))
if truncate and total_page > 999:
logger.info(f"Query {query} Duration {duration_span} has more than 999 pages ({total_page}), only crawling 999 pages")
total_page = 999
total_element_term = ".MuiTypography-root.MuiTypography-body2.mui-1xvvg7z-subtitle"
element = wait_and_get_element(total_element_term)
total_elements = int(element.text.split()[0].replace(",", ""))
return total_elements, total_page
def crawl_shutterstock_part(query=None, duration=0, page=1, expected_count=100):
base_url = f"https://www.shutterstock.com/zh/video/search"
if query is not None:
base_url += f"/{query}"
if driver.title.startswith("0"):
# no results for this duration
return None
link_term = "a.mui-t7xql4-a-inherit-link"
url = base_url + f"?duration={duration}-{duration+1}&page={page}"
driver.get(url)
logger.info(f"Opening {url}")
check_available(driver)
# Simulate more human-like scrolling
video_count = 0
WebDriverWait(driver, 10).until(EC.presence_of_element_located((By.CSS_SELECTOR, link_term)))
driver.execute_script("window.scrollBy(0, window.innerHeight * 100);")
while video_count < expected_count:
video_count = len(driver.find_elements(By.CSS_SELECTOR, link_term))
ret = dict()
for i in range(video_count):
try:
# Re-find the element to avoid staleness
video = driver.find_elements(By.CSS_SELECTOR, link_term)[i]
link = video.get_attribute("href")
link = get_video_url(link)
label = video.get_attribute("aria-label").encode("unicode_escape").decode("utf-8")
ret[link] = (duration, page, label)
except StaleElementReferenceException:
continue # Skip this element and continue with the next one
return ret
def get_filename(query, duration_span, rank, world_size):
return f"{query}.{duration_span}.{rank}_{world_size}"
def build_index(query, duration_span, rank, world_size, output_dir="output"):
st, ed = _parse_duration(duration_span)
query_alias = query.replace(" ", "-") if query is not None else "all"
filename = get_filename(query_alias, duration_span, rank, world_size)
index_file = osp.join(output_dir, f"{filename}.index.csv")
# calculate finished
finished = set()
urls = set()
query_file = osp.join(output_dir, f"{filename}.tsv")
if osp.exists(query_file):
with open(query_file, "r") as f:
lines = f.readlines()[1:]
for line in lines:
duration, page, link, label = line.strip().split("\t")
finished.add((int(duration), int(page)))
urls.add(link)
if osp.exists(index_file):
index_list = load_csv(index_file)
for i in range(len(index_list)):
index_list[i]["duration"] = int(index_list[i]["duration"])
index_list[i]["page"] = int(index_list[i]["page"])
index_list[i]["expected_count"] = int(index_list[i]["expected_count"])
logger.info(f"[{rank}/{world_size}] Index file {index_file} exists, loading from file")
index_list = index_list[rank::world_size]
index_list = sorted([x for x in index_list if (x["duration"], x["page"]) not in finished], key=lambda x: (x["duration"], x["page"]))
return index_list, urls
index_list = []
expected_count_all = 0
if rank == 0:
# _, total_pages_all = get_query_meta(query, duration_span, truncate=False)
for duration in tqdm(range(st, ed), desc="Building index"):
total_elements, total_pages = get_query_meta(query, duration_span=f"{duration}-{duration+1}")
logger.info(f"pages for Duration {duration}-{duration+1}: {total_pages}")
expected_count_by_duration = 0
for i in range(1, total_pages + 1):
expected_count = 100 if i < total_pages else total_elements % 100
index_list.append({"duration": duration, "page": i, "expected_count": expected_count})
expected_count_all += expected_count
expected_count_by_duration += expected_count
logger.info(f"Total expected count for Duration {duration}-{duration+1}: {expected_count_by_duration}")
time.sleep(random.uniform(1, 2))
logger.info(f"Total expected count: {expected_count_all}")
logger.info(f"[{rank}/{world_size}] Saving index to file {index_file}")
save_csv(index_list, index_file)
else:
logger.info(f"[{rank}/{world_size}] Waiting for index file {index_file}")
while not osp.exists(index_file):
time.sleep(1)
logger.info(f"[{rank}/{world_size}] Index file {index_file} exists, loading from file")
index_list = load_csv(index_file)
index_list = index_list[rank::world_size]
index_list = sorted([x for x in index_list if (x["duration"], x["page"]) not in finished], key=lambda x: (x["duration"], x["page"]))
return index_list, urls
# Main crawling function
def crawl_shutterstock(
query=None,
duration_span="0-600",
output_dir="output",
rank=0,
world_size=1,
):
indexs, finished_urls = build_index(
query,
duration_span=duration_span,
rank=rank,
world_size=world_size,
output_dir=output_dir,
)
query_alias = query.replace(" ", "-") if query is not None else "all"
query_file = osp.join(output_dir, f"{query_alias}.{duration_span}.{rank}_{world_size}.tsv")
if not osp.exists(query_file):
f = open(query_file, "w")
f.write("duration\tpage\tlink\tlabel\n")
else:
f = open(query_file, "a")
for v in tqdm(indexs):
duration_part, page, expected_count = v["duration"], v["page"], v["expected_count"]
ret = crawl_shutterstock_part(
query,
duration=duration_part,
page=page,
expected_count=expected_count,
)
# filter out urls that are already crawled
urls = set(ret.keys())
# if len(urls - finished_urls) < len(urls):
# import ipdb; ipdb.set_trace()
urls = urls - finished_urls
ret = {k: v for k, v in ret.items() if k in urls}
finished_urls.update(urls)
written = ""
for url, (duration, page, label) in ret.items():
written += f"{duration}\t{page}\t{url}\t{label}\n"
f.write(written)
f.flush()
time.sleep(random.uniform(1, 2))
f.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--query", type=str, default=None)
parser.add_argument("--rank", type=int, default=0)
parser.add_argument("--world_size", type=int, default=1)
parser.add_argument("--duration", type=str, default="0-600")
args = parser.parse_args()
os.makedirs("output", exist_ok=True)
crawl_shutterstock(
query=args.query,
duration_span=args.duration,
rank=args.rank,
world_size=args.world_size,
)
import requests
import random
import time
from loguru import logger
import os, os.path as osp
from kn_util.utils.multiproc import map_async_with_thread
from kn_util.utils.logger import setup_logger_loguru
setup_logger_loguru(logger=logger)
HTTPS_PROXY_URLS = [
"https://raw.githubusercontent.com/zloi-user/hideip.me/main/https.txt",
]
HTTP_PROXY_URLS = [
"https://raw.githubusercontent.com/zloi-user/hideip.me/main/http.txt",
]
HTTP_PROXY_PARSERS = [
lambda x: [_.rsplit(":", 1)[0] for _ in x.decode("unicode_escape").splitlines()],
]
HTTPS_PROXY_PARSERS = [
lambda x: [_.rsplit(":", 1)[0] for _ in x.decode("unicode_escape").splitlines()],
]
class ProxyPool:
def __init__(
self,
proxy_urls=HTTPS_PROXY_URLS,
proxy_parsers=HTTPS_PROXY_PARSERS,
target_url="www.baidu.com",
domain="https",
exclude_file="failed_proxies.txt",
):
assert not target_url.startswith("http") and not target_url.startswith("https"), "Please remove http or https from target_url"
self.proxy_urls = proxy_urls
self.proxy_parsers = proxy_parsers
self.refresh_timestamp = time
self.target_url = target_url
self.domain = domain
self.failed_proxies = set()
if osp.exists(exclude_file):
with open(exclude_file, "r") as f:
self.failed_proxies = set(f.read().splitlines())
self.refresh_proxies()
def refresh_proxies(self):
self.proxies = []
for i in range(len(self.proxy_urls)):
try:
response = requests.get(self.proxy_urls[i])
proxies = self.proxy_parsers[i](response.content)
except Exception as e:
print(e)
self.proxies.extend(proxies)
self.validate_proxies()
def validate_proxy(self, proxy, url):
if proxy in self.failed_proxies:
return False
try:
response = requests.get(
url,
proxies={self.domain: f"{self.domain}://{proxy}"},
allow_redirects=True,
verify=False,
timeout=60,
)
if response.status_code == 200:
print(f"Validated proxy: {proxy}")
return True
except Exception as e:
# print(f"Failed to validate proxy: {proxy}")
self.failed_proxies.add(proxy)
self.failed_proxies.add(proxy)
return False
def validate_proxies(self, url=None):
if url is None:
url = self.target_url
url = f"{self.domain}://{url}"
validate_results = map_async_with_thread(
iterable=self.proxies,
func=lambda proxy: self.validate_proxy(proxy=proxy, url=url),
num_thread=32,
# test_flag=True,
)
self.proxies = [proxy for proxy, result in zip(self.proxies, validate_results) if result]
logger.info(f"Validated {len(self.proxies)} proxies")
def get(self):
self.proxies = random.choice(self.proxies)
time_elapsed = time.time() - self.refresh_timestamp
if time_elapsed > 60 * 60 or len(self.proxies) < 100:
self.refresh_proxies()
self.refresh_timestamp = time.time()
def remove(self, proxy):
self.proxies.remove(proxy)
self.failed_proxies.add(proxy)
def update_exclude_file(self, exclude_file):
with open(exclude_file, "w") as f:
f.write("\n".join(self.failed_proxies))
def __len__(self):
return len(self.proxies)
if __name__ == "__main__":
proxy_pool = ProxyPool(
target_url="www.shutterstock.com/zh/video/search",
domain="https",
exclude_file="failed_proxies.txt",
)
proxy_pool.update_exclude_file("failed_proxies.txt")
import ipdb
ipdb.set_trace()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment