OmkarG commited on
Commit
7b240d9
·
1 Parent(s): a51ac58

updated Groq API key

Browse files
Files changed (3) hide show
  1. app.py +7 -46
  2. LLM_test.py → groq_helper.py +6 -19
  3. retrieval_helper.py +3 -66
app.py CHANGED
@@ -1,63 +1,24 @@
1
  import gradio as gr
2
  import json
3
- from LLM_test import generate_chat_completion
4
  from retrieval_helper import fetch
5
  from groq import Groq
6
 
7
- client = Groq(
8
- api_key="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", # This is the default and can be omitted
9
- )
10
-
11
- # related_vectors = '''
12
- # attribute: spend, operators_supported: [">", "<", ">=", "<=", "=", "!="], value_type: "Number"
13
- # attribute: clicks, operators_supported: [">", "<", ">=", "<=", "=", "!="], value_type: "Integer"
14
- # attribute: impressions, operators_supported: [">", "<", ">=", "<=", "=", "!="], value_type: "Integer"
15
- # '''
16
-
17
-
18
- # SYSTEM_PROMPT = f'''
19
- # You are a system that converts natural language queries into a structured filter schema.
20
- # The filter schema consists of a list of conditions, each represented as:
21
- # {{
22
- # "attribute": "<attribute_name>",
23
- # "op": "<operator>",
24
- # "value": "<value>"
25
- # }}
26
- # There can be any number of conditions. You have to list them all.
27
-
28
- # Supported attributes and their operators are:
29
- # {related_vectors}
30
-
31
- # Example:
32
- # Input: "Show campaigns where spend is greater than 11"
33
- # Output: [{{"attribute": "spend", "op": ">", "value": 11}}]
34
-
35
- # Input: "Find ads with clicks less than 100 and impressions greater than 500"
36
- # Output: [
37
- # {{"attribute": "clicks", "op": "<", "value": 100}},
38
- # {{"attribute": "impressions", "op": ">", "value": 500}}
39
- # ]
40
-
41
- # STRICLY PROVIDE IN THE ABOVE JSON FORMAT WITHOUT ANY METADATA
42
-
43
- # '''
44
-
45
- # Define the Gradio interface
46
  def generate_chat_completion_interface(USER_INPUT):
47
 
48
  top_documents = fetch(USER_INPUT)
49
  related_vectors = "\n".join(top_documents)
50
 
51
- result = generate_chat_completion(client, USER_INPUT, related_vectors)
52
 
53
  return result
54
 
55
- # Set up the Gradio app interface
56
  iface = gr.Interface(
57
- fn=generate_chat_completion_interface, # Function to run on input
58
- inputs=gr.Textbox(label="Enter your query"), # Input field
59
- outputs=gr.Textbox(label="Generated JSON"), # Output field
60
- title="RAG based search", # Title of the app
61
  description="Provide your natural language searhc query"
62
  )
63
 
 
1
  import gradio as gr
2
  import json
3
+ from groq_helper import generate_chat_completion
4
  from retrieval_helper import fetch
5
  from groq import Groq
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  def generate_chat_completion_interface(USER_INPUT):
8
 
9
  top_documents = fetch(USER_INPUT)
10
  related_vectors = "\n".join(top_documents)
11
 
12
+ result = generate_chat_completion(USER_INPUT, related_vectors)
13
 
14
  return result
15
 
16
+ # Gradio app interface
17
  iface = gr.Interface(
18
+ fn=generate_chat_completion_interface,
19
+ inputs=gr.Textbox(label="Enter your query"),
20
+ outputs=gr.Textbox(label="Generated JSON"),
21
+ title="RAG based search",
22
  description="Provide your natural language searhc query"
23
  )
24
 
LLM_test.py → groq_helper.py RENAMED
@@ -1,25 +1,11 @@
1
- import groq, os
2
  from groq import Groq
3
- from retrieval_helper import fetch
4
 
5
  client = Groq(
6
- api_key="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", # This is the default and can be omitted
7
  )
8
 
9
- # related_vectors = '''
10
- # attribute: spend, operators_supported: [">", "<", ">=", "<=", "=", "!="], value_type: "Number"
11
- # attribute: clicks, operators_supported: [">", "<", ">=", "<=", "=", "!="], value_type: "Integer"
12
- # attribute: impressions, operators_supported: [">", "<", ">=", "<=", "=", "!="], value_type: "Integer"
13
- # '''
14
-
15
- query = "Show campaigns where spend is greater than 11 and labels include holiday"
16
- top_documents = fetch(query)
17
-
18
-
19
- USER_INPUT = "Show campaigns where spend is greater than 11 and labels include holiday and with impressions less than 500"
20
-
21
- def generate_chat_completion(client, USER_INPUT, related_vectors):
22
-
23
 
24
  SYSTEM_PROMPT = f'''
25
  You are a system that converts natural language queries into a structured filter schema.
@@ -63,5 +49,6 @@ def generate_chat_completion(client, USER_INPUT, related_vectors):
63
  )
64
  return chat_completion.choices[0].message.content
65
 
66
-
67
- # print(generate_chat_completion(client, SYSTEM_PROMPT, USER_INPUT))
 
 
1
+ import os
2
  from groq import Groq
 
3
 
4
  client = Groq(
5
+ api_key=os.getenv('GROQ_API_KEY'),
6
  )
7
 
8
+ def generate_chat_completion(USER_INPUT, related_vectors):
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  SYSTEM_PROMPT = f'''
11
  You are a system that converts natural language queries into a structured filter schema.
 
49
  )
50
  return chat_completion.choices[0].message.content
51
 
52
+ #Test input
53
+ # USER_INPUT = "Show campaigns where spend is greater than 11 and labels include holiday and with impressions less than 500"
54
+ # print(generate_chat_completion(SYSTEM_PROMPT, USER_INPUT))
retrieval_helper.py CHANGED
@@ -1,79 +1,16 @@
1
- # from langchain.vectorstores import FAISS
2
- # from langchain.embeddings.huggingface import HuggingFaceEmbeddings
3
- # from langchain.schema import Document
4
- # import json
5
- # from pathlib import Path
6
- # from pprint import pprint
7
-
8
- # with open('Data.json', 'r') as file:
9
- # json_data = json.load(file)
10
-
11
- # text_data = []
12
- # attribute_data = [] # Store extra data for operators
13
-
14
- # for message in json_data["messages"]:
15
- # attribute = message["attribute"]
16
- # operators = message["supported_operators"] # Keep as a list
17
- # value_type = "Number" if message["valueType"] == "Numeric" else message["valueType"]
18
- # sentence = f'''attribute: {attribute}, value_type: {value_type}'''
19
- # text_data.append(sentence)
20
-
21
- # # Store attribute-to-operator mapping
22
- # attribute_data.append({"attribute": attribute, "operators": operators})
23
-
24
- # # Create documents for FAISS
25
- # data = [Document(page_content=text) for text in text_data]
26
-
27
- # pprint(data)
28
-
29
-
30
- # db = FAISS.from_documents(data,
31
- # HuggingFaceEmbeddings(model_name='sentence-transformers/paraphrase-MiniLM-L6-v2'))
32
-
33
- # # Connect query to FAISS index using a retriever
34
- # retriever = db.as_retriever(
35
- # search_type="similarity",
36
- # search_kwargs={"k": 5}
37
- # )
38
-
39
- # # Modify fetch function to include operators
40
- # def fetch(query):
41
- # res = retriever.get_relevant_documents(query)
42
- # docs = []
43
- # for i in res:
44
- # # Extract attribute from the document content
45
- # attribute_line = i.page_content.split(",")[0] # "attribute: X"
46
- # attribute = attribute_line.split(": ")[1] # Extract "X"
47
-
48
- # # Find the matching operators from attribute_data
49
- # operators = next((item["operators"] for item in attribute_data if item["attribute"] == attribute), [])
50
-
51
- # # Format the operators as a list
52
- # operators_list = f"operators: {operators}"
53
-
54
- # # Append the content with operators
55
- # docs.append(f"{i.page_content}, {operators_list}")
56
- # return docs
57
-
58
- # query = "Show campaigns where spend is greater than 11 and labels include holiday"
59
- # top_documents = fetch(query)
60
- # pprint(top_documents)
61
-
62
-
63
-
64
  import json
65
-
66
  from langchain.vectorstores import FAISS
67
  from langchain.embeddings.huggingface import HuggingFaceEmbeddings
68
 
69
- # Load the FAISS vector store from the directory
70
  db = FAISS.load_local(
71
  "faiss_index",
72
  HuggingFaceEmbeddings(model_name='sentence-transformers/paraphrase-MiniLM-L6-v2'),
73
  allow_dangerous_deserialization=True
74
  )
75
 
76
- attribute_data = json.load(open("attribute_data.json"))
 
77
 
78
  # Connect query to FAISS index using a retriever
79
  retriever = db.as_retriever(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import json
 
2
  from langchain.vectorstores import FAISS
3
  from langchain.embeddings.huggingface import HuggingFaceEmbeddings
4
 
5
+ # Load the FAISS vector store from the directory 'faiss_index'
6
  db = FAISS.load_local(
7
  "faiss_index",
8
  HuggingFaceEmbeddings(model_name='sentence-transformers/paraphrase-MiniLM-L6-v2'),
9
  allow_dangerous_deserialization=True
10
  )
11
 
12
+ #Load attrributes along with their supported operators
13
+ attribute_data = json.load(open("attribute_data.json"))
14
 
15
  # Connect query to FAISS index using a retriever
16
  retriever = db.as_retriever(