120 lines
3.8 KiB
Python
Executable File
120 lines
3.8 KiB
Python
Executable File
"""
|
|
Audio Generation Tool
|
|
Generate audio using ComfyUI.
|
|
"""
|
|
from typing import Dict, Any, Optional
|
|
from loguru import logger
|
|
|
|
from tools.base import BaseTool, ToolResult
|
|
from tools.comfyui.base import ComfyUIClient
|
|
|
|
|
|
class AudioGenerationTool(BaseTool):
|
|
"""Generate audio using ComfyUI."""
|
|
|
|
def __init__(self, config: Optional[Dict] = None):
|
|
self.client = ComfyUIClient()
|
|
super().__init__(config)
|
|
|
|
@property
|
|
def name(self) -> str:
|
|
return "generate_audio"
|
|
|
|
@property
|
|
def description(self) -> str:
|
|
return "Generate audio from a text description. Creates sound effects, music, or speech."
|
|
|
|
@property
|
|
def parameters(self) -> Dict[str, Any]:
|
|
return {
|
|
"type": "object",
|
|
"properties": {
|
|
"prompt": {
|
|
"type": "string",
|
|
"description": "Description of the audio to generate"
|
|
},
|
|
"negative_prompt": {
|
|
"type": "string",
|
|
"description": "What to avoid in the audio (optional)",
|
|
"default": ""
|
|
},
|
|
"duration": {
|
|
"type": "number",
|
|
"description": "Duration in seconds",
|
|
"default": 10.0
|
|
},
|
|
"seed": {
|
|
"type": "integer",
|
|
"description": "Random seed for reproducibility (optional)"
|
|
}
|
|
},
|
|
"required": ["prompt"]
|
|
}
|
|
|
|
async def execute(
|
|
self,
|
|
prompt: str,
|
|
negative_prompt: str = "",
|
|
duration: float = 10.0,
|
|
seed: Optional[int] = None,
|
|
**kwargs
|
|
) -> ToolResult:
|
|
"""Generate audio."""
|
|
self._log_execution({"prompt": prompt[:100], "duration": duration})
|
|
|
|
# Reload config to get latest settings
|
|
self.client.reload_config()
|
|
|
|
# Load the audio workflow
|
|
workflow = self.client.load_workflow("audio")
|
|
|
|
if not workflow:
|
|
return ToolResult(
|
|
success=False,
|
|
error="Audio generation workflow not configured. Please upload a workflow JSON in the admin panel."
|
|
)
|
|
|
|
try:
|
|
# Modify workflow with parameters
|
|
modified_workflow = self.client.modify_workflow(
|
|
workflow,
|
|
prompt=prompt,
|
|
workflow_type="audio",
|
|
negative_prompt=negative_prompt,
|
|
duration=duration,
|
|
seed=seed
|
|
)
|
|
|
|
# Queue the prompt
|
|
prompt_id = await self.client.queue_prompt(modified_workflow)
|
|
logger.info(f"Queued audio generation: {prompt_id}")
|
|
|
|
# Wait for completion
|
|
outputs = await self.client.wait_for_completion(
|
|
prompt_id,
|
|
timeout=300 # 5 minutes for audio generation
|
|
)
|
|
|
|
# Get output files
|
|
audio_files = await self.client.get_output_files(outputs, "audio")
|
|
|
|
if not audio_files:
|
|
return ToolResult(
|
|
success=False,
|
|
error="No audio was generated"
|
|
)
|
|
|
|
result = f"Successfully generated audio:\n"
|
|
result += "\n".join(f" - {a.get('filename', 'audio')}" for a in audio_files)
|
|
|
|
self._log_success(result)
|
|
return ToolResult(success=True, data=result)
|
|
|
|
except TimeoutError as e:
|
|
self._log_error(str(e))
|
|
return ToolResult(success=False, error="Audio generation timed out")
|
|
|
|
except Exception as e:
|
|
self._log_error(str(e))
|
|
return ToolResult(success=False, error=str(e))
|