326 lines
12 KiB
Python
Executable File
326 lines
12 KiB
Python
Executable File
"""
|
|
ComfyUI Base Connector
|
|
Shared functionality for all ComfyUI tools.
|
|
"""
|
|
import json
|
|
import uuid
|
|
from typing import Dict, Any, Optional, List
|
|
from pathlib import Path
|
|
import httpx
|
|
import asyncio
|
|
from loguru import logger
|
|
|
|
from config import load_config_from_db, settings, get_workflows_dir
|
|
|
|
|
|
class ComfyUIClient:
|
|
"""Base client for ComfyUI API interactions."""
|
|
|
|
def __init__(self):
|
|
config = load_config_from_db()
|
|
self.base_url = config.get("comfyui_host", settings.comfyui_host)
|
|
|
|
def reload_config(self):
|
|
"""Reload configuration from database."""
|
|
config = load_config_from_db()
|
|
self.base_url = config.get("comfyui_host", settings.comfyui_host)
|
|
return config
|
|
|
|
def load_workflow(self, workflow_type: str) -> Optional[Dict[str, Any]]:
|
|
"""Load a workflow JSON file."""
|
|
workflows_dir = get_workflows_dir()
|
|
workflow_path = workflows_dir / f"{workflow_type}.json"
|
|
|
|
if not workflow_path.exists():
|
|
return None
|
|
|
|
with open(workflow_path, "r") as f:
|
|
return json.load(f)
|
|
|
|
async def queue_prompt(self, workflow: Dict[str, Any]) -> str:
|
|
"""Queue a workflow and return the prompt ID."""
|
|
client_id = str(uuid.uuid4())
|
|
|
|
payload = {
|
|
"prompt": workflow,
|
|
"client_id": client_id
|
|
}
|
|
|
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
response = await client.post(
|
|
f"{self.base_url}/prompt",
|
|
json=payload
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
raise Exception(f"Failed to queue prompt: {response.status_code}")
|
|
|
|
data = response.json()
|
|
return data.get("prompt_id", client_id)
|
|
|
|
async def get_history(self, prompt_id: str) -> Optional[Dict]:
|
|
"""Get the execution history for a prompt."""
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
response = await client.get(
|
|
f"{self.base_url}/history/{prompt_id}"
|
|
)
|
|
|
|
if response.status_code != 200:
|
|
return None
|
|
|
|
data = response.json()
|
|
return data.get(prompt_id)
|
|
|
|
async def wait_for_completion(
|
|
self,
|
|
prompt_id: str,
|
|
timeout: int = 300,
|
|
poll_interval: float = 1.0
|
|
) -> Optional[Dict]:
|
|
"""Wait for a prompt to complete and return the result."""
|
|
elapsed = 0
|
|
|
|
while elapsed < timeout:
|
|
history = await self.get_history(prompt_id)
|
|
|
|
if history:
|
|
outputs = history.get("outputs", {})
|
|
if outputs:
|
|
return outputs
|
|
|
|
await asyncio.sleep(poll_interval)
|
|
elapsed += poll_interval
|
|
|
|
raise TimeoutError(f"Prompt {prompt_id} did not complete within {timeout} seconds")
|
|
|
|
def load_workflow(self, workflow_type: str) -> Optional[Dict[str, Any]]:
|
|
"""Load a workflow JSON file."""
|
|
workflows_dir = get_workflows_dir()
|
|
workflow_path = workflows_dir / f"{workflow_type}.json"
|
|
|
|
if not workflow_path.exists():
|
|
return None
|
|
|
|
with open(workflow_path, "r") as f:
|
|
return json.load(f)
|
|
|
|
def get_node_mappings(self, workflow_type: str) -> Dict[str, str]:
|
|
"""Get node ID mappings from config."""
|
|
config = load_config_from_db()
|
|
|
|
# Map config keys to workflow type
|
|
prefix = f"{workflow_type}_"
|
|
mappings = {}
|
|
|
|
for key, value in config.items():
|
|
if key.startswith(prefix) and key.endswith("_node"):
|
|
# Extract the node type (e.g., "image_prompt_node" -> "prompt")
|
|
node_type = key[len(prefix):-5] # Remove prefix and "_node"
|
|
if value: # Only include non-empty values
|
|
mappings[node_type] = value
|
|
|
|
return mappings
|
|
|
|
def modify_workflow(
|
|
self,
|
|
workflow: Dict[str, Any],
|
|
prompt: str,
|
|
workflow_type: str = "image",
|
|
**kwargs
|
|
) -> Dict[str, Any]:
|
|
"""
|
|
Modify a workflow with prompt and other parameters.
|
|
|
|
Uses node mappings from config to inject values into correct nodes.
|
|
"""
|
|
workflow = json.loads(json.dumps(workflow)) # Deep copy
|
|
config = self.reload_config()
|
|
|
|
# Get node mappings for this workflow type
|
|
mappings = self.get_node_mappings(workflow_type)
|
|
|
|
# Default values from config
|
|
defaults = {
|
|
"image": {
|
|
"default_size": config.get("image_default_size", "512x512"),
|
|
"default_steps": config.get("image_default_steps", 20),
|
|
},
|
|
"video": {
|
|
"default_frames": config.get("video_default_frames", 24),
|
|
},
|
|
"audio": {
|
|
"default_duration": config.get("audio_default_duration", 10),
|
|
}
|
|
}
|
|
|
|
# Inject prompt
|
|
prompt_node = mappings.get("prompt")
|
|
if prompt_node and prompt_node in workflow:
|
|
node = workflow[prompt_node]
|
|
if "inputs" in node:
|
|
if "text" in node["inputs"]:
|
|
node["inputs"]["text"] = prompt
|
|
elif "prompt" in node["inputs"]:
|
|
node["inputs"]["prompt"] = prompt
|
|
|
|
# Inject negative prompt
|
|
negative_prompt = kwargs.get("negative_prompt", "")
|
|
negative_node = mappings.get("negative_prompt")
|
|
if negative_node and negative_node in workflow and negative_prompt:
|
|
node = workflow[negative_node]
|
|
if "inputs" in node and "text" in node["inputs"]:
|
|
node["inputs"]["text"] = negative_prompt
|
|
|
|
# Inject seed
|
|
seed = kwargs.get("seed")
|
|
seed_node = mappings.get("seed")
|
|
if seed_node and seed_node in workflow:
|
|
node = workflow[seed_node]
|
|
if "inputs" in node:
|
|
# Common seed input names
|
|
for seed_key in ["seed", "noise_seed", "sampler_seed"]:
|
|
if seed_key in node["inputs"]:
|
|
node["inputs"][seed_key] = seed if seed else self._generate_seed()
|
|
break
|
|
|
|
# Inject steps
|
|
steps = kwargs.get("steps")
|
|
steps_node = mappings.get("steps")
|
|
if steps_node and steps_node in workflow:
|
|
node = workflow[steps_node]
|
|
if "inputs" in node and "steps" in node["inputs"]:
|
|
node["inputs"]["steps"] = steps if steps else defaults.get(workflow_type, {}).get("default_steps", 20)
|
|
|
|
# Inject width/height (for images)
|
|
if workflow_type == "image":
|
|
size = kwargs.get("size", defaults.get("image", {}).get("default_size", "512x512"))
|
|
if "x" in str(size):
|
|
width, height = map(int, str(size).split("x"))
|
|
else:
|
|
width = height = int(size)
|
|
|
|
width_node = mappings.get("width")
|
|
if width_node and width_node in workflow:
|
|
node = workflow[width_node]
|
|
if "inputs" in node and "width" in node["inputs"]:
|
|
node["inputs"]["width"] = width
|
|
|
|
height_node = mappings.get("height")
|
|
if height_node and height_node in workflow:
|
|
node = workflow[height_node]
|
|
if "inputs" in node and "height" in node["inputs"]:
|
|
node["inputs"]["height"] = height
|
|
|
|
# Inject frames (for video)
|
|
if workflow_type == "video":
|
|
frames = kwargs.get("frames", defaults.get("video", {}).get("default_frames", 24))
|
|
frames_node = mappings.get("frames")
|
|
if frames_node and frames_node in workflow:
|
|
node = workflow[frames_node]
|
|
if "inputs" in node:
|
|
for key in ["frames", "frame_count", "length"]:
|
|
if key in node["inputs"]:
|
|
node["inputs"][key] = frames
|
|
break
|
|
|
|
# Inject duration (for audio)
|
|
if workflow_type == "audio":
|
|
duration = kwargs.get("duration", defaults.get("audio", {}).get("default_duration", 10))
|
|
duration_node = mappings.get("duration")
|
|
if duration_node and duration_node in workflow:
|
|
node = workflow[duration_node]
|
|
if "inputs" in node:
|
|
for key in ["duration", "length", "seconds"]:
|
|
if key in node["inputs"]:
|
|
node["inputs"][key] = duration
|
|
break
|
|
|
|
# Inject CFG scale (for images)
|
|
if workflow_type == "image":
|
|
cfg = kwargs.get("cfg_scale", 7.0)
|
|
cfg_node = mappings.get("cfg")
|
|
if cfg_node and cfg_node in workflow:
|
|
node = workflow[cfg_node]
|
|
if "inputs" in node:
|
|
for key in ["cfg", "cfg_scale", "guidance_scale"]:
|
|
if key in node["inputs"]:
|
|
node["inputs"][key] = cfg
|
|
break
|
|
|
|
return workflow
|
|
|
|
def _generate_seed(self) -> int:
|
|
"""Generate a random seed."""
|
|
import random
|
|
return random.randint(0, 2**32 - 1)
|
|
|
|
async def get_output_images(self, outputs: Dict) -> list:
|
|
"""Retrieve output images from ComfyUI."""
|
|
images = []
|
|
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
for node_id, output in outputs.items():
|
|
if "images" in output:
|
|
for image in output["images"]:
|
|
filename = image.get("filename")
|
|
subfolder = image.get("subfolder", "")
|
|
|
|
params = {
|
|
"filename": filename,
|
|
"type": "output"
|
|
}
|
|
if subfolder:
|
|
params["subfolder"] = subfolder
|
|
|
|
response = await client.get(
|
|
f"{self.base_url}/view",
|
|
params=params
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
images.append({
|
|
"filename": filename,
|
|
"data": response.content
|
|
})
|
|
|
|
return images
|
|
|
|
async def get_output_files(self, outputs: Dict, file_type: str = "videos") -> list:
|
|
"""Retrieve output files from ComfyUI (videos or audio)."""
|
|
files = []
|
|
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
for node_id, output in outputs.items():
|
|
if file_type in output:
|
|
for item in output[file_type]:
|
|
filename = item.get("filename")
|
|
subfolder = item.get("subfolder", "")
|
|
|
|
params = {
|
|
"filename": filename,
|
|
"type": "output"
|
|
}
|
|
if subfolder:
|
|
params["subfolder"] = subfolder
|
|
|
|
response = await client.get(
|
|
f"{self.base_url}/view",
|
|
params=params
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
files.append({
|
|
"filename": filename,
|
|
"data": response.content
|
|
})
|
|
|
|
# Also check for images (some workflows output frames)
|
|
if file_type == "videos" and "images" in output:
|
|
for image in output["images"]:
|
|
files.append({
|
|
"filename": image.get("filename"),
|
|
"type": "image"
|
|
})
|
|
|
|
return files
|