anmolsahai commited on
Commit
2fe7634
·
1 Parent(s): ff82222
Files changed (1) hide show
  1. langchain_pipeline.py +15 -18
langchain_pipeline.py CHANGED
@@ -6,23 +6,20 @@ from langchain_openai import OpenAIEmbeddings
6
  from langchain_anthropic import ChatAnthropic
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
  import base64
9
- import vertexai
10
- from vertexai.generative_models import GenerativeModel, Part, FinishReason
11
- import vertexai.preview.generative_models as generative_models
12
 
13
  def generate(document_parts, prompt_text):
14
- vertexai.init(project="akroda", location="us-central1")
15
- model = GenerativeModel("gemini-1.5-pro-001")
16
- responses = model.generate_content(
17
- document_parts,
18
- generation_config=generation_config,
19
- safety_settings=safety_settings,
20
- stream=True,
21
- )
22
-
23
  response_text = ""
24
- for response in responses:
25
- response_text += response.text
 
26
  return response_text
27
 
28
  document1 = Part.from_data(
@@ -102,10 +99,10 @@ generation_config = {
102
  }
103
 
104
  safety_settings = {
105
- generative_models.HarmCategory.HARM_CATEGORY_HATE_SPEECH: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
106
- generative_models.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
107
- generative_models.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
108
- generative_models.HarmCategory.HARM_CATEGORY_HARASSMENT: generative_models.HarmBlockThreshold.BLOCK_MEDIUM_AND_ABOVE,
109
  }
110
 
111
  def pipeline(file, model_name, balance_type, apsn_transactions, max_fees_per_day, min_overdrawn_fee, min_transaction_overdraft):
 
6
  from langchain_anthropic import ChatAnthropic
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
  import base64
9
+ from google.cloud import aiplatform
 
 
10
 
11
  def generate(document_parts, prompt_text):
12
+ aiplatform.init(project="akroda", location="us-central1")
13
+ client = aiplatform.PredictionServiceClient()
14
+ model = client.model_path(project="akroda", location="us-central1", model="gemini-1.5-pro-001")
15
+ instances = [{"content": part} for part in document_parts]
16
+
17
+ response = client.predict(name=model, instances=instances, parameters={"temperature": 1.0})
18
+
 
 
19
  response_text = ""
20
+ for prediction in response.predictions:
21
+ response_text += prediction.get("content", "")
22
+
23
  return response_text
24
 
25
  document1 = Part.from_data(
 
99
  }
100
 
101
  safety_settings = {
102
+ "HARM_CATEGORY_HATE_SPEECH": "BLOCK_MEDIUM_AND_ABOVE",
103
+ "HARM_CATEGORY_DANGEROUS_CONTENT": "BLOCK_MEDIUM_AND_ABOVE",
104
+ "HARM_CATEGORY_SEXUALLY_EXPLICIT": "BLOCK_MEDIUM_AND_ABOVE",
105
+ "HARM_CATEGORY_HARASSMENT": "BLOCK_MEDIUM_AND_ABOVE",
106
  }
107
 
108
  def pipeline(file, model_name, balance_type, apsn_transactions, max_fees_per_day, min_overdrawn_fee, min_transaction_overdraft):