Spaces:
Runtime error
Runtime error
Upload 5 files
Browse files- ai_model_plugins.py +217 -0
- config_manager.py +20 -0
- gradio_app.py +158 -0
- insert_snippet.py +244 -0
- template_manager.py +22 -0
ai_model_plugins.py
ADDED
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ai_model_plugins.py
|
2 |
+
import os
|
3 |
+
import re
|
4 |
+
import requests
|
5 |
+
|
6 |
+
# --- Model Generation Functions ---
|
7 |
+
|
8 |
+
def generate_with_gemini(prompt, model_name="gemini-2.0-flash", api_key=None):
|
9 |
+
import google.generativeai as genai
|
10 |
+
actual_api_key = api_key if api_key else os.environ.get("GEMINI_API_KEY")
|
11 |
+
if not actual_api_key:
|
12 |
+
raise ValueError("Gemini API key is not set and none was provided.")
|
13 |
+
genai.configure(api_key=actual_api_key)
|
14 |
+
model = genai.GenerativeModel(model_name)
|
15 |
+
response = model.generate_content(prompt)
|
16 |
+
text = response.text
|
17 |
+
|
18 |
+
code_fence_regex = r"```[^\n]*\n(.*?)\n```"
|
19 |
+
match = re.search(code_fence_regex, text, re.DOTALL | re.IGNORECASE)
|
20 |
+
if match:
|
21 |
+
text = match.group(1).strip()
|
22 |
+
else:
|
23 |
+
text = text.strip()
|
24 |
+
|
25 |
+
tokens = len(text.split()) # rough token estimate
|
26 |
+
# Pricing: Gemini output pricing per 1M tokens
|
27 |
+
if model_name.lower() == "gemini-2.0-flash":
|
28 |
+
cost_rate = 0.40
|
29 |
+
elif model_name.lower() == "gemini-1.5-flash":
|
30 |
+
cost_rate = 0.30
|
31 |
+
elif model_name.lower() == "gemini-1.5-pro":
|
32 |
+
cost_rate = 5.00
|
33 |
+
else:
|
34 |
+
cost_rate = 0.40
|
35 |
+
cost = (tokens / 1_000_000) * cost_rate
|
36 |
+
return text, cost
|
37 |
+
|
38 |
+
def generate_with_openai(prompt, model_name="gpt-3.5-turbo", api_key=None):
|
39 |
+
import openai
|
40 |
+
actual_api_key = api_key if api_key else os.environ.get("OPENAI_API_KEY")
|
41 |
+
if not actual_api_key:
|
42 |
+
raise ValueError("OpenAI API key is not set and none was provided.")
|
43 |
+
openai.api_key = actual_api_key
|
44 |
+
response = openai.ChatCompletion.create(
|
45 |
+
model=model_name,
|
46 |
+
messages=[{"role": "user", "content": prompt}],
|
47 |
+
)
|
48 |
+
text = response.choices[0].message.content
|
49 |
+
tokens = len(text.split())
|
50 |
+
cost_rate = 2.00 # example rate; adjust per official pricing
|
51 |
+
cost = (tokens / 1_000_000) * cost_rate
|
52 |
+
return text, cost
|
53 |
+
|
54 |
+
def generate_with_anthropic(prompt, model_name="claude-v1", api_key=None):
|
55 |
+
url = "https://api.anthropic.com/v1/complete"
|
56 |
+
if not api_key:
|
57 |
+
raise ValueError("Anthropic API key is not set and none was provided.")
|
58 |
+
headers = {
|
59 |
+
"Content-Type": "application/json",
|
60 |
+
"x-api-key": api_key,
|
61 |
+
}
|
62 |
+
data = {
|
63 |
+
"prompt": prompt,
|
64 |
+
"model": model_name,
|
65 |
+
"max_tokens_to_sample": 150,
|
66 |
+
"temperature": 0.7,
|
67 |
+
"stop_sequences": ["\n\n"],
|
68 |
+
}
|
69 |
+
response = requests.post(url, json=data, headers=headers)
|
70 |
+
if response.status_code != 200:
|
71 |
+
raise ValueError("Anthropic API error: " + response.text)
|
72 |
+
json_data = response.json()
|
73 |
+
text = json_data.get("completion", "")
|
74 |
+
tokens = len(text.split())
|
75 |
+
# Example cost rate for Anthropic (update as per their official pricing)
|
76 |
+
cost_rate = 0.50
|
77 |
+
cost = (tokens / 1_000_000) * cost_rate
|
78 |
+
return text, cost
|
79 |
+
|
80 |
+
def generate_with_ollama(prompt, model_name="ollama-model-1", api_key=None):
|
81 |
+
# Ollama typically runs locally; adjust URL if needed.
|
82 |
+
url = "http://localhost:11434/api/v1/predict"
|
83 |
+
data = {
|
84 |
+
"model": model_name,
|
85 |
+
"prompt": prompt,
|
86 |
+
"max_tokens": 150,
|
87 |
+
}
|
88 |
+
response = requests.post(url, json=data)
|
89 |
+
if response.status_code != 200:
|
90 |
+
raise ValueError("Ollama API error: " + response.text)
|
91 |
+
json_data = response.json()
|
92 |
+
text = json_data.get("prediction", "")
|
93 |
+
tokens = len(text.split())
|
94 |
+
cost_rate = 0.25 # Replace with actual pricing if available
|
95 |
+
cost = (tokens / 1_000_000) * cost_rate
|
96 |
+
return text, cost
|
97 |
+
|
98 |
+
def generate_with_lmstudios(prompt, model_name="lmstudios-model-alpha", api_key=None):
|
99 |
+
url = "https://api.lmstudios.com/v1/generate"
|
100 |
+
if not api_key:
|
101 |
+
raise ValueError("LMStudios API key is not set and none was provided.")
|
102 |
+
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
103 |
+
data = {
|
104 |
+
"prompt": prompt,
|
105 |
+
"model": model_name,
|
106 |
+
"max_tokens": 150,
|
107 |
+
"temperature": 0.7,
|
108 |
+
}
|
109 |
+
response = requests.post(url, json=data, headers=headers)
|
110 |
+
if response.status_code != 200:
|
111 |
+
raise ValueError("LMStudios API error: " + response.text)
|
112 |
+
json_data = response.json()
|
113 |
+
text = json_data.get("text", "")
|
114 |
+
tokens = len(text.split())
|
115 |
+
cost_rate = 0.35 # Replace with official pricing
|
116 |
+
cost = (tokens / 1_000_000) * cost_rate
|
117 |
+
return text, cost
|
118 |
+
|
119 |
+
def generate_with_openrouter(prompt, model_name="openrouter-model-x", api_key=None):
|
120 |
+
url = "https://api.openrouter.ai/v2/chat/completions"
|
121 |
+
if not api_key:
|
122 |
+
raise ValueError("OpenRouter API key is not set and none was provided.")
|
123 |
+
headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
|
124 |
+
data = {
|
125 |
+
"model": model_name,
|
126 |
+
"messages": [{"role": "user", "content": prompt}],
|
127 |
+
"max_tokens": 150,
|
128 |
+
"temperature": 0.7,
|
129 |
+
}
|
130 |
+
response = requests.post(url, json=data, headers=headers)
|
131 |
+
if response.status_code != 200:
|
132 |
+
raise ValueError("OpenRouter API error: " + response.text)
|
133 |
+
json_data = response.json()
|
134 |
+
text = json_data["choices"][0]["message"]["content"]
|
135 |
+
tokens = len(text.split())
|
136 |
+
cost_rate = 0.45 # Update per official pricing
|
137 |
+
cost = (tokens / 1_000_000) * cost_rate
|
138 |
+
return text, cost
|
139 |
+
|
140 |
+
def get_model_generator(provider, model_name):
|
141 |
+
provider_lower = provider.lower()
|
142 |
+
if provider_lower == "openai":
|
143 |
+
return generate_with_openai
|
144 |
+
elif provider_lower == "anthropic":
|
145 |
+
return generate_with_anthropic
|
146 |
+
elif provider_lower == "ollama":
|
147 |
+
return generate_with_ollama
|
148 |
+
elif provider_lower == "lmstudios":
|
149 |
+
return generate_with_lmstudios
|
150 |
+
elif provider_lower == "openrouter":
|
151 |
+
return generate_with_openrouter
|
152 |
+
else: # Default to Gemini
|
153 |
+
return generate_with_gemini
|
154 |
+
|
155 |
+
# --- Model Listing Functions ---
|
156 |
+
|
157 |
+
def list_models_openai(api_key=None):
|
158 |
+
import openai
|
159 |
+
actual_api_key = api_key if api_key else os.environ.get("OPENAI_API_KEY")
|
160 |
+
if not actual_api_key:
|
161 |
+
raise ValueError("OpenAI API key is not set and none was provided.")
|
162 |
+
openai.api_key = actual_api_key
|
163 |
+
response = openai.Model.list()
|
164 |
+
return [model.id for model in response.data]
|
165 |
+
|
166 |
+
def list_models_gemini(api_key=None):
|
167 |
+
return ["gemini-2.0-flash", "gemini-1.5-flash", "gemini-1.5-pro"]
|
168 |
+
|
169 |
+
def list_models_anthropic(api_key=None):
|
170 |
+
return ["claude-3-opus", "claude-3.5-sonnet", "claude-3-haiku"]
|
171 |
+
|
172 |
+
def list_models_ollama(api_key=None):
|
173 |
+
url = "http://localhost:11434/api/models"
|
174 |
+
response = requests.get(url)
|
175 |
+
if response.status_code != 200:
|
176 |
+
raise ValueError("Ollama API error: " + response.text)
|
177 |
+
json_data = response.json()
|
178 |
+
return json_data.get("models", [])
|
179 |
+
|
180 |
+
def list_models_lmstudios(api_key=None):
|
181 |
+
if not api_key:
|
182 |
+
raise ValueError("LMStudios API key is not set and none was provided.")
|
183 |
+
url = "https://api.lmstudios.com/v1/models"
|
184 |
+
headers = {"Authorization": f"Bearer {api_key}"}
|
185 |
+
response = requests.get(url, headers=headers)
|
186 |
+
if response.status_code != 200:
|
187 |
+
raise ValueError("LMStudios API error: " + response.text)
|
188 |
+
json_data = response.json()
|
189 |
+
return [model["id"] for model in json_data.get("models", [])]
|
190 |
+
|
191 |
+
def list_models_openrouter(api_key=None):
|
192 |
+
if not api_key:
|
193 |
+
raise ValueError("OpenRouter API key is not set and none was provided.")
|
194 |
+
url = "https://api.openrouter.ai/v2/models"
|
195 |
+
headers = {"Authorization": f"Bearer {api_key}"}
|
196 |
+
response = requests.get(url, headers=headers)
|
197 |
+
if response.status_code != 200:
|
198 |
+
raise ValueError("OpenRouter API error: " + response.text)
|
199 |
+
json_data = response.json()
|
200 |
+
return [model["id"] for model in json_data.get("models", [])]
|
201 |
+
|
202 |
+
def list_available_models(provider, api_key=None):
|
203 |
+
provider_lower = provider.lower()
|
204 |
+
if provider_lower == "openai":
|
205 |
+
return list_models_openai(api_key)
|
206 |
+
elif provider_lower == "gemini":
|
207 |
+
return list_models_gemini(api_key)
|
208 |
+
elif provider_lower == "anthropic":
|
209 |
+
return list_models_anthropic(api_key)
|
210 |
+
elif provider_lower == "ollama":
|
211 |
+
return list_models_ollama(api_key)
|
212 |
+
elif provider_lower == "lmstudios":
|
213 |
+
return list_models_lmstudios(api_key)
|
214 |
+
elif provider_lower == "openrouter":
|
215 |
+
return list_models_openrouter(api_key)
|
216 |
+
else:
|
217 |
+
return []
|
config_manager.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
def load_config(config_path="config.json"):
|
5 |
+
"""
|
6 |
+
Loads configuration settings from a JSON file.
|
7 |
+
Returns a dictionary of config settings, or an empty dict if the file doesn't exist or fails to parse.
|
8 |
+
"""
|
9 |
+
if os.path.exists(config_path):
|
10 |
+
try:
|
11 |
+
with open(config_path, "r", encoding="utf-8") as f:
|
12 |
+
config = json.load(f)
|
13 |
+
return config
|
14 |
+
except json.JSONDecodeError as e:
|
15 |
+
print(f"Error parsing configuration file {config_path}: {e}")
|
16 |
+
return {}
|
17 |
+
|
18 |
+
if __name__ == "__main__":
|
19 |
+
config = load_config()
|
20 |
+
print(config)
|
gradio_app.py
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import json
|
3 |
+
from insert_snippet import insert_snippet
|
4 |
+
from ai_model_plugins import list_available_models
|
5 |
+
|
6 |
+
def generate_snippet_gui(
|
7 |
+
description, language, output_dir, output_file_name, clipboard,
|
8 |
+
provider, model, api_key, template, params, existing_file, marker, format_code
|
9 |
+
):
|
10 |
+
try:
|
11 |
+
try:
|
12 |
+
template_params = json.loads(params) if params else {}
|
13 |
+
except Exception as e:
|
14 |
+
return f"Error parsing template parameters: {e}", ""
|
15 |
+
|
16 |
+
result = insert_snippet(
|
17 |
+
description,
|
18 |
+
language,
|
19 |
+
output_dir=output_dir,
|
20 |
+
output_file_name=output_file_name,
|
21 |
+
use_clipboard=clipboard,
|
22 |
+
model_name=model,
|
23 |
+
template=template,
|
24 |
+
template_params=template_params,
|
25 |
+
existing_file=existing_file if existing_file.strip() != "" else None,
|
26 |
+
marker=marker if marker.strip() != "" else None,
|
27 |
+
format_code=format_code,
|
28 |
+
provider=provider if provider.strip() != "" else None,
|
29 |
+
api_key=api_key if api_key.strip() != "" else None,
|
30 |
+
return_snippet=True
|
31 |
+
)
|
32 |
+
# result is a tuple: (snippet, cost)
|
33 |
+
return result
|
34 |
+
except Exception as e:
|
35 |
+
return f"Error generating snippet: {e}", ""
|
36 |
+
|
37 |
+
def download_snippet_file_gui(
|
38 |
+
description, language, output_dir, output_file_name, clipboard,
|
39 |
+
provider, model, api_key, template, params, existing_file, marker, format_code
|
40 |
+
):
|
41 |
+
try:
|
42 |
+
try:
|
43 |
+
template_params = json.loads(params) if params else {}
|
44 |
+
except Exception as e:
|
45 |
+
return f"Error parsing template parameters: {e}"
|
46 |
+
|
47 |
+
result = insert_snippet(
|
48 |
+
description,
|
49 |
+
language,
|
50 |
+
output_dir=output_dir,
|
51 |
+
output_file_name=output_file_name,
|
52 |
+
use_clipboard=clipboard,
|
53 |
+
model_name=model,
|
54 |
+
template=template,
|
55 |
+
template_params=template_params,
|
56 |
+
existing_file=existing_file if existing_file.strip() != "" else None,
|
57 |
+
marker=marker if marker.strip() != "" else None,
|
58 |
+
format_code=format_code,
|
59 |
+
provider=provider if provider.strip() != "" else None,
|
60 |
+
api_key=api_key if api_key.strip() != "" else None,
|
61 |
+
return_file_path=True
|
62 |
+
)
|
63 |
+
return result[1]
|
64 |
+
except Exception as e:
|
65 |
+
return f"Error saving file: {e}"
|
66 |
+
|
67 |
+
def list_models_gui(provider, api_key):
|
68 |
+
try:
|
69 |
+
models = list_available_models(provider, api_key if api_key.strip() != "" else None)
|
70 |
+
return "\n".join(models)
|
71 |
+
except Exception as e:
|
72 |
+
return f"Error listing models: {e}"
|
73 |
+
|
74 |
+
def clear_fields():
|
75 |
+
return (
|
76 |
+
"", "python", "output", "snippet", False,
|
77 |
+
"Gemini", "gemini-2.0-flash", "", "", "{}", "", "", False, ""
|
78 |
+
)
|
79 |
+
|
80 |
+
with gr.Blocks() as demo:
|
81 |
+
gr.Markdown("# HMVL Snippet Generator")
|
82 |
+
|
83 |
+
with gr.Row():
|
84 |
+
description = gr.Textbox(label="Description", lines=3, placeholder="Enter code snippet description here...")
|
85 |
+
language = gr.Dropdown(
|
86 |
+
label="Language",
|
87 |
+
choices=["python", "javascript", "typescript", "markdown", "html", "css",
|
88 |
+
"cpp", "java", "go", "rust", "sql", "bash", "shell"],
|
89 |
+
value="python"
|
90 |
+
)
|
91 |
+
|
92 |
+
with gr.Row():
|
93 |
+
output_dir = gr.Textbox(label="Output Directory", value="output")
|
94 |
+
output_file_name = gr.Textbox(label="Output File Name", value="snippet")
|
95 |
+
|
96 |
+
with gr.Row():
|
97 |
+
clipboard = gr.Checkbox(label="Copy to Clipboard", value=False)
|
98 |
+
|
99 |
+
with gr.Row():
|
100 |
+
provider = gr.Dropdown(
|
101 |
+
label="LLM Provider",
|
102 |
+
choices=["Gemini", "OpenAI", "Anthropic", "Ollama", "LMStudios", "OpenRouter"],
|
103 |
+
value="Gemini"
|
104 |
+
)
|
105 |
+
model = gr.Textbox(label="Model", value="gemini-2.0-flash")
|
106 |
+
api_key = gr.Textbox(label="API Key", placeholder="Enter API key if required", type="password")
|
107 |
+
|
108 |
+
with gr.Row():
|
109 |
+
template = gr.Textbox(label="Template (optional)", placeholder="e.g., functional_component.jsx")
|
110 |
+
params = gr.Textbox(label="Template Parameters (JSON)", placeholder='e.g., {"component_name": "MyComponent", "title": "Hello"}', lines=2)
|
111 |
+
|
112 |
+
with gr.Row():
|
113 |
+
existing_file = gr.Textbox(label="Existing File (optional)", placeholder="Path to existing file for snippet insertion")
|
114 |
+
marker = gr.Textbox(label="Marker (optional)", placeholder="Marker string in file (if applicable)")
|
115 |
+
format_code = gr.Checkbox(label="Format Python Code", value=False)
|
116 |
+
|
117 |
+
with gr.Row():
|
118 |
+
generate_button = gr.Button("Generate Snippet")
|
119 |
+
download_button = gr.Button("Download Snippet File")
|
120 |
+
list_models_button = gr.Button("List Available Models")
|
121 |
+
clear_button = gr.Button("Clear Fields")
|
122 |
+
|
123 |
+
snippet_output = gr.Code(label="Generated Snippet", language="python")
|
124 |
+
cost_output = gr.Textbox(label="Estimated Cost (USD)", interactive=False)
|
125 |
+
file_output = gr.File(label="Download File")
|
126 |
+
models_list_output = gr.Textbox(label="Available Models", interactive=False)
|
127 |
+
|
128 |
+
with gr.Row():
|
129 |
+
status = gr.Textbox(label="Status", interactive=False)
|
130 |
+
|
131 |
+
generate_button.click(
|
132 |
+
fn=generate_snippet_gui,
|
133 |
+
inputs=[description, language, output_dir, output_file_name, clipboard,
|
134 |
+
provider, model, api_key, template, params, existing_file, marker, format_code],
|
135 |
+
outputs=[snippet_output, cost_output]
|
136 |
+
)
|
137 |
+
|
138 |
+
download_button.click(
|
139 |
+
fn=download_snippet_file_gui,
|
140 |
+
inputs=[description, language, output_dir, output_file_name, clipboard,
|
141 |
+
provider, model, api_key, template, params, existing_file, marker, format_code],
|
142 |
+
outputs=file_output
|
143 |
+
)
|
144 |
+
|
145 |
+
list_models_button.click(
|
146 |
+
fn=list_models_gui,
|
147 |
+
inputs=[provider, api_key],
|
148 |
+
outputs=models_list_output
|
149 |
+
)
|
150 |
+
|
151 |
+
clear_button.click(
|
152 |
+
fn=clear_fields,
|
153 |
+
inputs=None,
|
154 |
+
outputs=[description, language, output_dir, output_file_name, clipboard,
|
155 |
+
provider, model, api_key, template, params, existing_file, marker, format_code, snippet_output]
|
156 |
+
)
|
157 |
+
|
158 |
+
demo.launch()
|
insert_snippet.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# insert_snippet.py
|
2 |
+
import os
|
3 |
+
import argparse
|
4 |
+
import re
|
5 |
+
import json
|
6 |
+
import requests
|
7 |
+
import google.generativeai as genai
|
8 |
+
import pyperclip # For clipboard integration
|
9 |
+
|
10 |
+
from config_manager import load_config
|
11 |
+
from template_manager import load_template, apply_template
|
12 |
+
from ai_model_plugins import get_model_generator
|
13 |
+
|
14 |
+
def insert_snippet(description, language, output_dir="output", output_file_name="snippet",
|
15 |
+
use_clipboard=False, model_name=None, template=None, template_params=None,
|
16 |
+
provider=None, api_key=None,
|
17 |
+
existing_file=None, marker=None, format_code=False,
|
18 |
+
return_snippet=False, return_file_path=False):
|
19 |
+
"""
|
20 |
+
Generates a code snippet using an AI model or a template, and either:
|
21 |
+
- Saves it to a new file (or inserts into an existing file if provided),
|
22 |
+
- Copies it to the clipboard,
|
23 |
+
- And/or returns the snippet string (and cost, if requested).
|
24 |
+
|
25 |
+
New Parameters:
|
26 |
+
* provider: The LLM provider to use (e.g., "Gemini", "OpenAI", "Anthropic", etc.).
|
27 |
+
* api_key: The API key for the selected provider.
|
28 |
+
* existing_file: Path to an existing file where the snippet will be inserted.
|
29 |
+
* marker: A marker string in the existing file where the snippet should be inserted (if not found, appends).
|
30 |
+
* format_code: If True and language is Python, formats the snippet using Black.
|
31 |
+
"""
|
32 |
+
# Determine file extension based on language (case-insensitive)
|
33 |
+
file_extension = ".txt"
|
34 |
+
lang = language.lower()
|
35 |
+
if lang == "python":
|
36 |
+
file_extension = ".py"
|
37 |
+
elif lang == "javascript":
|
38 |
+
file_extension = ".js"
|
39 |
+
elif lang == "typescript":
|
40 |
+
file_extension = ".ts"
|
41 |
+
elif lang == "markdown":
|
42 |
+
file_extension = ".md"
|
43 |
+
elif lang == "html":
|
44 |
+
file_extension = ".html"
|
45 |
+
elif lang == "css":
|
46 |
+
file_extension = ".css"
|
47 |
+
elif lang == "cpp":
|
48 |
+
file_extension = ".cpp"
|
49 |
+
elif lang == "java":
|
50 |
+
file_extension = ".java"
|
51 |
+
elif lang == "go":
|
52 |
+
file_extension = ".go"
|
53 |
+
elif lang == "rust":
|
54 |
+
file_extension = ".rs"
|
55 |
+
elif lang == "sql":
|
56 |
+
file_extension = ".sql"
|
57 |
+
elif lang in ["bash", "shell"]:
|
58 |
+
file_extension = ".sh"
|
59 |
+
else:
|
60 |
+
raise ValueError("Invalid language. Supported languages: python, javascript, typescript, markdown, html, css, cpp, java, go, rust, sql, bash, shell")
|
61 |
+
|
62 |
+
# Determine the output file path if creating a new file
|
63 |
+
output_file_path = os.path.join(output_dir, output_file_name + file_extension)
|
64 |
+
|
65 |
+
try:
|
66 |
+
# Generate snippet via template or AI model
|
67 |
+
if template:
|
68 |
+
template_content = load_template(template)
|
69 |
+
template_params = template_params or {}
|
70 |
+
snippet = apply_template(template_content, template_params)
|
71 |
+
cost = 0.0
|
72 |
+
else:
|
73 |
+
if lang == "markdown":
|
74 |
+
snippet = f"# {description}"
|
75 |
+
cost = 0.0
|
76 |
+
else:
|
77 |
+
# Build a language-specific prompt
|
78 |
+
if lang == "python":
|
79 |
+
prompt = f"Generate Python code snippet for: {description}"
|
80 |
+
elif lang == "javascript":
|
81 |
+
prompt = f"Generate Javascript code snippet for: {description}"
|
82 |
+
elif lang == "typescript":
|
83 |
+
prompt = f"Generate Typescript code snippet for: {description}"
|
84 |
+
else:
|
85 |
+
prompt = description
|
86 |
+
|
87 |
+
# Get the appropriate generator function using provider and model_name
|
88 |
+
generator = get_model_generator(provider if provider else "gemini",
|
89 |
+
model_name if model_name else "gemini-2.0-flash")
|
90 |
+
# Expecting the generator to return a tuple: (snippet, cost)
|
91 |
+
result = generator(prompt, model_name=model_name if model_name else "gemini-2.0-flash", api_key=api_key)
|
92 |
+
snippet = result[0]
|
93 |
+
cost = result[1]
|
94 |
+
|
95 |
+
# Extract code from fenced code blocks if present
|
96 |
+
code_fence_regex = rf"```{language}\n(.*?)\n```"
|
97 |
+
match = re.search(code_fence_regex, snippet, re.DOTALL | re.IGNORECASE)
|
98 |
+
if match:
|
99 |
+
snippet = match.group(1).strip()
|
100 |
+
else:
|
101 |
+
snippet = snippet.strip()
|
102 |
+
print("Warning: Could not extract code snippet from AI API response. Using raw response.")
|
103 |
+
|
104 |
+
snippet = snippet.strip() + "\n"
|
105 |
+
|
106 |
+
# Format the code if requested and if it's Python
|
107 |
+
if format_code and lang == "python":
|
108 |
+
try:
|
109 |
+
import black
|
110 |
+
snippet = black.format_str(snippet, mode=black.FileMode())
|
111 |
+
except Exception as e:
|
112 |
+
print("Warning: Failed to format code with Black:", e)
|
113 |
+
|
114 |
+
# Create output or insert into existing file
|
115 |
+
if existing_file:
|
116 |
+
if not os.path.exists(existing_file):
|
117 |
+
raise FileNotFoundError(f"Existing file '{existing_file}' not found.")
|
118 |
+
with open(existing_file, "r", encoding="utf-8") as f:
|
119 |
+
file_contents = f.read()
|
120 |
+
if marker:
|
121 |
+
idx = file_contents.find(marker)
|
122 |
+
if idx == -1:
|
123 |
+
print("Warning: Marker not found in file. Appending snippet at the end.")
|
124 |
+
new_contents = file_contents + "\n" + snippet
|
125 |
+
else:
|
126 |
+
marker_end = file_contents.find("\n", idx)
|
127 |
+
if marker_end == -1:
|
128 |
+
marker_end = len(file_contents)
|
129 |
+
new_contents = file_contents[:marker_end] + "\n" + snippet + file_contents[marker_end:]
|
130 |
+
else:
|
131 |
+
new_contents = file_contents + "\n" + snippet
|
132 |
+
with open(existing_file, "w", encoding="utf-8") as f:
|
133 |
+
f.write(new_contents)
|
134 |
+
final_path = existing_file
|
135 |
+
else:
|
136 |
+
os.makedirs(output_dir, exist_ok=True)
|
137 |
+
with open(output_file_path, "w", encoding="utf-8") as f:
|
138 |
+
f.write(snippet)
|
139 |
+
final_path = output_file_path
|
140 |
+
|
141 |
+
# Clipboard integration
|
142 |
+
if use_clipboard:
|
143 |
+
pyperclip.copy(snippet)
|
144 |
+
print("Code snippet generated and copied to clipboard!")
|
145 |
+
|
146 |
+
# Preview the snippet in console
|
147 |
+
print("\n--- Snippet Preview ---")
|
148 |
+
print(snippet)
|
149 |
+
print("--- End Preview ---\n")
|
150 |
+
|
151 |
+
# Return values as requested
|
152 |
+
if return_file_path:
|
153 |
+
return snippet, final_path
|
154 |
+
elif return_snippet:
|
155 |
+
return snippet, cost
|
156 |
+
|
157 |
+
except requests.exceptions.RequestException as e:
|
158 |
+
error_msg = ("Error: Could not connect to AI API. Please check your network connection and ensure the API is accessible.\n"
|
159 |
+
f"Detailed error: {e}")
|
160 |
+
print(error_msg)
|
161 |
+
if return_file_path or return_snippet:
|
162 |
+
return error_msg
|
163 |
+
except KeyError as e:
|
164 |
+
error_msg = ("Error: Invalid response format from AI API. Missing key in response JSON.\n"
|
165 |
+
f"KeyError details: {e}")
|
166 |
+
print(error_msg)
|
167 |
+
if return_file_path or return_snippet:
|
168 |
+
return error_msg
|
169 |
+
except IndexError as e:
|
170 |
+
error_msg = ("Error: Invalid response format from AI API. Index out of bounds in response array.\n"
|
171 |
+
f"IndexError details: {e}")
|
172 |
+
print(error_msg)
|
173 |
+
if return_file_path or return_snippet:
|
174 |
+
return error_msg
|
175 |
+
except ValueError as ve:
|
176 |
+
error_msg = f"ValueError: {ve}"
|
177 |
+
print(error_msg)
|
178 |
+
if return_file_path or return_snippet:
|
179 |
+
return error_msg
|
180 |
+
except Exception as e:
|
181 |
+
error_msg = "An unexpected error occurred:\n" + str(e)
|
182 |
+
print(error_msg)
|
183 |
+
if return_file_path or return_snippet:
|
184 |
+
return error_msg
|
185 |
+
|
186 |
+
if __name__ == "__main__":
|
187 |
+
parser = argparse.ArgumentParser(
|
188 |
+
description="Generate and insert code snippets using an AI model or template."
|
189 |
+
)
|
190 |
+
parser.add_argument("description", nargs="?", help="Description for code snippet (not required if using a template).")
|
191 |
+
parser.add_argument("-l", "--language", required=True,
|
192 |
+
help="Programming/markup language (python, javascript, typescript, markdown, html, css, cpp, java, go, rust, sql, bash, shell).")
|
193 |
+
parser.add_argument("-o", "--output_dir", help="Output directory for the snippet (default: output).", default="output")
|
194 |
+
parser.add_argument("-f", "--output_file_name", help="Base name for the output file (default: snippet).", default="snippet")
|
195 |
+
parser.add_argument("-c", "--clipboard", action="store_true", help="Copy snippet to clipboard instead of saving.")
|
196 |
+
parser.add_argument("-m", "--model", help="Specify the AI model to use (e.g., gemini-2.0-flash, gemini-1.5-pro, openai-...).")
|
197 |
+
parser.add_argument("-t", "--template", help="Name of the template file (in the templates folder) to use.")
|
198 |
+
parser.add_argument("-p", "--params", help="JSON string of parameters for template substitution.", default="{}")
|
199 |
+
parser.add_argument("-C", "--config", help="Path to JSON configuration file.", default=None)
|
200 |
+
# New arguments for feature expansion:
|
201 |
+
parser.add_argument("--existing_file", help="Path to an existing file to insert the snippet into.", default=None)
|
202 |
+
parser.add_argument("--marker", help="Marker string in the existing file where the snippet should be inserted.", default=None)
|
203 |
+
parser.add_argument("--format", action="store_true", help="Format the generated Python code using Black.")
|
204 |
+
# New parameters for provider integration:
|
205 |
+
parser.add_argument("--provider", help="LLM provider to use (e.g., Gemini, OpenAI, Anthropic, Ollama, LMStudios, OpenRouter).", default="Gemini")
|
206 |
+
parser.add_argument("--api_key", help="API key for the selected provider.", default="")
|
207 |
+
|
208 |
+
args = parser.parse_args()
|
209 |
+
|
210 |
+
if args.config:
|
211 |
+
config = load_config(args.config)
|
212 |
+
defaults = {"output_dir": "output", "output_file_name": "snippet", "clipboard": False, "model": None}
|
213 |
+
for key, default_val in defaults.items():
|
214 |
+
if getattr(args, key) == default_val and key in config:
|
215 |
+
setattr(args, key, config[key])
|
216 |
+
|
217 |
+
try:
|
218 |
+
template_params = json.loads(args.params)
|
219 |
+
except json.JSONDecodeError as e:
|
220 |
+
print(f"Error parsing template parameters JSON: {e}")
|
221 |
+
template_params = {}
|
222 |
+
|
223 |
+
if args.template and not args.description:
|
224 |
+
description = ""
|
225 |
+
elif args.description:
|
226 |
+
description = args.description
|
227 |
+
else:
|
228 |
+
parser.error("You must provide a description or a template.")
|
229 |
+
|
230 |
+
insert_snippet(
|
231 |
+
description,
|
232 |
+
args.language,
|
233 |
+
output_dir=args.output_dir,
|
234 |
+
output_file_name=args.output_file_name,
|
235 |
+
use_clipboard=args.clipboard,
|
236 |
+
model_name=args.model,
|
237 |
+
template=args.template,
|
238 |
+
template_params=template_params,
|
239 |
+
provider=args.provider,
|
240 |
+
api_key=args.api_key,
|
241 |
+
existing_file=args.existing_file,
|
242 |
+
marker=args.marker,
|
243 |
+
format_code=args.format
|
244 |
+
)
|
template_manager.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
def load_template(template_name, templates_dir="templates"):
|
4 |
+
"""
|
5 |
+
Loads and returns the content of a template file from the specified templates directory.
|
6 |
+
"""
|
7 |
+
template_path = os.path.join(templates_dir, template_name)
|
8 |
+
if not os.path.exists(template_path):
|
9 |
+
raise FileNotFoundError(f"Template file '{template_path}' not found.")
|
10 |
+
with open(template_path, "r", encoding="utf-8") as f:
|
11 |
+
return f.read()
|
12 |
+
|
13 |
+
def apply_template(template_content, params):
|
14 |
+
"""
|
15 |
+
Applies the provided parameters to the template content using Python's format() method.
|
16 |
+
"""
|
17 |
+
try:
|
18 |
+
return template_content.format(**params)
|
19 |
+
except KeyError as e:
|
20 |
+
raise ValueError(f"Missing template parameter: {e}")
|
21 |
+
except Exception as e:
|
22 |
+
raise ValueError(f"Error applying template: {e}")
|