Upload 2 files
Browse files- GroqConfig.ini +0 -0
- groq_api_vlm.py +134 -0
GroqConfig.ini
ADDED
The diff for this file is too large to render.
See raw diff
|
|
groq_api_vlm.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from colorama import init, Fore, Style
|
7 |
+
from configparser import ConfigParser
|
8 |
+
from groq import Groq
|
9 |
+
|
10 |
+
from ..utils.api_utils import make_api_request, load_prompt_options, get_prompt_content
|
11 |
+
from ..utils.image_utils import encode_image, tensor_to_pil
|
12 |
+
|
13 |
+
init() # Initialize colorama
|
14 |
+
|
15 |
+
class GroqAPIVLM:
|
16 |
+
DEFAULT_PROMPT = "Use [system_message] and [user_input]"
|
17 |
+
|
18 |
+
VLM_MODELS = [
|
19 |
+
"llava-v1.5-7b-4096-preview",
|
20 |
+
"llama-3.2-11b-vision-preview",
|
21 |
+
"llama-3.1-70b-versatile",
|
22 |
+
"gemma2-9b-it"
|
23 |
+
]
|
24 |
+
|
25 |
+
def __init__(self):
|
26 |
+
current_directory = os.path.dirname(os.path.realpath(__file__))
|
27 |
+
groq_directory = os.path.join(current_directory, 'groq')
|
28 |
+
config_path = os.path.join(groq_directory, 'GroqConfig.ini')
|
29 |
+
self.config = ConfigParser()
|
30 |
+
self.config.read(config_path)
|
31 |
+
self.api_key = self.config.get('API', 'key')
|
32 |
+
self.client = Groq(api_key=self.api_key)
|
33 |
+
|
34 |
+
# Load prompt options
|
35 |
+
prompt_files = [
|
36 |
+
os.path.join(groq_directory, 'DefaultPrompts_VLM.json'),
|
37 |
+
os.path.join(groq_directory, 'UserPrompts_VLM.json')
|
38 |
+
]
|
39 |
+
self.prompt_options = load_prompt_options(prompt_files)
|
40 |
+
|
41 |
+
@classmethod
|
42 |
+
def INPUT_TYPES(cls):
|
43 |
+
try:
|
44 |
+
current_directory = os.path.dirname(os.path.realpath(__file__))
|
45 |
+
groq_directory = os.path.join(current_directory, 'groq')
|
46 |
+
prompt_files = [
|
47 |
+
os.path.join(groq_directory, 'DefaultPrompts_VLM.json'),
|
48 |
+
os.path.join(groq_directory, 'UserPrompts_VLM.json')
|
49 |
+
]
|
50 |
+
prompt_options = load_prompt_options(prompt_files)
|
51 |
+
except Exception as e:
|
52 |
+
print(Fore.RED + f"Failed to load prompt options: {e}" + Style.RESET_ALL)
|
53 |
+
prompt_options = {}
|
54 |
+
|
55 |
+
return {
|
56 |
+
"required": {
|
57 |
+
"model": (cls.VLM_MODELS, {"tooltip": "Select the Vision-Language Model (VLM) to use."}),
|
58 |
+
"preset": ([cls.DEFAULT_PROMPT] + list(prompt_options.keys()), {"tooltip": "Select a preset prompt or use a custom prompt for the model."}),
|
59 |
+
"system_message": ("STRING", {"multiline": True, "default": "", "tooltip": "Optional system message to guide model behavior."}),
|
60 |
+
"user_input": ("STRING", {"multiline": True, "default": "", "tooltip": "User input or prompt for the model to generate a response."}),
|
61 |
+
"image": ("IMAGE", {"label": "Image (required for VLM models)", "tooltip": "Upload an image for processing by the VLM model."}),
|
62 |
+
"temperature": ("FLOAT", {"default": 0.85, "min": 0.1, "max": 2.0, "step": 0.05, "tooltip": "Controls randomness in responses.\n\nA higher temperature makes the model take more risks, leading to more creative or varied answers.\n\nA lower temperature (closer to 0.1) makes the model more focused and predictable."}),
|
63 |
+
"max_tokens": ("INT", {"default": 1024, "min": 1, "max": 131072, "step": 1, "tooltip": "Maximum number of tokens to generate in the output."}),
|
64 |
+
"top_p": ("FLOAT", {"default": 1.0, "min": 0.1, "max": 1.0, "step": 0.01, "tooltip": "Limits the pool of words the model can choose from based on their combined probability.\n\nSet it closer to 1 to allow more variety in output. Lowering this (e.g., 0.9) will restrict the output to the most likely words, making responses more focused."}),
|
65 |
+
"seed": ("INT", {"default": 42, "min": 0, "max": 4294967295, "tooltip": "Seed for random number generation, ensuring reproducibility."}),
|
66 |
+
"max_retries": ("INT", {"default": 2, "min": 1, "max": 10, "step": 1, "tooltip": "Maximum number of retries in case of failures."}),
|
67 |
+
"stop": ("STRING", {"default": "", "tooltip": "Stop generation when the specified sequence is encountered."}),
|
68 |
+
"json_mode": ("BOOLEAN", {"default": False, "tooltip": "Enable JSON mode for structured output.\n\nIMPORTANT: Requires you to use the word 'JSON' in the prompt."}),
|
69 |
+
}
|
70 |
+
}
|
71 |
+
|
72 |
+
OUTPUT_NODE = True
|
73 |
+
RETURN_TYPES = ("STRING", "BOOLEAN", "STRING")
|
74 |
+
RETURN_NAMES = ("api_response", "success", "status_code")
|
75 |
+
OUTPUT_TOOLTIPS = ("The API response. This is the description of your input image generated by the model", "Whether the request was successful", "The status code of the request")
|
76 |
+
FUNCTION = "process_completion_request"
|
77 |
+
CATEGORY = "⚡ MNeMiC Nodes"
|
78 |
+
DESCRIPTION = "Uses Groq API for image processing."
|
79 |
+
|
80 |
+
def process_completion_request(self, model, image, temperature, max_tokens, top_p, seed, max_retries, stop, json_mode, preset="", system_message="", user_input=""):
|
81 |
+
# Set the seed for reproducibility
|
82 |
+
torch.manual_seed(seed)
|
83 |
+
np.random.seed(seed)
|
84 |
+
random.seed(seed)
|
85 |
+
|
86 |
+
if preset == self.DEFAULT_PROMPT:
|
87 |
+
system_message = system_message
|
88 |
+
else:
|
89 |
+
system_message = get_prompt_content(self.prompt_options, preset)
|
90 |
+
|
91 |
+
url = 'https://api.groq.com/openai/v1/chat/completions'
|
92 |
+
headers = {'Authorization': f'Bearer {self.api_key}'}
|
93 |
+
|
94 |
+
if image is not None and isinstance(image, torch.Tensor):
|
95 |
+
# Process the image
|
96 |
+
image_pil = tensor_to_pil(image)
|
97 |
+
base64_image = encode_image(image_pil)
|
98 |
+
if base64_image:
|
99 |
+
combined_message = f"{system_message}\n{user_input}"
|
100 |
+
# Send one single message containing both text and image
|
101 |
+
image_content = {
|
102 |
+
"role": "user",
|
103 |
+
"content": [
|
104 |
+
{"type": "text", "text": combined_message},
|
105 |
+
{
|
106 |
+
"type": "image_url",
|
107 |
+
"image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}
|
108 |
+
}
|
109 |
+
]
|
110 |
+
}
|
111 |
+
messages = [image_content]
|
112 |
+
else:
|
113 |
+
print(Fore.RED + "Failed to encode image." + Style.RESET_ALL)
|
114 |
+
messages = []
|
115 |
+
else:
|
116 |
+
print(Fore.RED + "Image is required for VLM models." + Style.RESET_ALL)
|
117 |
+
return "Image is required for VLM models.", False, "400 Bad Request"
|
118 |
+
|
119 |
+
data = {
|
120 |
+
'model': model,
|
121 |
+
'messages': messages,
|
122 |
+
'temperature': temperature,
|
123 |
+
'max_tokens': max_tokens,
|
124 |
+
'top_p': top_p,
|
125 |
+
'seed': seed
|
126 |
+
}
|
127 |
+
|
128 |
+
if stop: # Only add stop if it's not empty
|
129 |
+
data['stop'] = stop
|
130 |
+
|
131 |
+
#print(f"Sending request to {url} with data: {json.dumps(data, indent=4)} and headers: {headers}")
|
132 |
+
|
133 |
+
assistant_message, success, status_code = make_api_request(data, headers, url, max_retries)
|
134 |
+
return assistant_message, success, status_code
|