Last active
July 29, 2023 20:01
-
-
Save vedroboev/1f76d1a4347a54df96956b147e7f26fb to your computer and use it in GitHub Desktop.
A custom script for running a Stable Horde bridge inside AUTOMATIC1111's WebUi
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| import base64 | |
| import json | |
| import logging | |
| import time | |
| from io import BytesIO | |
| import gradio as gr | |
| import modules.scripts as scripts | |
| import modules.sd_samplers as samplers | |
| import modules.shared as shared | |
| import requests | |
| import webui | |
| from modules.processing import StableDiffusionProcessingTxt2Img, process_images | |
| logging.basicConfig() | |
| logger = logging.getLogger("HORDE-BRIDGE") | |
| logger.setLevel(logging.DEBUG) | |
| class HordeState: | |
| interrupted = False | |
| stopped = False | |
| def interrupt_horde(self): | |
| self.interrupted = True | |
| def stop_horde(self): | |
| self.stopped = True | |
| def start_horde(self): | |
| self.stopped = False | |
| self.interrupted = False | |
| horde_state = HordeState() | |
| def gr_show(visible=True): | |
| return {"visible": visible, "__type__": "update"} | |
| def get_sampler_index_by_name(name: str): | |
| aliases = [ | |
| sampler.aliases[0].lower() for sampler in samplers.samplers_data_k_diffusion | |
| ] | |
| aliases += ["ddim", "plms"] | |
| try: | |
| return aliases.index((name or "k_lms").lower()) | |
| except ValueError: | |
| logger.error(f'Sampler "{name}" not found, defaulting to k_lms...') | |
| return aliases.index("k_lms") | |
| def toggle_nsfw(value): | |
| shared.opts.filter_nsfw = value | |
| def run_bridge( | |
| api_key, | |
| horde_name="Awesome Instance", | |
| horde_max_power=8, | |
| horde_nsfw=True, | |
| horde_censor_nsfw=False, | |
| priority_usernames=[], | |
| horde_blacklist=[], | |
| horde_url="https://stablehorde.net", | |
| interval=1, | |
| ): | |
| # A slightly adapted version of run_bridge function from hlky's webui. | |
| headers = {"apikey": api_key} | |
| horde_state.start_horde() | |
| current_id = None | |
| current_payload = None | |
| horde_max_pixels = horde_max_power * 8 * 64 * 64 | |
| loop_retry = 0 | |
| priority_usernames = priority_usernames or [] | |
| if isinstance(priority_usernames, str): | |
| priority_usernames = priority_usernames.split(",") | |
| horde_blacklist = horde_blacklist or [] | |
| if isinstance(horde_blacklist, str): | |
| horde_blacklist = horde_blacklist.split(",") | |
| wait_time = 1 | |
| while True: | |
| time.sleep(wait_time) | |
| if horde_state.stopped: | |
| return logger.info("Finished working!") | |
| if horde_state.interrupted: | |
| logger.info("Horde interrupted, awaiting resume...") | |
| while horde_state.interrupted: | |
| time.sleep(interval) | |
| logger.info("Interrupt finished, resuming horde...") | |
| with webui.queue_lock: | |
| if loop_retry > 10 and current_id: | |
| logger.error( | |
| f"Exceeded retry count {loop_retry} for generation id {current_id}. Aborting generation!" | |
| ) | |
| current_id = None | |
| current_payload = None | |
| current_generation = None | |
| loop_retry = 0 | |
| elif current_id: | |
| logger.debug( | |
| f"Retrying ({loop_retry}/10) for generation id {current_id}..." | |
| ) | |
| gen_dict = { | |
| "name": horde_name, | |
| "max_pixels": horde_max_pixels, | |
| "priority_usernames": priority_usernames, | |
| "nsfw": horde_nsfw, | |
| "blacklist": horde_blacklist, | |
| } | |
| if current_id: | |
| loop_retry += 1 | |
| else: | |
| try: | |
| pop_req = requests.post( | |
| horde_url + "/api/v2/generate/pop", | |
| json=gen_dict, | |
| headers=headers, | |
| ) | |
| except requests.exceptions.ConnectionError: | |
| logger.warning( | |
| f"Server {horde_url} unavailable during pop. Waiting 10 seconds..." | |
| ) | |
| wait_time = 10 | |
| continue | |
| except requests.exceptions.JSONDecodeError(): | |
| logger.warning( | |
| f"Server {horde_url} unavailable during pop. Waiting 10 seconds..." | |
| ) | |
| wait_time = 10 | |
| continue | |
| try: | |
| pop = pop_req.json() | |
| except json.decoder.JSONDecodeError: | |
| logger.error( | |
| f"Could not decode response from {horde_url} as json. Please inform its administrator!" | |
| ) | |
| continue | |
| if pop == None: | |
| logger.error( | |
| f"Something has gone wrong with {horde_url}. Please inform its administrator!" | |
| ) | |
| continue | |
| if not pop_req.ok: | |
| message = pop["message"] | |
| logger.warning( | |
| f"During gen pop, server {horde_url} responded with status code {pop_req.status_code}: {pop['message']}. Waiting for 10 seconds..." | |
| ) | |
| if "errors" in pop: | |
| logger.warning(f"Detailed Request Errors: {pop['errors']}") | |
| wait_time = 10 | |
| continue | |
| if not pop.get("id"): | |
| skipped_info = pop.get("skipped") | |
| if skipped_info and len(skipped_info): | |
| skipped_info = f" Skipped Info: {skipped_info}." | |
| else: | |
| skipped_info = "" | |
| logger.debug( | |
| f"Server {horde_url} has no valid generations to do for us.{skipped_info}" | |
| ) | |
| continue | |
| current_id = pop["id"] | |
| logger.debug( | |
| f"Request with id {current_id} picked up. Initiating work..." | |
| ) | |
| current_payload = pop["payload"] | |
| if "toggles" in current_payload and current_payload["toggles"] == None: | |
| logger.error(f"Received Bad payload: {pop}") | |
| current_id = None | |
| current_payload = None | |
| current_generation = None | |
| loop_retry = 0 | |
| wait_time = 10 | |
| continue | |
| # TODO Fix toggles for automatic. | |
| current_payload["toggles"] = current_payload.get("toggles", [1, 4]) | |
| # In bridge-mode, matrix is prepared on the horde and split in multiple nodes | |
| # if 0 in current_payload["toggles"]: | |
| # current_payload["toggles"].remove(0) | |
| # # TODO fix in AUTOMATIC 1111. | |
| # if 8 not in current_payload["toggles"]: | |
| # if horde_censor_nsfw and not horde_nsfw: | |
| # current_payload["toggles"].append(8) | |
| prompt = current_payload.get("prompt", "") | |
| negative_prompt = None | |
| if "###" in prompt: | |
| prompt, negative_prompt = prompt.split("###")[0:2] | |
| try: | |
| seed = int(current_payload.get("seed", -1)) | |
| except ValueError: | |
| seed = -1 | |
| # TODO upscaling!!! | |
| processing_data = StableDiffusionProcessingTxt2Img( | |
| sd_model=shared.sd_model, | |
| seed=seed, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| steps=current_payload.get("ddim_steps", 50), | |
| cfg_scale=current_payload.get("cfg_scale", 5.0), | |
| height=current_payload.get("height", 512), | |
| width=current_payload.get("width", 512), | |
| n_iter=current_payload.get("n_iter", 1), | |
| sampler_index=get_sampler_index_by_name( | |
| current_payload.get("sampler_name") | |
| ), | |
| do_not_save_samples=True, | |
| do_not_save_grid=True, | |
| ) | |
| shared.total_tqdm.updateTotal(current_payload.get("ddim_steps", 50)) | |
| processed = process_images(processing_data) | |
| shared.total_tqdm.clear() | |
| images = processed.images | |
| seed = processed.seed | |
| buffer = BytesIO() | |
| # We send as WebP to avoid using all the horde bandwidth | |
| images[0].save(buffer, format="WebP", quality=90) | |
| # logger.info(processed.js()) | |
| submit_dict = { | |
| "id": current_id, | |
| "generation": base64.b64encode(buffer.getvalue()).decode("utf8"), | |
| "api_key": api_key, | |
| "seed": seed, | |
| "max_pixels": horde_max_pixels, | |
| } | |
| current_generation = seed | |
| while current_id and current_generation != None: | |
| try: | |
| submit_req = requests.post( | |
| horde_url + "/api/v2/generate/submit", | |
| json=submit_dict, | |
| headers=headers, | |
| ) | |
| try: | |
| submit = submit_req.json() | |
| except json.decoder.JSONDecodeError: | |
| logger.error( | |
| f"Something has gone wrong with {horde_url} during submit. Please inform its administrator! (Retry {loop_retry}/10)" | |
| ) | |
| time.sleep(interval) | |
| continue | |
| if submit_req.status_code == 404: | |
| logger.warning( | |
| f"The generation we were working on got stale. Aborting!" | |
| ) | |
| elif not submit_req.ok: | |
| logger.warning( | |
| f"During gen submit, server {horde_url} responded with status code {submit_req.status_code}: {submit['message']}. Waiting for 10 seconds... (Retry {loop_retry}/10)" | |
| ) | |
| if "errors" in submit: | |
| logger.warning( | |
| f"Detailed Request Errors: {submit['errors']}" | |
| ) | |
| time.sleep(10) | |
| continue | |
| else: | |
| logger.info( | |
| f'Submitted generation with id {current_id} and contributed for {submit_req.json()["reward"]}' | |
| ) | |
| current_id = None | |
| current_payload = None | |
| current_generation = None | |
| loop_retry = 0 | |
| except requests.exceptions.ConnectionError: | |
| logger.warning( | |
| f"Server {horde_url} unavailable during submit. Waiting 10 seconds... (Retry {loop_retry}/10)" | |
| ) | |
| time.sleep(10) | |
| continue | |
| wait_time = 1 | |
| # TODO Fix on script + ui reload. | |
| class Script(scripts.Script): | |
| def title(self): | |
| return "Run a StableHorde Bridge" | |
| def show(self, is_img2img): | |
| return not is_img2img | |
| def ui(self, is_img2img): | |
| # TODO menu to put worker into maintenance. | |
| with gr.Box(): | |
| with gr.Column(): | |
| info_label = gr.Markdown( | |
| "*A small tool which allows you to host a Stable Horde from AUTOMATIC1111's WebUi*", | |
| visible=False, | |
| ) | |
| warning_label = gr.Markdown( | |
| "**WARNING** You can run txt2img and img2img with an active bridge. Your requests will " | |
| "be prioritized. Running custom scripts *should* also be possible, but may cause errors. " | |
| "\n\nInterrupt and skip buttons will interrupt horde request if it's running! Be careful.", | |
| visible=False, | |
| ) | |
| with gr.Column(): | |
| api_key = gr.Textbox(label="API key", visible=False) | |
| horde_name = gr.Textbox( | |
| label="Horde name", visible=False, placeholder="Awesome instance" | |
| ) | |
| with gr.Column(): | |
| horde_max_power = gr.Slider( | |
| label="Horde max power", | |
| minimum=2, | |
| maximum=32, | |
| value=8, | |
| step=1, | |
| visible=False, | |
| ) | |
| with gr.Row(): | |
| horde_show_extra_settings = gr.Checkbox( | |
| label="Additional settings", visible=False | |
| ) | |
| with gr.Box(visible=False) as horde_extra_settings: | |
| with gr.Row(): | |
| horde_nsfw = gr.Checkbox(value=True, label="Allow NSFW") | |
| horde_censor_nsfw = gr.Checkbox(value=False, label="Censor NSFW") | |
| with gr.Column(): | |
| horde_priority_usernames = gr.Textbox(label="Priority usernames") | |
| horde_blacklist = gr.Textbox(label="Blacklist") | |
| # horde_censorlist = gr.Textbox(label="Censorlist") | |
| with gr.Row(): | |
| horde_start = gr.Button("Start", visible=False) | |
| horde_stop = gr.Button("Stop", visible=False) | |
| with gr.Row(): | |
| # TODO better run status. | |
| run_status = gr.Textbox( | |
| label="Status", | |
| placeholder="See command prompt for status messages", | |
| interactive=False, | |
| visible=False, | |
| ) | |
| horde_start.click( | |
| run_bridge, | |
| inputs=[ | |
| api_key, | |
| horde_name, | |
| horde_max_power, | |
| horde_nsfw, | |
| horde_censor_nsfw, | |
| horde_priority_usernames, | |
| horde_blacklist, | |
| ], | |
| outputs=[], | |
| ) | |
| horde_stop.click(horde_state.stop_horde, inputs=[], outputs=[]) | |
| horde_show_extra_settings.change( | |
| fn=lambda x: gr_show(x), | |
| inputs=[horde_show_extra_settings], | |
| outputs=[horde_extra_settings], | |
| ) | |
| horde_censor_nsfw.change(toggle_nsfw, inputs=[horde_censor_nsfw], outputs=[]) | |
| return ( | |
| [info_label, warning_label] | |
| if is_img2img | |
| else [ | |
| api_key, | |
| info_label, | |
| warning_label, | |
| horde_name, | |
| horde_max_power, | |
| horde_nsfw, | |
| horde_censor_nsfw, | |
| horde_show_extra_settings, | |
| horde_start, | |
| horde_stop, | |
| run_status, | |
| ] | |
| ) | |
| def run(self, p, *args): | |
| # TODO Add option to run local requests through the horde (?). | |
| pass |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment