Skip to content

Instantly share code, notes, and snippets.

@jingwangsg
Last active January 9, 2025 08:58
Show Gist options
  • Select an option

  • Save jingwangsg/41e379ffadea7e3e9cb6270d54370465 to your computer and use it in GitHub Desktop.

Select an option

Save jingwangsg/41e379ffadea7e3e9cb6270d54370465 to your computer and use it in GitHub Desktop.
import os
import os.path as osp
import json
import glob
from prompts import *
import google.generativeai as genai
import time
from queue import Queue, Empty
from typing import Dict
import threading
import multiprocessing as mp
def upload_to_gemini(path, mime_type=None):
"""Uploads the given file to Gemini.
See https://ai.google.dev/gemini-api/docs/prompting_with_media
"""
file = genai.upload_file(path, mime_type=mime_type)
# print(f"Uploaded file '{file.display_name}' as: {file.uri}")
return file
def call_gemini_flash_think(
image_path: str,
question: str,
prompt: str,
token_queue: Queue,
) -> Dict[str, str]:
while True:
try:
token = token_queue.get_nowait()
break
except Exception as e:
time.sleep(1)
continue
print(f"Using token: {token}")
genai.configure(api_key=token)
try:
# create the model
generation_config = {
"temperature": 0.2,
"top_p": 0.95,
# "max_output_tokens": 8192,
"max_output_tokens": 1024 * 8,
"response_mime_type": "text/plain",
}
# print(f"{token} - {image_path} - {question} - {prompt}")
model = genai.GenerativeModel(
model_name="gemini-2.0-flash-thinking-exp-1219",
generation_config=generation_config,
system_instruction=prompt,
)
chat_session = model.start_chat()
image_file = upload_to_gemini(image_path)
response = chat_session.send_message([image_file, f"### Question: {question}"])
except Exception as e:
if "Generate Content API requests per minute" in str(e):
print(f"Rate limit exceeded for {token}")
return None, token, "<RATE_LIMIT>"
elif "Resource has been exhausted" in str(e):
print(f"Resource exhausted for {token}")
return None, token, "<RESOURCE_EXHAUSTED>"
else:
print(f"Error for {token}: {e}")
return None, token, str(e)
return (
{
"inner_thoughts": response.candidates[0].content.parts[0].text,
"real_response": response.candidates[0].content.parts[1].text,
},
token,
None,
)
if __name__ == "__main__":
from prompts import general_prompt, doc_table_demos
prompt = general_prompt + doc_table_demos
token_file = "/mnt/amlfs-01/home/jingwang/PROJECTS/mmo1/data_bin/gemini_tokens.txt"
tokens = open(token_file).read().strip().split("\n")
for token in tokens:
token_queue = Queue()
token_queue.put(token)
image_path = osp.join(
"../data_bin/mmo1", "AtomMath-sft/images/tabmwp/tables/880.png"
)
question = "Look at the following schedule. Which event begins at 12.25 P.M.?"
response = call_gemini_flash_think(
image_path=image_path,
question=question,
prompt=prompt,
token_queue=token_queue,
)
if response is None:
print("FAILED")
else:
print("SUCCESS")
import argparse
import os
import os.path as osp
from queue import Queue
from tqdm import tqdm
import threading
import time
import multiprocessing as mp
from gemini_call import call_gemini_flash_think
import prompts
from hashlib import sha256
from dev_engine.data.io import save_json, load_json
from dev_engine.multiproc import imap_async
from dev_engine.system import listdir_fd
from dev_engine.debug import setup_debugpy
setup_debugpy()
file_to_demo_type = {
"chart-figure.json": "chart_figure_demos",
"doc-table.json": "doc_table_demos",
"general.json": "general_demos",
"math-geometry.json": "math_geometry_demos",
"science.json": "science_demos",
}
def get_args():
parser = argparse.ArgumentParser(description="Call Gemini API")
parser.add_argument(
"--token_file", type=str, help="Path to the file containing the API token"
)
parser.add_argument(
"--data_dir", type=str, help="Path to the directory containing the data"
)
parser.add_argument(
"--output_dir", type=str, help="Path to the directory to save the output"
)
return parser.parse_args()
def load_data_list(data_dir, output_dir):
data_list = []
os.makedirs(output_dir, exist_ok=True)
cache_files = listdir_fd(output_dir, postfix="json")
cache_files = [osp.relpath(file, output_dir) for file in cache_files]
print(f"Found {len(cache_files)} cached files")
for file in file_to_demo_type.keys():
data = load_json(osp.join(data_dir, file))
for dataset_name, dataset in data.items():
def parse_item(sample):
_id = sha256(
(sample["question"] + sample["final_answer"]).encode()
).hexdigest()
if f"{dataset_name}/{_id}.json" in cache_files:
return None, None
sample["_id"] = _id
sample["image_path"] = osp.join(
"/mnt/amlfs-01/home/jingwang/PROJECTS/mmo1/data_bin/mmo1",
sample["image_path"],
)
sample["dataset"] = dataset_name
sample["prompt"] = getattr(prompts, "general_prompt") + getattr(
prompts, file_to_demo_type[file]
)
output_path = osp.join(output_dir, f"{dataset_name}/{_id}.json")
return sample, output_path
for sample, output_path in tqdm(
map(parse_item, dataset),
desc=f"Processing {dataset_name}",
total=len(dataset),
):
if sample is not None and output_path is not None:
data_list.append((sample, output_path))
return data_list
usable_tokens = 0
def output_status(output_dir):
global usable_tokens
while True:
print(
"\t".join(
[
"Cached Files:" + str(len(listdir_fd(output_dir, postfix="json"))),
"Usable Tokens:" + str(usable_tokens),
]
)
)
time.sleep(10)
def main(args):
global usable_tokens
tokens = open(args.token_file).read().strip().split("\n")
usable_tokens = len(tokens)
token_queue = mp.Manager().Queue()
for token in tokens:
token_queue.put(token)
data_list = load_data_list(args.data_dir, output_dir=args.output_dir)
def delayed_put(token, timeout_secs=2):
time.sleep(timeout_secs)
token_queue.put(token)
imap_wrapper = imap_async(
iterable=data_list,
func=lambda x: call_gemini_flash_think(
image_path=x[0]["image_path"],
question=x[0]["question"],
prompt=x[0]["prompt"],
token_queue=token_queue,
),
num_workers=min(32, len(tokens)),
mode="process",
)
pbar = tqdm(total=len(data_list))
threading.Thread(target=output_status, args=(args.output_dir,)).start()
for inputs, (response, token, error) in imap_wrapper:
if error is not None:
imap_wrapper.task_queue.put(inputs)
if error in ("<RESOURCE_EXHAUSTED>"):
usable_tokens -= 1
pass
else:
threading.Thread(target=delayed_put, args=(token, 10)).start()
continue
threading.Thread(target=delayed_put, args=(token, 10)).start()
output_path = inputs[1]
os.makedirs(osp.dirname(output_path), exist_ok=True)
sample = inputs[0]
sample["response"] = response
save_json(sample, output_path)
pbar.update(1)
# print(f"Saved to {output_path}")
pbar.close()
if __name__ == "__main__":
args = get_args()
main(args)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment