From b9b09447f6ba84c794a2f9ac46a854dd201bc878 Mon Sep 17 00:00:00 2001 From: 13315423919 <13315423919@qq.com> Date: Fri, 7 Nov 2025 09:05:29 +0800 Subject: [PATCH] Add File --- src/landppt/ai/providers.py | 692 ++++++++++++++++++++++++++++++++++++ 1 file changed, 692 insertions(+) create mode 100644 src/landppt/ai/providers.py diff --git a/src/landppt/ai/providers.py b/src/landppt/ai/providers.py new file mode 100644 index 0000000..9e0b775 --- /dev/null +++ b/src/landppt/ai/providers.py @@ -0,0 +1,692 @@ +""" +AI provider implementations +""" + +import asyncio +import json +import logging +from typing import List, Dict, Any, Optional, AsyncGenerator, Union, Tuple + +from .base import AIProvider, AIMessage, AIResponse, MessageRole, TextContent, ImageContent, MessageContentType +from ..core.config import ai_config + +logger = logging.getLogger(__name__) + +class OpenAIProvider(AIProvider): + """OpenAI API provider""" + + def __init__(self, config: Dict[str, Any]): + super().__init__(config) + try: + import openai + self.client = openai.AsyncOpenAI( + api_key=config.get("api_key"), + base_url=config.get("base_url") + ) + except ImportError: + logger.warning("OpenAI library not installed. Install with: pip install openai") + self.client = None + + def _convert_message_to_openai(self, message: AIMessage) -> Dict[str, Any]: + """Convert AIMessage to OpenAI format, supporting multimodal content""" + openai_message = {"role": message.role.value} + + if isinstance(message.content, str): + # Simple text message + openai_message["content"] = message.content + elif isinstance(message.content, list): + # Multimodal message + content_parts = [] + for part in message.content: + if isinstance(part, TextContent): + content_parts.append({ + "type": "text", + "text": part.text + }) + elif isinstance(part, ImageContent): + content_parts.append({ + "type": "image_url", + "image_url": part.image_url + }) + openai_message["content"] = content_parts + else: + # Fallback to string representation + openai_message["content"] = str(message.content) + + if message.name: + openai_message["name"] = message.name + + return openai_message + + async def chat_completion(self, messages: List[AIMessage], **kwargs) -> AIResponse: + """Generate chat completion using OpenAI""" + if not self.client: + raise RuntimeError("OpenAI client not available") + + config = self._merge_config(**kwargs) + + # Convert messages to OpenAI format with multimodal support + openai_messages = [ + self._convert_message_to_openai(msg) + for msg in messages + ] + + try: + response = await self.client.chat.completions.create( + model=config.get("model", self.model), + messages=openai_messages, + # max_tokens=config.get("max_tokens", 2000), + temperature=config.get("temperature", 0.7), + top_p=config.get("top_p", 1.0) + ) + + choice = response.choices[0] + + return AIResponse( + content=choice.message.content, + model=response.model, + usage={ + "prompt_tokens": response.usage.prompt_tokens, + "completion_tokens": response.usage.completion_tokens, + "total_tokens": response.usage.total_tokens + }, + finish_reason=choice.finish_reason, + metadata={"provider": "openai"} + ) + + except Exception as e: + # 提供更详细的错误信息 + error_msg = str(e) + if "Expecting value" in error_msg: + logger.error(f"OpenAI API JSON parsing error: {error_msg}. This usually indicates the API returned malformed JSON.") + elif "timeout" in error_msg.lower(): + logger.error(f"OpenAI API timeout error: {error_msg}") + elif "rate limit" in error_msg.lower(): + logger.error(f"OpenAI API rate limit error: {error_msg}") + else: + logger.error(f"OpenAI API error: {error_msg}") + raise + + async def text_completion(self, prompt: str, **kwargs) -> AIResponse: + """Generate text completion using OpenAI chat format""" + messages = [AIMessage(role=MessageRole.USER, content=prompt)] + return await self.chat_completion(messages, **kwargs) + + async def stream_chat_completion(self, messages: List[AIMessage], **kwargs) -> AsyncGenerator[str, None]: + """Stream chat completion using OpenAI""" + if not self.client: + raise RuntimeError("OpenAI client not available") + + config = self._merge_config(**kwargs) + + # Convert messages to OpenAI format with multimodal support + openai_messages = [ + self._convert_message_to_openai(msg) + for msg in messages + ] + + try: + stream = await self.client.chat.completions.create( + model=config.get("model", self.model), + messages=openai_messages, + # max_tokens=config.get("max_tokens", 2000), + temperature=config.get("temperature", 0.7), + top_p=config.get("top_p", 1.0), + stream=True + ) + + async for chunk in stream: + if chunk.choices and chunk.choices[0].delta.content: + yield chunk.choices[0].delta.content + + except Exception as e: + logger.error(f"OpenAI streaming error: {e}") + raise + + async def stream_text_completion(self, prompt: str, **kwargs) -> AsyncGenerator[str, None]: + """Stream text completion using OpenAI chat format""" + messages = [AIMessage(role=MessageRole.USER, content=prompt)] + async for chunk in self.stream_chat_completion(messages, **kwargs): + yield chunk + +class AnthropicProvider(AIProvider): + """Anthropic Claude API provider""" + + def __init__(self, config: Dict[str, Any]): + super().__init__(config) + try: + import anthropic + self.client = anthropic.AsyncAnthropic( + api_key=config.get("api_key") + ) + except ImportError: + logger.warning("Anthropic library not installed. Install with: pip install anthropic") + self.client = None + + def _convert_message_to_anthropic(self, message: AIMessage) -> Dict[str, Any]: + """Convert AIMessage to Anthropic format, supporting multimodal content""" + anthropic_message = {"role": message.role.value} + + if isinstance(message.content, str): + # Simple text message + anthropic_message["content"] = message.content + elif isinstance(message.content, list): + # Multimodal message + content_parts = [] + for part in message.content: + if isinstance(part, TextContent): + content_parts.append({ + "type": "text", + "text": part.text + }) + elif isinstance(part, ImageContent): + # Anthropic expects base64 data without the data URL prefix + image_url = part.image_url.get("url", "") + if image_url.startswith("data:image/"): + # Extract base64 data and media type + header, base64_data = image_url.split(",", 1) + media_type = header.split(":")[1].split(";")[0] + content_parts.append({ + "type": "image", + "source": { + "type": "base64", + "media_type": media_type, + "data": base64_data + } + }) + else: + # For URL-based images, we'd need to fetch and convert to base64 + # For now, skip or convert to text description + content_parts.append({ + "type": "text", + "text": f"[Image: {image_url}]" + }) + anthropic_message["content"] = content_parts + else: + # Fallback to string representation + anthropic_message["content"] = str(message.content) + + return anthropic_message + + async def chat_completion(self, messages: List[AIMessage], **kwargs) -> AIResponse: + """Generate chat completion using Anthropic Claude""" + if not self.client: + raise RuntimeError("Anthropic client not available") + + config = self._merge_config(**kwargs) + + # Convert messages to Anthropic format + system_message = None + claude_messages = [] + + for msg in messages: + if msg.role == MessageRole.SYSTEM: + # System messages should be simple text for Anthropic + system_message = msg.content if isinstance(msg.content, str) else str(msg.content) + else: + claude_messages.append(self._convert_message_to_anthropic(msg)) + + try: + response = await self.client.messages.create( + model=config.get("model", self.model), + # max_tokens=config.get("max_tokens", 2000), + temperature=config.get("temperature", 0.7), + system=system_message, + messages=claude_messages + ) + + content = response.content[0].text if response.content else "" + + return AIResponse( + content=content, + model=response.model, + usage={ + "prompt_tokens": response.usage.input_tokens, + "completion_tokens": response.usage.output_tokens, + "total_tokens": response.usage.input_tokens + response.usage.output_tokens + }, + finish_reason=response.stop_reason, + metadata={"provider": "anthropic"} + ) + + except Exception as e: + logger.error(f"Anthropic API error: {e}") + raise + + async def text_completion(self, prompt: str, **kwargs) -> AIResponse: + """Generate text completion using Anthropic chat format""" + messages = [AIMessage(role=MessageRole.USER, content=prompt)] + return await self.chat_completion(messages, **kwargs) + +class GoogleProvider(AIProvider): + """Google Gemini API provider""" + + def __init__(self, config: Dict[str, Any]): + super().__init__(config) + try: + import google.generativeai as genai + + # Configure the API key + genai.configure(api_key=config.get("api_key")) + + # Store base_url for potential future use or proxy configurations + self.base_url = config.get("base_url", "https://generativelanguage.googleapis.com") + + self.client = genai + self.model_instance = genai.GenerativeModel(config.get("model", "gemini-1.5-flash")) + except ImportError: + logger.warning("Google Generative AI library not installed. Install with: pip install google-generativeai") + self.client = None + self.model_instance = None + + def _convert_messages_to_gemini(self, messages: List[AIMessage]): + """Convert AIMessage list to Gemini format, supporting multimodal content""" + import google.generativeai as genai + import base64 + + # Try to import genai types for proper image handling + try: + from google.genai import types + GENAI_TYPES_AVAILABLE = True + except ImportError: + try: + # Fallback to older API structure + from google.generativeai import types + GENAI_TYPES_AVAILABLE = True + except ImportError: + logger.warning("Google GenAI types not available for proper image processing") + GENAI_TYPES_AVAILABLE = False + + # Check if we have any images + has_images = any( + isinstance(msg.content, list) and + any(isinstance(part, ImageContent) for part in msg.content) + for msg in messages + ) + + if not has_images: + # Text-only mode - return string + parts = [] + for msg in messages: + role_prefix = f"[{msg.role.value.upper()}]: " + if isinstance(msg.content, str): + parts.append(role_prefix + msg.content) + elif isinstance(msg.content, list): + message_parts = [role_prefix] + for part in msg.content: + if isinstance(part, TextContent): + message_parts.append(part.text) + parts.append(" ".join(message_parts)) + else: + parts.append(role_prefix + str(msg.content)) + return "\n\n".join(parts) + else: + # Multimodal mode - return list of parts for Gemini + content_parts = [] + + for msg in messages: + role_prefix = f"[{msg.role.value.upper()}]: " + + if isinstance(msg.content, str): + content_parts.append(role_prefix + msg.content) + elif isinstance(msg.content, list): + text_parts = [role_prefix] + + for part in msg.content: + if isinstance(part, TextContent): + text_parts.append(part.text) + elif isinstance(part, ImageContent): + # Add accumulated text first + if len(text_parts) > 1 or text_parts[0]: + content_parts.append(" ".join(text_parts)) + text_parts = [] + + # Process image for Gemini + image_url = part.image_url.get("url", "") + if image_url.startswith("data:image/") and GENAI_TYPES_AVAILABLE: + try: + # Extract base64 data and mime type + header, base64_data = image_url.split(",", 1) + mime_type = header.split(":")[1].split(";")[0] # Extract mime type like 'image/jpeg' + image_data = base64.b64decode(base64_data) + + # Create Gemini-compatible part from base64 image data + image_part = None + if GENAI_TYPES_AVAILABLE: + if hasattr(types, 'Part') and hasattr(types.Part, 'from_bytes'): + image_part = types.Part.from_bytes( + data=image_data, + mime_type=mime_type + ) + elif hasattr(types, 'to_part'): + image_part = types.to_part({ + 'inline_data': { + 'mime_type': mime_type, + 'data': image_data + } + }) + if image_part is None: + image_part = { + 'inline_data': { + 'mime_type': mime_type, + 'data': image_data + } + } + content_parts.append(image_part) + logger.info(f"Successfully processed image for Gemini: {mime_type}, {len(image_data)} bytes") + except Exception as e: + logger.error(f"Failed to process image for Gemini: {e}") + content_parts.append("请参考上传的图片进行设计。图片包含了重要的设计参考信息,请根据图片的风格、色彩、布局等元素来生成模板。") + else: + # Fallback when genai types not available or not base64 image + if image_url.startswith("data:image/"): + content_parts.append("请参考上传的图片进行设计。图片包含了重要的设计参考信息,请根据图片的风格、色彩、布局等元素来生成模板。") + else: + content_parts.append(f"请参考图片 {image_url} 进行设计") + + # Add remaining text + if len(text_parts) > 1 or (len(text_parts) == 1 and text_parts[0]): + content_parts.append(" ".join(text_parts)) + else: + content_parts.append(role_prefix + str(msg.content)) + + return content_parts + + async def chat_completion(self, messages: List[AIMessage], **kwargs) -> AIResponse: + """Generate chat completion using Google Gemini""" + if not self.client or not self.model_instance: + raise RuntimeError("Google Gemini client not available") + + config = self._merge_config(**kwargs) + + # Convert messages to Gemini format with multimodal support + prompt = self._convert_messages_to_gemini(messages) + + try: + # Configure generation parameters + # 确保max_tokens不会太小,至少1000个token用于生成内容 + max_tokens = max(config.get("max_tokens", 16384), 1000) + generation_config = { + "temperature": config.get("temperature", 0.7), + "top_p": config.get("top_p", 1.0), + # "max_output_tokens": max_tokens, + } + + # 配置安全设置 - 设置为较宽松的安全级别以减少误拦截 + safety_settings = [ + { + "category": "HARM_CATEGORY_HARASSMENT", + "threshold": "BLOCK_ONLY_HIGH" + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "threshold": "BLOCK_ONLY_HIGH" + }, + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "threshold": "BLOCK_ONLY_HIGH" + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "threshold": "BLOCK_ONLY_HIGH" + } + ] + + + response = await self._generate_async(prompt, generation_config, safety_settings) + logger.debug(f"Google Gemini API response: {response}") + + # 检查响应状态和安全过滤 + finish_reason = "stop" + content = "" + + if response.candidates: + candidate = response.candidates[0] + finish_reason = candidate.finish_reason.name if hasattr(candidate.finish_reason, 'name') else str(candidate.finish_reason) + + # 检查是否被安全过滤器阻止或其他问题 + if finish_reason == "SAFETY": + logger.warning("Content was blocked by safety filters") + content = "[内容被安全过滤器阻止]" + elif finish_reason == "RECITATION": + logger.warning("Content was blocked due to recitation") + content = "[内容因重复而被阻止]" + elif finish_reason == "MAX_TOKENS": + logger.warning("Response was truncated due to max tokens limit") + # 尝试获取部分内容 + try: + if hasattr(candidate, 'content') and candidate.content and hasattr(candidate.content, 'parts') and candidate.content.parts: + content = candidate.content.parts[0].text if candidate.content.parts[0].text else "[响应因token限制被截断,无内容]" + else: + content = "[响应因token限制被截断,无内容]" + except Exception as text_error: + logger.warning(f"Failed to get truncated response text: {text_error}") + content = "[响应因token限制被截断,无法获取内容]" + elif finish_reason == "OTHER": + logger.warning("Content was blocked for other reasons") + content = "[内容被其他原因阻止]" + else: + # 正常情况下获取文本 + try: + if hasattr(candidate, 'content') and candidate.content and hasattr(candidate.content, 'parts') and candidate.content.parts: + content = candidate.content.parts[0].text if candidate.content.parts[0].text else "" + else: + # 回退到response.text + content = response.text if hasattr(response, 'text') and response.text else "" + except Exception as text_error: + logger.warning(f"Failed to get response text: {text_error}") + content = "[无法获取响应内容]" + else: + logger.warning("No candidates in response") + content = "[响应中没有候选内容]" + + return AIResponse( + content=content, + model=self.model, + usage={ + "prompt_tokens": response.usage_metadata.prompt_token_count if hasattr(response, 'usage_metadata') else 0, + "completion_tokens": response.usage_metadata.candidates_token_count if hasattr(response, 'usage_metadata') else 0, + "total_tokens": response.usage_metadata.total_token_count if hasattr(response, 'usage_metadata') else 0 + }, + finish_reason=finish_reason, + metadata={"provider": "google"} + ) + + except Exception as e: + logger.error(f"Google Gemini API error: {e}") + raise + + async def _generate_async(self, prompt, generation_config: Dict[str, Any], safety_settings=None): + """Async wrapper for Gemini generation - supports both text and multimodal content""" + import asyncio + loop = asyncio.get_event_loop() + + def _generate_sync(): + kwargs = { + "generation_config": generation_config + } + if safety_settings: + kwargs["safety_settings"] = safety_settings + + return self.model_instance.generate_content( + prompt, # Can be string or list of parts + **kwargs + ) + + return await loop.run_in_executor(None, _generate_sync) + + async def text_completion(self, prompt: str, **kwargs) -> AIResponse: + """Generate text completion using Google Gemini""" + messages = [AIMessage(role=MessageRole.USER, content=prompt)] + return await self.chat_completion(messages, **kwargs) + +class OllamaProvider(AIProvider): + """Ollama local model provider""" + + def __init__(self, config: Dict[str, Any]): + super().__init__(config) + try: + import ollama + self.client = ollama.AsyncClient(host=config.get("base_url", "http://localhost:11434")) + except ImportError: + logger.warning("Ollama library not installed. Install with: pip install ollama") + self.client = None + + async def chat_completion(self, messages: List[AIMessage], **kwargs) -> AIResponse: + """Generate chat completion using Ollama""" + if not self.client: + raise RuntimeError("Ollama client not available") + + config = self._merge_config(**kwargs) + + # Convert messages to Ollama format with multimodal support + ollama_messages = [] + for msg in messages: + if isinstance(msg.content, str): + # Simple text message + ollama_messages.append({"role": msg.role.value, "content": msg.content}) + elif isinstance(msg.content, list): + # Multimodal message - convert to text description for Ollama + content_parts = [] + for part in msg.content: + if isinstance(part, TextContent): + content_parts.append(part.text) + elif isinstance(part, ImageContent): + # Ollama doesn't support images directly, add text description + image_url = part.image_url.get("url", "") + if image_url.startswith("data:image/"): + content_parts.append("[Image provided - base64 data]") + else: + content_parts.append(f"[Image: {image_url}]") + ollama_messages.append({ + "role": msg.role.value, + "content": " ".join(content_parts) + }) + else: + # Fallback to string representation + ollama_messages.append({"role": msg.role.value, "content": str(msg.content)}) + + try: + response = await self.client.chat( + model=config.get("model", self.model), + messages=ollama_messages, + options={ + "temperature": config.get("temperature", 0.7), + "top_p": config.get("top_p", 1.0), + # "num_predict": config.get("max_tokens", 2000) + } + ) + + content = response.get("message", {}).get("content", "") + + return AIResponse( + content=content, + model=config.get("model", self.model), + usage=self._calculate_usage( + " ".join([msg.content for msg in messages]), + content + ), + finish_reason="stop", + metadata={"provider": "ollama"} + ) + + except Exception as e: + logger.error(f"Ollama API error: {e}") + raise + + async def text_completion(self, prompt: str, **kwargs) -> AIResponse: + """Generate text completion using Ollama""" + messages = [AIMessage(role=MessageRole.USER, content=prompt)] + return await self.chat_completion(messages, **kwargs) + +class AIProviderFactory: + """Factory for creating AI providers""" + + _providers = { + "openai": OpenAIProvider, + "anthropic": AnthropicProvider, + "google": GoogleProvider, + "gemini": GoogleProvider, # Alias for google + "ollama": OllamaProvider, + "302ai": OpenAIProvider # 302.AI uses OpenAI-compatible API + } + + @classmethod + def create_provider(cls, provider_name: str, config: Optional[Dict[str, Any]] = None) -> AIProvider: + """Create an AI provider instance""" + if config is None: + config = ai_config.get_provider_config(provider_name) + + # Built-in providers + if provider_name not in cls._providers: + raise ValueError(f"Unknown provider: {provider_name}") + + provider_class = cls._providers[provider_name] + return provider_class(config) + + @classmethod + def get_available_providers(cls) -> List[str]: + """Get list of available providers""" + return list(cls._providers.keys()) + +class AIProviderManager: + """Manager for AI provider instances with caching and reloading""" + + def __init__(self): + self._provider_cache = {} + self._config_cache = {} + + def get_provider(self, provider_name: Optional[str] = None) -> AIProvider: + """Get AI provider instance with caching""" + if provider_name is None: + provider_name = ai_config.default_ai_provider + + # Get current config for the provider + current_config = ai_config.get_provider_config(provider_name) + + # Check if we have a cached provider and if config has changed + cache_key = provider_name + if (cache_key in self._provider_cache and + cache_key in self._config_cache and + self._config_cache[cache_key] == current_config): + return self._provider_cache[cache_key] + + # Create new provider instance + provider = AIProviderFactory.create_provider(provider_name, current_config) + + # Cache the provider and config + self._provider_cache[cache_key] = provider + self._config_cache[cache_key] = current_config + + return provider + + def clear_cache(self): + """Clear provider cache to force reload""" + self._provider_cache.clear() + self._config_cache.clear() + + def reload_provider(self, provider_name: str): + """Reload a specific provider""" + cache_key = provider_name + if cache_key in self._provider_cache: + del self._provider_cache[cache_key] + if cache_key in self._config_cache: + del self._config_cache[cache_key] + +# Global provider manager +_provider_manager = AIProviderManager() + +def get_ai_provider(provider_name: Optional[str] = None) -> AIProvider: + """Get AI provider instance""" + return _provider_manager.get_provider(provider_name) + + +def get_role_provider(role: str, provider_override: Optional[str] = None) -> Tuple[AIProvider, Dict[str, Optional[str]]]: + """Get provider and settings for a specific task role""" + settings = ai_config.get_model_config_for_role(role, provider_override=provider_override) + provider = get_ai_provider(settings["provider"]) + return provider, settings + +def reload_ai_providers(): + """Reload all AI providers (clear cache)""" + _provider_manager.clear_cache()