diff --git a/src/landppt/services/image/providers/stable_diffusion_provider.py b/src/landppt/services/image/providers/stable_diffusion_provider.py new file mode 100644 index 0000000..9bd41d9 --- /dev/null +++ b/src/landppt/services/image/providers/stable_diffusion_provider.py @@ -0,0 +1,372 @@ +""" +Stable Diffusion 图片生成提供者 +""" + +import asyncio +import logging +import time +import base64 +from typing import Dict, Any, Optional, List +from pathlib import Path +import aiohttp +import json + +from .base import ImageGenerationProvider +from ..models import ( + ImageInfo, ImageGenerationRequest, ImageOperationResult, + ImageProvider, ImageSourceType, ImageFormat, ImageMetadata, ImageTag +) + +logger = logging.getLogger(__name__) + + +class StableDiffusionProvider(ImageGenerationProvider): + """Stable Diffusion 图片生成提供者""" + + def __init__(self, config: Dict[str, Any]): + super().__init__(ImageProvider.STABLE_DIFFUSION, config) + + # API配置 + self.api_key = config.get('api_key') + self.api_base = config.get('api_base', 'https://api.stability.ai/v1') + self.engine_id = config.get('engine_id', 'stable-diffusion-xl-1024-v1-0') + + # 默认参数 + self.default_width = config.get('default_width', 1024) + self.default_height = config.get('default_height', 1024) + self.default_steps = config.get('default_steps', 30) + self.default_cfg_scale = config.get('default_cfg_scale', 7.0) + self.default_sampler = config.get('default_sampler', 'K_DPM_2_ANCESTRAL') + + # 速率限制 + self.rate_limit_requests = config.get('rate_limit_requests', 150) + self.rate_limit_window = config.get('rate_limit_window', 60) + + # 请求历史 + self._request_history = [] + + if not self.api_key: + logger.warning("Stable Diffusion API key not configured") + + async def generate(self, request: ImageGenerationRequest) -> ImageOperationResult: + """生成图片""" + if not self.api_key: + return ImageOperationResult( + success=False, + message="Stable Diffusion API key not configured", + error_code="api_key_missing" + ) + + try: + # 检查速率限制 + if not await self._check_rate_limit(): + return ImageOperationResult( + success=False, + message="Rate limit exceeded", + error_code="rate_limit_exceeded" + ) + + # 准备API请求 + api_request = self._prepare_api_request(request) + + # 调用Stable Diffusion API + async with aiohttp.ClientSession() as session: + async with session.post( + f"{self.api_base}/generation/{self.engine_id}/text-to-image", + headers={ + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json", + "Accept": "application/json" + }, + json=api_request, + timeout=aiohttp.ClientTimeout(total=180) # 3分钟超时 + ) as response: + + if response.status != 200: + error_text = await response.text() + logger.error(f"Stable Diffusion API error {response.status}: {error_text}") + return ImageOperationResult( + success=False, + message=f"Stable Diffusion API error: {response.status}", + error_code="api_error" + ) + + result_data = await response.json() + + # 处理API响应 + return await self._process_api_response(result_data, request) + + except asyncio.TimeoutError: + logger.error("Stable Diffusion API request timeout") + return ImageOperationResult( + success=False, + message="Request timeout", + error_code="timeout" + ) + except Exception as e: + logger.error(f"Stable Diffusion generation failed: {e}") + return ImageOperationResult( + success=False, + message=f"Generation failed: {str(e)}", + error_code="generation_error" + ) + + def _prepare_api_request(self, request: ImageGenerationRequest) -> Dict[str, Any]: + """准备API请求""" + # 使用请求中的尺寸 + width, height = request.width, request.height + + api_request = { + "text_prompts": [ + { + "text": request.prompt, + "weight": 1.0 + } + ], + "width": width, + "height": height, + "steps": request.steps or self.default_steps, + "cfg_scale": request.guidance_scale or self.default_cfg_scale, + "sampler": self.default_sampler, # 使用默认采样器 + "samples": 1, + "seed": request.seed if request.seed is not None else 0 + } + + # 添加负面提示词 + if request.negative_prompt: + api_request["text_prompts"].append({ + "text": request.negative_prompt, + "weight": -1.0 + }) + + return api_request + + async def _process_api_response(self, + response_data: Dict[str, Any], + request: ImageGenerationRequest) -> ImageOperationResult: + """处理API响应""" + try: + if 'artifacts' not in response_data or not response_data['artifacts']: + return ImageOperationResult( + success=False, + message="No image data in response", + error_code="no_data" + ) + + artifact = response_data['artifacts'][0] + + if artifact.get('finishReason') != 'SUCCESS': + return ImageOperationResult( + success=False, + message=f"Generation failed: {artifact.get('finishReason')}", + error_code="generation_failed" + ) + + # 解码base64图片数据 + image_data = base64.b64decode(artifact['base64']) + + # 保存图片 + image_path = await self._save_image(image_data, request) + + # 创建图片信息 + image_info = self._create_image_info( + image_path, len(image_data), request, artifact + ) + + return ImageOperationResult( + success=True, + message="Image generated successfully", + image_info=image_info + ) + + except Exception as e: + logger.error(f"Failed to process Stable Diffusion response: {e}") + return ImageOperationResult( + success=False, + message=f"Failed to process response: {str(e)}", + error_code="response_processing_error" + ) + + async def _save_image(self, + image_data: bytes, + request: ImageGenerationRequest) -> Path: + """保存生成的图片""" + # 生成文件名 + timestamp = int(time.time()) + filename = f"sd_{timestamp}_{hash(request.prompt) % 10000}.png" + + # 创建保存路径 + save_dir = Path("temp/images_cache/ai_generated/stable_diffusion") + save_dir.mkdir(parents=True, exist_ok=True) + image_path = save_dir / filename + + # 保存图片 + def _save(): + with open(image_path, 'wb') as f: + f.write(image_data) + + await asyncio.get_event_loop().run_in_executor(None, _save) + + return image_path + + def _create_image_info(self, + image_path: Path, + image_size: int, + request: ImageGenerationRequest, + artifact: Dict[str, Any]) -> ImageInfo: + """创建图片信息""" + # 生成图片ID + image_id = f"sd_{int(time.time())}_{hash(request.prompt) % 10000}" + + # 使用请求中的尺寸 + width, height = request.width, request.height + + # 创建元数据 + metadata = ImageMetadata( + width=width, + height=height, + format=ImageFormat.PNG, + file_size=image_size, + color_mode="RGB", + has_transparency=False + ) + + # 创建标签 + tags = self._generate_tags_from_prompt(request.prompt) + + # 构建描述 + description_parts = [f"Generated by Stable Diffusion with prompt: {request.prompt}"] + if request.negative_prompt: + description_parts.append(f"Negative prompt: {request.negative_prompt}") + if request.seed is not None: + description_parts.append(f"Seed: {request.seed}") + + return ImageInfo( + image_id=image_id, + filename=image_path.name, + title=f"AI Generated: {request.prompt[:50]}...", + description=" | ".join(description_parts), + alt_text=request.prompt, + source_type=ImageSourceType.AI_GENERATED, + provider=ImageProvider.STABLE_DIFFUSION, + original_url="", + local_path=str(image_path), + metadata=metadata, + tags=tags, + keywords=self._extract_keywords_from_prompt(request.prompt), + usage_count=0, + created_at=time.time(), + updated_at=time.time() + ) + + def _generate_tags_from_prompt(self, prompt: str) -> List[ImageTag]: + """从提示词生成标签""" + # 简单的关键词提取和标签生成 + keywords = prompt.lower().replace(',', ' ').split() + + # 过滤常见词汇 + stop_words = { + 'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', + 'with', 'by', 'very', 'highly', 'extremely', 'detailed', 'realistic' + } + keywords = [word.strip('.,!?;:') for word in keywords if word not in stop_words and len(word) > 2] + + # 生成标签 + tags = [] + for i, keyword in enumerate(keywords[:15]): # 最多15个标签 + confidence = max(0.4, 1.0 - i * 0.05) # 递减的置信度 + tags.append(ImageTag( + name=keyword, + confidence=confidence, + category="ai_generated" + )) + + return tags + + def _extract_keywords_from_prompt(self, prompt: str) -> List[str]: + """从提示词提取关键词""" + words = prompt.lower().replace(',', ' ').split() + stop_words = { + 'a', 'an', 'the', 'and', 'or', 'but', 'in', 'on', 'at', 'to', 'for', 'of', + 'with', 'by', 'very', 'highly', 'extremely' + } + keywords = [word.strip('.,!?;:') for word in words if word not in stop_words and len(word) > 2] + + return keywords[:25] # 最多25个关键词 + + async def _check_rate_limit(self) -> bool: + """检查速率限制""" + current_time = time.time() + + # 清理过期的请求记录 + self._request_history = [ + req_time for req_time in self._request_history + if current_time - req_time < self.rate_limit_window + ] + + # 检查是否超过限制 + if len(self._request_history) >= self.rate_limit_requests: + return False + + # 记录当前请求 + self._request_history.append(current_time) + return True + + async def health_check(self) -> Dict[str, Any]: + """健康检查""" + if not self.api_key: + return { + 'status': 'unhealthy', + 'message': 'API key not configured', + 'provider': self.provider.value + } + + try: + # 检查引擎列表 + async with aiohttp.ClientSession() as session: + async with session.get( + f"{self.api_base}/engines/list", + headers={"Authorization": f"Bearer {self.api_key}"}, + timeout=aiohttp.ClientTimeout(total=10) + ) as response: + + if response.status == 200: + return { + 'status': 'healthy', + 'message': 'API accessible', + 'provider': self.provider.value, + 'engine_id': self.engine_id, + 'rate_limit_remaining': self.rate_limit_requests - len(self._request_history) + } + else: + return { + 'status': 'unhealthy', + 'message': f'API error: {response.status}', + 'provider': self.provider.value + } + + except Exception as e: + return { + 'status': 'unhealthy', + 'message': f'Health check failed: {str(e)}', + 'provider': self.provider.value + } + + async def get_available_models(self) -> List[Dict[str, Any]]: + """获取可用模型列表""" + return [ + { + 'id': 'stable-diffusion-xl-1024-v1-0', + 'name': 'Stable Diffusion XL 1.0', + 'description': 'High-quality image generation with 1024x1024 resolution', + 'max_resolution': '1024x1024', + 'supported_styles': ['photographic', 'digital-art', 'comic-book', 'fantasy-art', 'line-art', 'analog-film', 'neon-punk', 'isometric', 'low-poly', 'origami', 'modeling-compound', 'cinematic', 'anime', '3d-model', 'pixel-art', 'tile-texture'] + }, + { + 'id': 'stable-diffusion-v1-6', + 'name': 'Stable Diffusion 1.6', + 'description': 'Previous generation model with good quality', + 'max_resolution': '512x512', + 'supported_styles': ['photographic', 'digital-art', 'comic-book', 'fantasy-art'] + } + ]