anmolsahai commited on
Commit
3b863e1
·
1 Parent(s): 2fe7634
Files changed (1) hide show
  1. langchain_pipeline.py +51 -8
langchain_pipeline.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  from pdfminer import high_level
3
  from langchain_astradb import AstraDBVectorStore
4
  from langchain_core.prompts import PromptTemplate
@@ -6,18 +7,60 @@ from langchain_openai import OpenAIEmbeddings
6
  from langchain_anthropic import ChatAnthropic
7
  from langchain_google_genai import ChatGoogleGenerativeAI
8
  import base64
9
- from google.cloud import aiplatform
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def generate(document_parts, prompt_text):
12
- aiplatform.init(project="akroda", location="us-central1")
13
- client = aiplatform.PredictionServiceClient()
14
- model = client.model_path(project="akroda", location="us-central1", model="gemini-1.5-pro-001")
 
 
 
15
  instances = [{"content": part} for part in document_parts]
16
-
17
- response = client.predict(name=model, instances=instances, parameters={"temperature": 1.0})
18
-
 
 
 
 
 
19
  response_text = ""
20
- for prediction in response.predictions:
21
  response_text += prediction.get("content", "")
22
 
23
  return response_text
 
1
  import os
2
+ import requests
3
  from pdfminer import high_level
4
  from langchain_astradb import AstraDBVectorStore
5
  from langchain_core.prompts import PromptTemplate
 
7
  from langchain_anthropic import ChatAnthropic
8
  from langchain_google_genai import ChatGoogleGenerativeAI
9
  import base64
10
+ import json
11
+
12
+ # Load Google Cloud credentials from the file
13
+ credentials_path = "/mnt/data/Composure IAM Admin.json"
14
+ with open(credentials_path) as f:
15
+ credentials = json.load(f)
16
+
17
+ def get_access_token():
18
+ # Request an access token using the service account credentials
19
+ auth_url = 'https://oauth2.googleapis.com/token'
20
+ auth_data = {
21
+ 'grant_type': 'urn:ietf:params:oauth:grant-type:jwt-bearer',
22
+ 'assertion': create_jwt_assertion(credentials)
23
+ }
24
+ response = requests.post(auth_url, data=auth_data)
25
+ response.raise_for_status()
26
+ return response.json()['access_token']
27
+
28
+ def create_jwt_assertion(credentials):
29
+ # Create a JWT assertion to use for requesting an access token
30
+ header = {
31
+ 'alg': 'RS256',
32
+ 'typ': 'JWT'
33
+ }
34
+ now = int(time.time())
35
+ claim_set = {
36
+ 'iss': credentials['client_email'],
37
+ 'scope': 'https://www.googleapis.com/auth/cloud-platform',
38
+ 'aud': 'https://oauth2.googleapis.com/token',
39
+ 'exp': now + 3600,
40
+ 'iat': now
41
+ }
42
+ jwt_message = json.dumps(header) + '.' + json.dumps(claim_set)
43
+ signed_jwt = jwt.encode(jwt_message, credentials['private_key'], algorithm='RS256')
44
+ return signed_jwt
45
 
46
  def generate(document_parts, prompt_text):
47
+ access_token = get_access_token()
48
+ model_endpoint = f'https://us-central1-aiplatform.googleapis.com/v1/projects/{credentials["project_id"]}/locations/us-central1/endpoints/gemini-1.5-pro-001:predict'
49
+ headers = {
50
+ 'Authorization': f'Bearer {access_token}',
51
+ 'Content-Type': 'application/json'
52
+ }
53
  instances = [{"content": part} for part in document_parts]
54
+ data = {
55
+ 'instances': instances,
56
+ 'parameters': {"temperature": 1.0}
57
+ }
58
+ response = requests.post(model_endpoint, headers=headers, json=data)
59
+ response.raise_for_status()
60
+ predictions = response.json()['predictions']
61
+
62
  response_text = ""
63
+ for prediction in predictions:
64
  response_text += prediction.get("content", "")
65
 
66
  return response_text