Sanjeev23oct's picture
Upload folder using huggingface_hub
f1d5e1c verified
raw
history blame
15.3 kB
import base64
import os
import time
from pathlib import Path
from typing import Dict, Optional
import requests
import json
import gradio as gr
import uuid
from langchain_anthropic import ChatAnthropic
from langchain_mistralai import ChatMistralAI
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain_ollama import ChatOllama
from langchain_openai import AzureChatOpenAI, ChatOpenAI
from .llm import DeepSeekR1ChatOpenAI, DeepSeekR1ChatOllama
PROVIDER_DISPLAY_NAMES = {
"openai": "OpenAI",
"azure_openai": "Azure OpenAI",
"anthropic": "Anthropic",
"deepseek": "DeepSeek",
"google": "Google",
"alibaba": "Alibaba",
"moonshot": "MoonShot",
"unbound": "Unbound AI"
}
def get_llm_model(provider: str, **kwargs):
"""
获取LLM 模型
:param provider: 模型类型
:param kwargs:
:return:
"""
if provider not in ["ollama"]:
env_var = f"{provider.upper()}_API_KEY"
api_key = kwargs.get("api_key", "") or os.getenv(env_var, "")
if not api_key:
raise MissingAPIKeyError(provider, env_var)
kwargs["api_key"] = api_key
if provider == "anthropic":
if not kwargs.get("base_url", ""):
base_url = "https://api.anthropic.com"
else:
base_url = kwargs.get("base_url")
return ChatAnthropic(
model=kwargs.get("model_name", "claude-3-5-sonnet-20241022"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
elif provider == 'mistral':
if not kwargs.get("base_url", ""):
base_url = os.getenv("MISTRAL_ENDPOINT", "https://api.mistral.ai/v1")
else:
base_url = kwargs.get("base_url")
if not kwargs.get("api_key", ""):
api_key = os.getenv("MISTRAL_API_KEY", "")
else:
api_key = kwargs.get("api_key")
return ChatMistralAI(
model=kwargs.get("model_name", "mistral-large-latest"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
elif provider == "openai":
if not kwargs.get("base_url", ""):
base_url = os.getenv("OPENAI_ENDPOINT", "https://api.openai.com/v1")
else:
base_url = kwargs.get("base_url")
return ChatOpenAI(
model=kwargs.get("model_name", "gpt-4o"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
elif provider == "deepseek":
if not kwargs.get("base_url", ""):
base_url = os.getenv("DEEPSEEK_ENDPOINT", "")
else:
base_url = kwargs.get("base_url")
if kwargs.get("model_name", "deepseek-chat") == "deepseek-reasoner":
return DeepSeekR1ChatOpenAI(
model=kwargs.get("model_name", "deepseek-reasoner"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
else:
return ChatOpenAI(
model=kwargs.get("model_name", "deepseek-chat"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
elif provider == "google":
return ChatGoogleGenerativeAI(
model=kwargs.get("model_name", "gemini-2.0-flash-exp"),
temperature=kwargs.get("temperature", 0.0),
api_key=api_key,
)
elif provider == "ollama":
if not kwargs.get("base_url", ""):
base_url = os.getenv("OLLAMA_ENDPOINT", "http://localhost:11434")
else:
base_url = kwargs.get("base_url")
if "deepseek-r1" in kwargs.get("model_name", "qwen2.5:7b"):
return DeepSeekR1ChatOllama(
model=kwargs.get("model_name", "deepseek-r1:14b"),
temperature=kwargs.get("temperature", 0.0),
num_ctx=kwargs.get("num_ctx", 32000),
base_url=base_url,
)
else:
return ChatOllama(
model=kwargs.get("model_name", "qwen2.5:7b"),
temperature=kwargs.get("temperature", 0.0),
num_ctx=kwargs.get("num_ctx", 32000),
num_predict=kwargs.get("num_predict", 1024),
base_url=base_url,
)
elif provider == "azure_openai":
if not kwargs.get("base_url", ""):
base_url = os.getenv("AZURE_OPENAI_ENDPOINT", "")
else:
base_url = kwargs.get("base_url")
api_version = kwargs.get("api_version", "") or os.getenv("AZURE_OPENAI_API_VERSION", "2025-01-01-preview")
return AzureChatOpenAI(
model=kwargs.get("model_name", "gpt-4o"),
temperature=kwargs.get("temperature", 0.0),
api_version=api_version,
azure_endpoint=base_url,
api_key=api_key,
)
elif provider == "alibaba":
if not kwargs.get("base_url", ""):
base_url = os.getenv("ALIBABA_ENDPOINT", "https://dashscope.aliyuncs.com/compatible-mode/v1")
else:
base_url = kwargs.get("base_url")
return ChatOpenAI(
model=kwargs.get("model_name", "qwen-plus"),
temperature=kwargs.get("temperature", 0.0),
base_url=base_url,
api_key=api_key,
)
elif provider == "moonshot":
return ChatOpenAI(
model=kwargs.get("model_name", "moonshot-v1-32k-vision-preview"),
temperature=kwargs.get("temperature", 0.0),
base_url=os.getenv("MOONSHOT_ENDPOINT"),
api_key=os.getenv("MOONSHOT_API_KEY"),
)
elif provider == "unbound":
return ChatOpenAI(
model=kwargs.get("model_name", "gpt-4o-mini"),
temperature=kwargs.get("temperature", 0.0),
base_url=os.getenv("UNBOUND_ENDPOINT", "https://api.getunbound.ai"),
api_key=api_key,
)
elif provider == "siliconflow":
if not kwargs.get("api_key", ""):
api_key = os.getenv("SiliconFLOW_API_KEY", "")
else:
api_key = kwargs.get("api_key")
if not kwargs.get("base_url", ""):
base_url = os.getenv("SiliconFLOW_ENDPOINT", "")
else:
base_url = kwargs.get("base_url")
return ChatOpenAI(
api_key=api_key,
base_url=base_url,
model_name=kwargs.get("model_name", "Qwen/QwQ-32B"),
temperature=kwargs.get("temperature", 0.0),
)
else:
raise ValueError(f"Unsupported provider: {provider}")
# Predefined model names for common providers
model_names = {
"anthropic": ["claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20240620", "claude-3-opus-20240229"],
"openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini"],
"deepseek": ["deepseek-chat", "deepseek-reasoner"],
"google": ["gemini-2.0-flash", "gemini-2.0-flash-thinking-exp", "gemini-1.5-flash-latest",
"gemini-1.5-flash-8b-latest", "gemini-2.0-flash-thinking-exp-01-21", "gemini-2.0-pro-exp-02-05"],
"ollama": ["qwen2.5:7b", "qwen2.5:14b", "qwen2.5:32b", "qwen2.5-coder:14b", "qwen2.5-coder:32b", "llama2:7b",
"deepseek-r1:14b", "deepseek-r1:32b"],
"azure_openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo"],
"mistral": ["pixtral-large-latest", "mistral-large-latest", "mistral-small-latest", "ministral-8b-latest"],
"alibaba": ["qwen-plus", "qwen-max", "qwen-turbo", "qwen-long"],
"moonshot": ["moonshot-v1-32k-vision-preview", "moonshot-v1-8k-vision-preview"],
"unbound": ["gemini-2.0-flash", "gpt-4o-mini", "gpt-4o", "gpt-4.5-preview"],
"siliconflow": [
"deepseek-ai/DeepSeek-R1",
"deepseek-ai/DeepSeek-V3",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-32B",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-14B",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B",
"deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
"deepseek-ai/DeepSeek-V2.5",
"deepseek-ai/deepseek-vl2",
"Qwen/Qwen2.5-72B-Instruct-128K",
"Qwen/Qwen2.5-72B-Instruct",
"Qwen/Qwen2.5-32B-Instruct",
"Qwen/Qwen2.5-14B-Instruct",
"Qwen/Qwen2.5-7B-Instruct",
"Qwen/Qwen2.5-Coder-32B-Instruct",
"Qwen/Qwen2.5-Coder-7B-Instruct",
"Qwen/Qwen2-7B-Instruct",
"Qwen/Qwen2-1.5B-Instruct",
"Qwen/QwQ-32B-Preview",
"Qwen/Qwen2-VL-72B-Instruct",
"Qwen/Qwen2.5-VL-32B-Instruct",
"Qwen/Qwen2.5-VL-72B-Instruct",
"TeleAI/TeleChat2",
"THUDM/glm-4-9b-chat",
"Vendor-A/Qwen/Qwen2.5-72B-Instruct",
"internlm/internlm2_5-7b-chat",
"internlm/internlm2_5-20b-chat",
"Pro/Qwen/Qwen2.5-7B-Instruct",
"Pro/Qwen/Qwen2-7B-Instruct",
"Pro/Qwen/Qwen2-1.5B-Instruct",
"Pro/THUDM/chatglm3-6b",
"Pro/THUDM/glm-4-9b-chat",
],
}
# Callback to update the model name dropdown based on the selected provider
def update_model_dropdown(llm_provider, api_key=None, base_url=None):
"""
Update the model name dropdown with predefined models for the selected provider.
"""
import gradio as gr
# Use API keys from .env if not provided
if not api_key:
api_key = os.getenv(f"{llm_provider.upper()}_API_KEY", "")
if not base_url:
base_url = os.getenv(f"{llm_provider.upper()}_BASE_URL", "")
# Use predefined models for the selected provider
if llm_provider in model_names:
return gr.Dropdown(choices=model_names[llm_provider], value=model_names[llm_provider][0], interactive=True)
else:
return gr.Dropdown(choices=[], value="", interactive=True, allow_custom_value=True)
class MissingAPIKeyError(Exception):
"""Custom exception for missing API key."""
def __init__(self, provider: str, env_var: str):
provider_display = PROVIDER_DISPLAY_NAMES.get(provider, provider.upper())
super().__init__(f"💥 {provider_display} API key not found! 🔑 Please set the "
f"`{env_var}` environment variable or provide it in the UI.")
def encode_image(img_path):
if not img_path:
return None
with open(img_path, "rb") as fin:
image_data = base64.b64encode(fin.read()).decode("utf-8")
return image_data
def get_latest_files(directory: str, file_types: list = ['.webm', '.zip']) -> Dict[str, Optional[str]]:
"""Get the latest recording and trace files"""
latest_files: Dict[str, Optional[str]] = {ext: None for ext in file_types}
if not os.path.exists(directory):
os.makedirs(directory, exist_ok=True)
return latest_files
for file_type in file_types:
try:
matches = list(Path(directory).rglob(f"*{file_type}"))
if matches:
latest = max(matches, key=lambda p: p.stat().st_mtime)
# Only return files that are complete (not being written)
if time.time() - latest.stat().st_mtime > 1.0:
latest_files[file_type] = str(latest)
except Exception as e:
print(f"Error getting latest {file_type} file: {e}")
return latest_files
async def capture_screenshot(browser_context):
"""Capture and encode a screenshot"""
# Extract the Playwright browser instance
playwright_browser = browser_context.browser.playwright_browser # Ensure this is correct.
# Check if the browser instance is valid and if an existing context can be reused
if playwright_browser and playwright_browser.contexts:
playwright_context = playwright_browser.contexts[0]
else:
return None
# Access pages in the context
pages = None
if playwright_context:
pages = playwright_context.pages
# Use an existing page or create a new one if none exist
if pages:
active_page = pages[0]
for page in pages:
if page.url != "about:blank":
active_page = page
else:
return None
# Take screenshot
try:
screenshot = await active_page.screenshot(
type='jpeg',
quality=75,
scale="css"
)
encoded = base64.b64encode(screenshot).decode('utf-8')
return encoded
except Exception as e:
return None
class ConfigManager:
def __init__(self):
self.components = {}
self.component_order = []
def register_component(self, name: str, component):
"""Register a gradio component for config management."""
self.components[name] = component
if name not in self.component_order:
self.component_order.append(name)
return component
def save_current_config(self):
"""Save the current configuration of all registered components."""
current_config = {}
for name in self.component_order:
component = self.components[name]
# Get the current value from the component
current_config[name] = getattr(component, "value", None)
return save_config_to_file(current_config)
def update_ui_from_config(self, config_file):
"""Update UI components from a loaded configuration file."""
if config_file is None:
return [gr.update() for _ in self.component_order] + ["No file selected."]
loaded_config = load_config_from_file(config_file.name)
if not isinstance(loaded_config, dict):
return [gr.update() for _ in self.component_order] + ["Error: Invalid configuration file."]
# Prepare updates for all components
updates = []
for name in self.component_order:
if name in loaded_config:
updates.append(gr.update(value=loaded_config[name]))
else:
updates.append(gr.update())
updates.append("Configuration loaded successfully.")
return updates
def get_all_components(self):
"""Return all registered components in the order they were registered."""
return [self.components[name] for name in self.component_order]
def load_config_from_file(config_file):
"""Load settings from a config file (JSON format)."""
try:
with open(config_file, 'r') as f:
settings = json.load(f)
return settings
except Exception as e:
return f"Error loading configuration: {str(e)}"
def save_config_to_file(settings, save_dir="./tmp/webui_settings"):
"""Save the current settings to a UUID.json file with a UUID name."""
os.makedirs(save_dir, exist_ok=True)
config_file = os.path.join(save_dir, f"{uuid.uuid4()}.json")
with open(config_file, 'w') as f:
json.dump(settings, f, indent=2)
return f"Configuration saved to {config_file}"