Skip to content

Instantly share code, notes, and snippets.

@vedroboev
Last active July 29, 2023 20:01
Show Gist options
  • Select an option

  • Save vedroboev/1f76d1a4347a54df96956b147e7f26fb to your computer and use it in GitHub Desktop.

Select an option

Save vedroboev/1f76d1a4347a54df96956b147e7f26fb to your computer and use it in GitHub Desktop.
A custom script for running a Stable Horde bridge inside AUTOMATIC1111's WebUi
import base64
import json
import logging
import time
from io import BytesIO
import gradio as gr
import modules.scripts as scripts
import modules.sd_samplers as samplers
import modules.shared as shared
import requests
import webui
from modules.processing import StableDiffusionProcessingTxt2Img, process_images
logging.basicConfig()
logger = logging.getLogger("HORDE-BRIDGE")
logger.setLevel(logging.DEBUG)
class HordeState:
interrupted = False
stopped = False
def interrupt_horde(self):
self.interrupted = True
def stop_horde(self):
self.stopped = True
def start_horde(self):
self.stopped = False
self.interrupted = False
horde_state = HordeState()
def gr_show(visible=True):
return {"visible": visible, "__type__": "update"}
def get_sampler_index_by_name(name: str):
aliases = [
sampler.aliases[0].lower() for sampler in samplers.samplers_data_k_diffusion
]
aliases += ["ddim", "plms"]
try:
return aliases.index((name or "k_lms").lower())
except ValueError:
logger.error(f'Sampler "{name}" not found, defaulting to k_lms...')
return aliases.index("k_lms")
def toggle_nsfw(value):
shared.opts.filter_nsfw = value
def run_bridge(
api_key,
horde_name="Awesome Instance",
horde_max_power=8,
horde_nsfw=True,
horde_censor_nsfw=False,
priority_usernames=[],
horde_blacklist=[],
horde_url="https://stablehorde.net",
interval=1,
):
# A slightly adapted version of run_bridge function from hlky's webui.
headers = {"apikey": api_key}
horde_state.start_horde()
current_id = None
current_payload = None
horde_max_pixels = horde_max_power * 8 * 64 * 64
loop_retry = 0
priority_usernames = priority_usernames or []
if isinstance(priority_usernames, str):
priority_usernames = priority_usernames.split(",")
horde_blacklist = horde_blacklist or []
if isinstance(horde_blacklist, str):
horde_blacklist = horde_blacklist.split(",")
wait_time = 1
while True:
time.sleep(wait_time)
if horde_state.stopped:
return logger.info("Finished working!")
if horde_state.interrupted:
logger.info("Horde interrupted, awaiting resume...")
while horde_state.interrupted:
time.sleep(interval)
logger.info("Interrupt finished, resuming horde...")
with webui.queue_lock:
if loop_retry > 10 and current_id:
logger.error(
f"Exceeded retry count {loop_retry} for generation id {current_id}. Aborting generation!"
)
current_id = None
current_payload = None
current_generation = None
loop_retry = 0
elif current_id:
logger.debug(
f"Retrying ({loop_retry}/10) for generation id {current_id}..."
)
gen_dict = {
"name": horde_name,
"max_pixels": horde_max_pixels,
"priority_usernames": priority_usernames,
"nsfw": horde_nsfw,
"blacklist": horde_blacklist,
}
if current_id:
loop_retry += 1
else:
try:
pop_req = requests.post(
horde_url + "/api/v2/generate/pop",
json=gen_dict,
headers=headers,
)
except requests.exceptions.ConnectionError:
logger.warning(
f"Server {horde_url} unavailable during pop. Waiting 10 seconds..."
)
wait_time = 10
continue
except requests.exceptions.JSONDecodeError():
logger.warning(
f"Server {horde_url} unavailable during pop. Waiting 10 seconds..."
)
wait_time = 10
continue
try:
pop = pop_req.json()
except json.decoder.JSONDecodeError:
logger.error(
f"Could not decode response from {horde_url} as json. Please inform its administrator!"
)
continue
if pop == None:
logger.error(
f"Something has gone wrong with {horde_url}. Please inform its administrator!"
)
continue
if not pop_req.ok:
message = pop["message"]
logger.warning(
f"During gen pop, server {horde_url} responded with status code {pop_req.status_code}: {pop['message']}. Waiting for 10 seconds..."
)
if "errors" in pop:
logger.warning(f"Detailed Request Errors: {pop['errors']}")
wait_time = 10
continue
if not pop.get("id"):
skipped_info = pop.get("skipped")
if skipped_info and len(skipped_info):
skipped_info = f" Skipped Info: {skipped_info}."
else:
skipped_info = ""
logger.debug(
f"Server {horde_url} has no valid generations to do for us.{skipped_info}"
)
continue
current_id = pop["id"]
logger.debug(
f"Request with id {current_id} picked up. Initiating work..."
)
current_payload = pop["payload"]
if "toggles" in current_payload and current_payload["toggles"] == None:
logger.error(f"Received Bad payload: {pop}")
current_id = None
current_payload = None
current_generation = None
loop_retry = 0
wait_time = 10
continue
# TODO Fix toggles for automatic.
current_payload["toggles"] = current_payload.get("toggles", [1, 4])
# In bridge-mode, matrix is prepared on the horde and split in multiple nodes
# if 0 in current_payload["toggles"]:
# current_payload["toggles"].remove(0)
# # TODO fix in AUTOMATIC 1111.
# if 8 not in current_payload["toggles"]:
# if horde_censor_nsfw and not horde_nsfw:
# current_payload["toggles"].append(8)
prompt = current_payload.get("prompt", "")
negative_prompt = None
if "###" in prompt:
prompt, negative_prompt = prompt.split("###")[0:2]
try:
seed = int(current_payload.get("seed", -1))
except ValueError:
seed = -1
# TODO upscaling!!!
processing_data = StableDiffusionProcessingTxt2Img(
sd_model=shared.sd_model,
seed=seed,
prompt=prompt,
negative_prompt=negative_prompt,
steps=current_payload.get("ddim_steps", 50),
cfg_scale=current_payload.get("cfg_scale", 5.0),
height=current_payload.get("height", 512),
width=current_payload.get("width", 512),
n_iter=current_payload.get("n_iter", 1),
sampler_index=get_sampler_index_by_name(
current_payload.get("sampler_name")
),
do_not_save_samples=True,
do_not_save_grid=True,
)
shared.total_tqdm.updateTotal(current_payload.get("ddim_steps", 50))
processed = process_images(processing_data)
shared.total_tqdm.clear()
images = processed.images
seed = processed.seed
buffer = BytesIO()
# We send as WebP to avoid using all the horde bandwidth
images[0].save(buffer, format="WebP", quality=90)
# logger.info(processed.js())
submit_dict = {
"id": current_id,
"generation": base64.b64encode(buffer.getvalue()).decode("utf8"),
"api_key": api_key,
"seed": seed,
"max_pixels": horde_max_pixels,
}
current_generation = seed
while current_id and current_generation != None:
try:
submit_req = requests.post(
horde_url + "/api/v2/generate/submit",
json=submit_dict,
headers=headers,
)
try:
submit = submit_req.json()
except json.decoder.JSONDecodeError:
logger.error(
f"Something has gone wrong with {horde_url} during submit. Please inform its administrator! (Retry {loop_retry}/10)"
)
time.sleep(interval)
continue
if submit_req.status_code == 404:
logger.warning(
f"The generation we were working on got stale. Aborting!"
)
elif not submit_req.ok:
logger.warning(
f"During gen submit, server {horde_url} responded with status code {submit_req.status_code}: {submit['message']}. Waiting for 10 seconds... (Retry {loop_retry}/10)"
)
if "errors" in submit:
logger.warning(
f"Detailed Request Errors: {submit['errors']}"
)
time.sleep(10)
continue
else:
logger.info(
f'Submitted generation with id {current_id} and contributed for {submit_req.json()["reward"]}'
)
current_id = None
current_payload = None
current_generation = None
loop_retry = 0
except requests.exceptions.ConnectionError:
logger.warning(
f"Server {horde_url} unavailable during submit. Waiting 10 seconds... (Retry {loop_retry}/10)"
)
time.sleep(10)
continue
wait_time = 1
# TODO Fix on script + ui reload.
class Script(scripts.Script):
def title(self):
return "Run a StableHorde Bridge"
def show(self, is_img2img):
return not is_img2img
def ui(self, is_img2img):
# TODO menu to put worker into maintenance.
with gr.Box():
with gr.Column():
info_label = gr.Markdown(
"*A small tool which allows you to host a Stable Horde from AUTOMATIC1111's WebUi*",
visible=False,
)
warning_label = gr.Markdown(
"**WARNING** You can run txt2img and img2img with an active bridge. Your requests will "
"be prioritized. Running custom scripts *should* also be possible, but may cause errors. "
"\n\nInterrupt and skip buttons will interrupt horde request if it's running! Be careful.",
visible=False,
)
with gr.Column():
api_key = gr.Textbox(label="API key", visible=False)
horde_name = gr.Textbox(
label="Horde name", visible=False, placeholder="Awesome instance"
)
with gr.Column():
horde_max_power = gr.Slider(
label="Horde max power",
minimum=2,
maximum=32,
value=8,
step=1,
visible=False,
)
with gr.Row():
horde_show_extra_settings = gr.Checkbox(
label="Additional settings", visible=False
)
with gr.Box(visible=False) as horde_extra_settings:
with gr.Row():
horde_nsfw = gr.Checkbox(value=True, label="Allow NSFW")
horde_censor_nsfw = gr.Checkbox(value=False, label="Censor NSFW")
with gr.Column():
horde_priority_usernames = gr.Textbox(label="Priority usernames")
horde_blacklist = gr.Textbox(label="Blacklist")
# horde_censorlist = gr.Textbox(label="Censorlist")
with gr.Row():
horde_start = gr.Button("Start", visible=False)
horde_stop = gr.Button("Stop", visible=False)
with gr.Row():
# TODO better run status.
run_status = gr.Textbox(
label="Status",
placeholder="See command prompt for status messages",
interactive=False,
visible=False,
)
horde_start.click(
run_bridge,
inputs=[
api_key,
horde_name,
horde_max_power,
horde_nsfw,
horde_censor_nsfw,
horde_priority_usernames,
horde_blacklist,
],
outputs=[],
)
horde_stop.click(horde_state.stop_horde, inputs=[], outputs=[])
horde_show_extra_settings.change(
fn=lambda x: gr_show(x),
inputs=[horde_show_extra_settings],
outputs=[horde_extra_settings],
)
horde_censor_nsfw.change(toggle_nsfw, inputs=[horde_censor_nsfw], outputs=[])
return (
[info_label, warning_label]
if is_img2img
else [
api_key,
info_label,
warning_label,
horde_name,
horde_max_power,
horde_nsfw,
horde_censor_nsfw,
horde_show_extra_settings,
horde_start,
horde_stop,
run_status,
]
)
def run(self, p, *args):
# TODO Add option to run local requests through the horde (?).
pass
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment