Spaces:
Sleeping
Sleeping
updated Groq API key
Browse files- app.py +7 -46
- LLM_test.py → groq_helper.py +6 -19
- retrieval_helper.py +3 -66
app.py
CHANGED
@@ -1,63 +1,24 @@
|
|
1 |
import gradio as gr
|
2 |
import json
|
3 |
-
from
|
4 |
from retrieval_helper import fetch
|
5 |
from groq import Groq
|
6 |
|
7 |
-
client = Groq(
|
8 |
-
api_key="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", # This is the default and can be omitted
|
9 |
-
)
|
10 |
-
|
11 |
-
# related_vectors = '''
|
12 |
-
# attribute: spend, operators_supported: [">", "<", ">=", "<=", "=", "!="], value_type: "Number"
|
13 |
-
# attribute: clicks, operators_supported: [">", "<", ">=", "<=", "=", "!="], value_type: "Integer"
|
14 |
-
# attribute: impressions, operators_supported: [">", "<", ">=", "<=", "=", "!="], value_type: "Integer"
|
15 |
-
# '''
|
16 |
-
|
17 |
-
|
18 |
-
# SYSTEM_PROMPT = f'''
|
19 |
-
# You are a system that converts natural language queries into a structured filter schema.
|
20 |
-
# The filter schema consists of a list of conditions, each represented as:
|
21 |
-
# {{
|
22 |
-
# "attribute": "<attribute_name>",
|
23 |
-
# "op": "<operator>",
|
24 |
-
# "value": "<value>"
|
25 |
-
# }}
|
26 |
-
# There can be any number of conditions. You have to list them all.
|
27 |
-
|
28 |
-
# Supported attributes and their operators are:
|
29 |
-
# {related_vectors}
|
30 |
-
|
31 |
-
# Example:
|
32 |
-
# Input: "Show campaigns where spend is greater than 11"
|
33 |
-
# Output: [{{"attribute": "spend", "op": ">", "value": 11}}]
|
34 |
-
|
35 |
-
# Input: "Find ads with clicks less than 100 and impressions greater than 500"
|
36 |
-
# Output: [
|
37 |
-
# {{"attribute": "clicks", "op": "<", "value": 100}},
|
38 |
-
# {{"attribute": "impressions", "op": ">", "value": 500}}
|
39 |
-
# ]
|
40 |
-
|
41 |
-
# STRICLY PROVIDE IN THE ABOVE JSON FORMAT WITHOUT ANY METADATA
|
42 |
-
|
43 |
-
# '''
|
44 |
-
|
45 |
-
# Define the Gradio interface
|
46 |
def generate_chat_completion_interface(USER_INPUT):
|
47 |
|
48 |
top_documents = fetch(USER_INPUT)
|
49 |
related_vectors = "\n".join(top_documents)
|
50 |
|
51 |
-
result = generate_chat_completion(
|
52 |
|
53 |
return result
|
54 |
|
55 |
-
#
|
56 |
iface = gr.Interface(
|
57 |
-
fn=generate_chat_completion_interface,
|
58 |
-
inputs=gr.Textbox(label="Enter your query"),
|
59 |
-
outputs=gr.Textbox(label="Generated JSON"),
|
60 |
-
title="RAG based search",
|
61 |
description="Provide your natural language searhc query"
|
62 |
)
|
63 |
|
|
|
1 |
import gradio as gr
|
2 |
import json
|
3 |
+
from groq_helper import generate_chat_completion
|
4 |
from retrieval_helper import fetch
|
5 |
from groq import Groq
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
def generate_chat_completion_interface(USER_INPUT):
|
8 |
|
9 |
top_documents = fetch(USER_INPUT)
|
10 |
related_vectors = "\n".join(top_documents)
|
11 |
|
12 |
+
result = generate_chat_completion(USER_INPUT, related_vectors)
|
13 |
|
14 |
return result
|
15 |
|
16 |
+
# Gradio app interface
|
17 |
iface = gr.Interface(
|
18 |
+
fn=generate_chat_completion_interface,
|
19 |
+
inputs=gr.Textbox(label="Enter your query"),
|
20 |
+
outputs=gr.Textbox(label="Generated JSON"),
|
21 |
+
title="RAG based search",
|
22 |
description="Provide your natural language searhc query"
|
23 |
)
|
24 |
|
LLM_test.py → groq_helper.py
RENAMED
@@ -1,25 +1,11 @@
|
|
1 |
-
import
|
2 |
from groq import Groq
|
3 |
-
from retrieval_helper import fetch
|
4 |
|
5 |
client = Groq(
|
6 |
-
api_key=
|
7 |
)
|
8 |
|
9 |
-
|
10 |
-
# attribute: spend, operators_supported: [">", "<", ">=", "<=", "=", "!="], value_type: "Number"
|
11 |
-
# attribute: clicks, operators_supported: [">", "<", ">=", "<=", "=", "!="], value_type: "Integer"
|
12 |
-
# attribute: impressions, operators_supported: [">", "<", ">=", "<=", "=", "!="], value_type: "Integer"
|
13 |
-
# '''
|
14 |
-
|
15 |
-
query = "Show campaigns where spend is greater than 11 and labels include holiday"
|
16 |
-
top_documents = fetch(query)
|
17 |
-
|
18 |
-
|
19 |
-
USER_INPUT = "Show campaigns where spend is greater than 11 and labels include holiday and with impressions less than 500"
|
20 |
-
|
21 |
-
def generate_chat_completion(client, USER_INPUT, related_vectors):
|
22 |
-
|
23 |
|
24 |
SYSTEM_PROMPT = f'''
|
25 |
You are a system that converts natural language queries into a structured filter schema.
|
@@ -63,5 +49,6 @@ def generate_chat_completion(client, USER_INPUT, related_vectors):
|
|
63 |
)
|
64 |
return chat_completion.choices[0].message.content
|
65 |
|
66 |
-
|
67 |
-
#
|
|
|
|
1 |
+
import os
|
2 |
from groq import Groq
|
|
|
3 |
|
4 |
client = Groq(
|
5 |
+
api_key=os.getenv('GROQ_API_KEY'),
|
6 |
)
|
7 |
|
8 |
+
def generate_chat_completion(USER_INPUT, related_vectors):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
SYSTEM_PROMPT = f'''
|
11 |
You are a system that converts natural language queries into a structured filter schema.
|
|
|
49 |
)
|
50 |
return chat_completion.choices[0].message.content
|
51 |
|
52 |
+
#Test input
|
53 |
+
# USER_INPUT = "Show campaigns where spend is greater than 11 and labels include holiday and with impressions less than 500"
|
54 |
+
# print(generate_chat_completion(SYSTEM_PROMPT, USER_INPUT))
|
retrieval_helper.py
CHANGED
@@ -1,79 +1,16 @@
|
|
1 |
-
# from langchain.vectorstores import FAISS
|
2 |
-
# from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
3 |
-
# from langchain.schema import Document
|
4 |
-
# import json
|
5 |
-
# from pathlib import Path
|
6 |
-
# from pprint import pprint
|
7 |
-
|
8 |
-
# with open('Data.json', 'r') as file:
|
9 |
-
# json_data = json.load(file)
|
10 |
-
|
11 |
-
# text_data = []
|
12 |
-
# attribute_data = [] # Store extra data for operators
|
13 |
-
|
14 |
-
# for message in json_data["messages"]:
|
15 |
-
# attribute = message["attribute"]
|
16 |
-
# operators = message["supported_operators"] # Keep as a list
|
17 |
-
# value_type = "Number" if message["valueType"] == "Numeric" else message["valueType"]
|
18 |
-
# sentence = f'''attribute: {attribute}, value_type: {value_type}'''
|
19 |
-
# text_data.append(sentence)
|
20 |
-
|
21 |
-
# # Store attribute-to-operator mapping
|
22 |
-
# attribute_data.append({"attribute": attribute, "operators": operators})
|
23 |
-
|
24 |
-
# # Create documents for FAISS
|
25 |
-
# data = [Document(page_content=text) for text in text_data]
|
26 |
-
|
27 |
-
# pprint(data)
|
28 |
-
|
29 |
-
|
30 |
-
# db = FAISS.from_documents(data,
|
31 |
-
# HuggingFaceEmbeddings(model_name='sentence-transformers/paraphrase-MiniLM-L6-v2'))
|
32 |
-
|
33 |
-
# # Connect query to FAISS index using a retriever
|
34 |
-
# retriever = db.as_retriever(
|
35 |
-
# search_type="similarity",
|
36 |
-
# search_kwargs={"k": 5}
|
37 |
-
# )
|
38 |
-
|
39 |
-
# # Modify fetch function to include operators
|
40 |
-
# def fetch(query):
|
41 |
-
# res = retriever.get_relevant_documents(query)
|
42 |
-
# docs = []
|
43 |
-
# for i in res:
|
44 |
-
# # Extract attribute from the document content
|
45 |
-
# attribute_line = i.page_content.split(",")[0] # "attribute: X"
|
46 |
-
# attribute = attribute_line.split(": ")[1] # Extract "X"
|
47 |
-
|
48 |
-
# # Find the matching operators from attribute_data
|
49 |
-
# operators = next((item["operators"] for item in attribute_data if item["attribute"] == attribute), [])
|
50 |
-
|
51 |
-
# # Format the operators as a list
|
52 |
-
# operators_list = f"operators: {operators}"
|
53 |
-
|
54 |
-
# # Append the content with operators
|
55 |
-
# docs.append(f"{i.page_content}, {operators_list}")
|
56 |
-
# return docs
|
57 |
-
|
58 |
-
# query = "Show campaigns where spend is greater than 11 and labels include holiday"
|
59 |
-
# top_documents = fetch(query)
|
60 |
-
# pprint(top_documents)
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
import json
|
65 |
-
|
66 |
from langchain.vectorstores import FAISS
|
67 |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
68 |
|
69 |
-
# Load the FAISS vector store from the directory
|
70 |
db = FAISS.load_local(
|
71 |
"faiss_index",
|
72 |
HuggingFaceEmbeddings(model_name='sentence-transformers/paraphrase-MiniLM-L6-v2'),
|
73 |
allow_dangerous_deserialization=True
|
74 |
)
|
75 |
|
76 |
-
|
|
|
77 |
|
78 |
# Connect query to FAISS index using a retriever
|
79 |
retriever = db.as_retriever(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import json
|
|
|
2 |
from langchain.vectorstores import FAISS
|
3 |
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
|
4 |
|
5 |
+
# Load the FAISS vector store from the directory 'faiss_index'
|
6 |
db = FAISS.load_local(
|
7 |
"faiss_index",
|
8 |
HuggingFaceEmbeddings(model_name='sentence-transformers/paraphrase-MiniLM-L6-v2'),
|
9 |
allow_dangerous_deserialization=True
|
10 |
)
|
11 |
|
12 |
+
#Load attrributes along with their supported operators
|
13 |
+
attribute_data = json.load(open("attribute_data.json"))
|
14 |
|
15 |
# Connect query to FAISS index using a retriever
|
16 |
retriever = db.as_retriever(
|