Last active
January 9, 2025 08:58
-
-
Save jingwangsg/41e379ffadea7e3e9cb6270d54370465 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import os | |
| import os.path as osp | |
| import json | |
| import glob | |
| from prompts import * | |
| import google.generativeai as genai | |
| import time | |
| from queue import Queue, Empty | |
| from typing import Dict | |
| import threading | |
| import multiprocessing as mp | |
| def upload_to_gemini(path, mime_type=None): | |
| """Uploads the given file to Gemini. | |
| See https://ai.google.dev/gemini-api/docs/prompting_with_media | |
| """ | |
| file = genai.upload_file(path, mime_type=mime_type) | |
| # print(f"Uploaded file '{file.display_name}' as: {file.uri}") | |
| return file | |
| def call_gemini_flash_think( | |
| image_path: str, | |
| question: str, | |
| prompt: str, | |
| token_queue: Queue, | |
| ) -> Dict[str, str]: | |
| while True: | |
| try: | |
| token = token_queue.get_nowait() | |
| break | |
| except Exception as e: | |
| time.sleep(1) | |
| continue | |
| print(f"Using token: {token}") | |
| genai.configure(api_key=token) | |
| try: | |
| # create the model | |
| generation_config = { | |
| "temperature": 0.2, | |
| "top_p": 0.95, | |
| # "max_output_tokens": 8192, | |
| "max_output_tokens": 1024 * 8, | |
| "response_mime_type": "text/plain", | |
| } | |
| # print(f"{token} - {image_path} - {question} - {prompt}") | |
| model = genai.GenerativeModel( | |
| model_name="gemini-2.0-flash-thinking-exp-1219", | |
| generation_config=generation_config, | |
| system_instruction=prompt, | |
| ) | |
| chat_session = model.start_chat() | |
| image_file = upload_to_gemini(image_path) | |
| response = chat_session.send_message([image_file, f"### Question: {question}"]) | |
| except Exception as e: | |
| if "Generate Content API requests per minute" in str(e): | |
| print(f"Rate limit exceeded for {token}") | |
| return None, token, "<RATE_LIMIT>" | |
| elif "Resource has been exhausted" in str(e): | |
| print(f"Resource exhausted for {token}") | |
| return None, token, "<RESOURCE_EXHAUSTED>" | |
| else: | |
| print(f"Error for {token}: {e}") | |
| return None, token, str(e) | |
| return ( | |
| { | |
| "inner_thoughts": response.candidates[0].content.parts[0].text, | |
| "real_response": response.candidates[0].content.parts[1].text, | |
| }, | |
| token, | |
| None, | |
| ) | |
| if __name__ == "__main__": | |
| from prompts import general_prompt, doc_table_demos | |
| prompt = general_prompt + doc_table_demos | |
| token_file = "/mnt/amlfs-01/home/jingwang/PROJECTS/mmo1/data_bin/gemini_tokens.txt" | |
| tokens = open(token_file).read().strip().split("\n") | |
| for token in tokens: | |
| token_queue = Queue() | |
| token_queue.put(token) | |
| image_path = osp.join( | |
| "../data_bin/mmo1", "AtomMath-sft/images/tabmwp/tables/880.png" | |
| ) | |
| question = "Look at the following schedule. Which event begins at 12.25 P.M.?" | |
| response = call_gemini_flash_think( | |
| image_path=image_path, | |
| question=question, | |
| prompt=prompt, | |
| token_queue=token_queue, | |
| ) | |
| if response is None: | |
| print("FAILED") | |
| else: | |
| print("SUCCESS") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import argparse | |
| import os | |
| import os.path as osp | |
| from queue import Queue | |
| from tqdm import tqdm | |
| import threading | |
| import time | |
| import multiprocessing as mp | |
| from gemini_call import call_gemini_flash_think | |
| import prompts | |
| from hashlib import sha256 | |
| from dev_engine.data.io import save_json, load_json | |
| from dev_engine.multiproc import imap_async | |
| from dev_engine.system import listdir_fd | |
| from dev_engine.debug import setup_debugpy | |
| setup_debugpy() | |
| file_to_demo_type = { | |
| "chart-figure.json": "chart_figure_demos", | |
| "doc-table.json": "doc_table_demos", | |
| "general.json": "general_demos", | |
| "math-geometry.json": "math_geometry_demos", | |
| "science.json": "science_demos", | |
| } | |
| def get_args(): | |
| parser = argparse.ArgumentParser(description="Call Gemini API") | |
| parser.add_argument( | |
| "--token_file", type=str, help="Path to the file containing the API token" | |
| ) | |
| parser.add_argument( | |
| "--data_dir", type=str, help="Path to the directory containing the data" | |
| ) | |
| parser.add_argument( | |
| "--output_dir", type=str, help="Path to the directory to save the output" | |
| ) | |
| return parser.parse_args() | |
| def load_data_list(data_dir, output_dir): | |
| data_list = [] | |
| os.makedirs(output_dir, exist_ok=True) | |
| cache_files = listdir_fd(output_dir, postfix="json") | |
| cache_files = [osp.relpath(file, output_dir) for file in cache_files] | |
| print(f"Found {len(cache_files)} cached files") | |
| for file in file_to_demo_type.keys(): | |
| data = load_json(osp.join(data_dir, file)) | |
| for dataset_name, dataset in data.items(): | |
| def parse_item(sample): | |
| _id = sha256( | |
| (sample["question"] + sample["final_answer"]).encode() | |
| ).hexdigest() | |
| if f"{dataset_name}/{_id}.json" in cache_files: | |
| return None, None | |
| sample["_id"] = _id | |
| sample["image_path"] = osp.join( | |
| "/mnt/amlfs-01/home/jingwang/PROJECTS/mmo1/data_bin/mmo1", | |
| sample["image_path"], | |
| ) | |
| sample["dataset"] = dataset_name | |
| sample["prompt"] = getattr(prompts, "general_prompt") + getattr( | |
| prompts, file_to_demo_type[file] | |
| ) | |
| output_path = osp.join(output_dir, f"{dataset_name}/{_id}.json") | |
| return sample, output_path | |
| for sample, output_path in tqdm( | |
| map(parse_item, dataset), | |
| desc=f"Processing {dataset_name}", | |
| total=len(dataset), | |
| ): | |
| if sample is not None and output_path is not None: | |
| data_list.append((sample, output_path)) | |
| return data_list | |
| usable_tokens = 0 | |
| def output_status(output_dir): | |
| global usable_tokens | |
| while True: | |
| print( | |
| "\t".join( | |
| [ | |
| "Cached Files:" + str(len(listdir_fd(output_dir, postfix="json"))), | |
| "Usable Tokens:" + str(usable_tokens), | |
| ] | |
| ) | |
| ) | |
| time.sleep(10) | |
| def main(args): | |
| global usable_tokens | |
| tokens = open(args.token_file).read().strip().split("\n") | |
| usable_tokens = len(tokens) | |
| token_queue = mp.Manager().Queue() | |
| for token in tokens: | |
| token_queue.put(token) | |
| data_list = load_data_list(args.data_dir, output_dir=args.output_dir) | |
| def delayed_put(token, timeout_secs=2): | |
| time.sleep(timeout_secs) | |
| token_queue.put(token) | |
| imap_wrapper = imap_async( | |
| iterable=data_list, | |
| func=lambda x: call_gemini_flash_think( | |
| image_path=x[0]["image_path"], | |
| question=x[0]["question"], | |
| prompt=x[0]["prompt"], | |
| token_queue=token_queue, | |
| ), | |
| num_workers=min(32, len(tokens)), | |
| mode="process", | |
| ) | |
| pbar = tqdm(total=len(data_list)) | |
| threading.Thread(target=output_status, args=(args.output_dir,)).start() | |
| for inputs, (response, token, error) in imap_wrapper: | |
| if error is not None: | |
| imap_wrapper.task_queue.put(inputs) | |
| if error in ("<RESOURCE_EXHAUSTED>"): | |
| usable_tokens -= 1 | |
| pass | |
| else: | |
| threading.Thread(target=delayed_put, args=(token, 10)).start() | |
| continue | |
| threading.Thread(target=delayed_put, args=(token, 10)).start() | |
| output_path = inputs[1] | |
| os.makedirs(osp.dirname(output_path), exist_ok=True) | |
| sample = inputs[0] | |
| sample["response"] = response | |
| save_json(sample, output_path) | |
| pbar.update(1) | |
| # print(f"Saved to {output_path}") | |
| pbar.close() | |
| if __name__ == "__main__": | |
| args = get_args() | |
| main(args) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment