test/moxie/tools/comfyui/base.py
2026-03-24 04:07:54 +00:00

326 lines
12 KiB
Python
Executable File

"""
ComfyUI Base Connector
Shared functionality for all ComfyUI tools.
"""
import json
import uuid
from typing import Dict, Any, Optional, List
from pathlib import Path
import httpx
import asyncio
from loguru import logger
from config import load_config_from_db, settings, get_workflows_dir
class ComfyUIClient:
"""Base client for ComfyUI API interactions."""
def __init__(self):
config = load_config_from_db()
self.base_url = config.get("comfyui_host", settings.comfyui_host)
def reload_config(self):
"""Reload configuration from database."""
config = load_config_from_db()
self.base_url = config.get("comfyui_host", settings.comfyui_host)
return config
def load_workflow(self, workflow_type: str) -> Optional[Dict[str, Any]]:
"""Load a workflow JSON file."""
workflows_dir = get_workflows_dir()
workflow_path = workflows_dir / f"{workflow_type}.json"
if not workflow_path.exists():
return None
with open(workflow_path, "r") as f:
return json.load(f)
async def queue_prompt(self, workflow: Dict[str, Any]) -> str:
"""Queue a workflow and return the prompt ID."""
client_id = str(uuid.uuid4())
payload = {
"prompt": workflow,
"client_id": client_id
}
async with httpx.AsyncClient(timeout=120.0) as client:
response = await client.post(
f"{self.base_url}/prompt",
json=payload
)
if response.status_code != 200:
raise Exception(f"Failed to queue prompt: {response.status_code}")
data = response.json()
return data.get("prompt_id", client_id)
async def get_history(self, prompt_id: str) -> Optional[Dict]:
"""Get the execution history for a prompt."""
async with httpx.AsyncClient(timeout=30.0) as client:
response = await client.get(
f"{self.base_url}/history/{prompt_id}"
)
if response.status_code != 200:
return None
data = response.json()
return data.get(prompt_id)
async def wait_for_completion(
self,
prompt_id: str,
timeout: int = 300,
poll_interval: float = 1.0
) -> Optional[Dict]:
"""Wait for a prompt to complete and return the result."""
elapsed = 0
while elapsed < timeout:
history = await self.get_history(prompt_id)
if history:
outputs = history.get("outputs", {})
if outputs:
return outputs
await asyncio.sleep(poll_interval)
elapsed += poll_interval
raise TimeoutError(f"Prompt {prompt_id} did not complete within {timeout} seconds")
def load_workflow(self, workflow_type: str) -> Optional[Dict[str, Any]]:
"""Load a workflow JSON file."""
workflows_dir = get_workflows_dir()
workflow_path = workflows_dir / f"{workflow_type}.json"
if not workflow_path.exists():
return None
with open(workflow_path, "r") as f:
return json.load(f)
def get_node_mappings(self, workflow_type: str) -> Dict[str, str]:
"""Get node ID mappings from config."""
config = load_config_from_db()
# Map config keys to workflow type
prefix = f"{workflow_type}_"
mappings = {}
for key, value in config.items():
if key.startswith(prefix) and key.endswith("_node"):
# Extract the node type (e.g., "image_prompt_node" -> "prompt")
node_type = key[len(prefix):-5] # Remove prefix and "_node"
if value: # Only include non-empty values
mappings[node_type] = value
return mappings
def modify_workflow(
self,
workflow: Dict[str, Any],
prompt: str,
workflow_type: str = "image",
**kwargs
) -> Dict[str, Any]:
"""
Modify a workflow with prompt and other parameters.
Uses node mappings from config to inject values into correct nodes.
"""
workflow = json.loads(json.dumps(workflow)) # Deep copy
config = self.reload_config()
# Get node mappings for this workflow type
mappings = self.get_node_mappings(workflow_type)
# Default values from config
defaults = {
"image": {
"default_size": config.get("image_default_size", "512x512"),
"default_steps": config.get("image_default_steps", 20),
},
"video": {
"default_frames": config.get("video_default_frames", 24),
},
"audio": {
"default_duration": config.get("audio_default_duration", 10),
}
}
# Inject prompt
prompt_node = mappings.get("prompt")
if prompt_node and prompt_node in workflow:
node = workflow[prompt_node]
if "inputs" in node:
if "text" in node["inputs"]:
node["inputs"]["text"] = prompt
elif "prompt" in node["inputs"]:
node["inputs"]["prompt"] = prompt
# Inject negative prompt
negative_prompt = kwargs.get("negative_prompt", "")
negative_node = mappings.get("negative_prompt")
if negative_node and negative_node in workflow and negative_prompt:
node = workflow[negative_node]
if "inputs" in node and "text" in node["inputs"]:
node["inputs"]["text"] = negative_prompt
# Inject seed
seed = kwargs.get("seed")
seed_node = mappings.get("seed")
if seed_node and seed_node in workflow:
node = workflow[seed_node]
if "inputs" in node:
# Common seed input names
for seed_key in ["seed", "noise_seed", "sampler_seed"]:
if seed_key in node["inputs"]:
node["inputs"][seed_key] = seed if seed else self._generate_seed()
break
# Inject steps
steps = kwargs.get("steps")
steps_node = mappings.get("steps")
if steps_node and steps_node in workflow:
node = workflow[steps_node]
if "inputs" in node and "steps" in node["inputs"]:
node["inputs"]["steps"] = steps if steps else defaults.get(workflow_type, {}).get("default_steps", 20)
# Inject width/height (for images)
if workflow_type == "image":
size = kwargs.get("size", defaults.get("image", {}).get("default_size", "512x512"))
if "x" in str(size):
width, height = map(int, str(size).split("x"))
else:
width = height = int(size)
width_node = mappings.get("width")
if width_node and width_node in workflow:
node = workflow[width_node]
if "inputs" in node and "width" in node["inputs"]:
node["inputs"]["width"] = width
height_node = mappings.get("height")
if height_node and height_node in workflow:
node = workflow[height_node]
if "inputs" in node and "height" in node["inputs"]:
node["inputs"]["height"] = height
# Inject frames (for video)
if workflow_type == "video":
frames = kwargs.get("frames", defaults.get("video", {}).get("default_frames", 24))
frames_node = mappings.get("frames")
if frames_node and frames_node in workflow:
node = workflow[frames_node]
if "inputs" in node:
for key in ["frames", "frame_count", "length"]:
if key in node["inputs"]:
node["inputs"][key] = frames
break
# Inject duration (for audio)
if workflow_type == "audio":
duration = kwargs.get("duration", defaults.get("audio", {}).get("default_duration", 10))
duration_node = mappings.get("duration")
if duration_node and duration_node in workflow:
node = workflow[duration_node]
if "inputs" in node:
for key in ["duration", "length", "seconds"]:
if key in node["inputs"]:
node["inputs"][key] = duration
break
# Inject CFG scale (for images)
if workflow_type == "image":
cfg = kwargs.get("cfg_scale", 7.0)
cfg_node = mappings.get("cfg")
if cfg_node and cfg_node in workflow:
node = workflow[cfg_node]
if "inputs" in node:
for key in ["cfg", "cfg_scale", "guidance_scale"]:
if key in node["inputs"]:
node["inputs"][key] = cfg
break
return workflow
def _generate_seed(self) -> int:
"""Generate a random seed."""
import random
return random.randint(0, 2**32 - 1)
async def get_output_images(self, outputs: Dict) -> list:
"""Retrieve output images from ComfyUI."""
images = []
async with httpx.AsyncClient(timeout=30.0) as client:
for node_id, output in outputs.items():
if "images" in output:
for image in output["images"]:
filename = image.get("filename")
subfolder = image.get("subfolder", "")
params = {
"filename": filename,
"type": "output"
}
if subfolder:
params["subfolder"] = subfolder
response = await client.get(
f"{self.base_url}/view",
params=params
)
if response.status_code == 200:
images.append({
"filename": filename,
"data": response.content
})
return images
async def get_output_files(self, outputs: Dict, file_type: str = "videos") -> list:
"""Retrieve output files from ComfyUI (videos or audio)."""
files = []
async with httpx.AsyncClient(timeout=30.0) as client:
for node_id, output in outputs.items():
if file_type in output:
for item in output[file_type]:
filename = item.get("filename")
subfolder = item.get("subfolder", "")
params = {
"filename": filename,
"type": "output"
}
if subfolder:
params["subfolder"] = subfolder
response = await client.get(
f"{self.base_url}/view",
params=params
)
if response.status_code == 200:
files.append({
"filename": filename,
"data": response.content
})
# Also check for images (some workflows output frames)
if file_type == "videos" and "images" in output:
for image in output["images"]:
files.append({
"filename": image.get("filename"),
"type": "image"
})
return files