File size: 3,348 Bytes
9705b6c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
const { Tool } = require('langchain/tools');
const { SearchClient, AzureKeyCredential } = require('@azure/search-documents');

class AzureCognitiveSearch extends Tool {
  constructor(fields = {}) {
    super();
    this.serviceEndpoint =
      fields.AZURE_COGNITIVE_SEARCH_SERVICE_ENDPOINT || this.getServiceEndpoint();
    this.indexName = fields.AZURE_COGNITIVE_SEARCH_INDEX_NAME || this.getIndexName();
    this.apiKey = fields.AZURE_COGNITIVE_SEARCH_API_KEY || this.getApiKey();

    this.apiVersion = fields.AZURE_COGNITIVE_SEARCH_API_VERSION || this.getApiVersion();

    this.queryType = fields.AZURE_COGNITIVE_SEARCH_SEARCH_OPTION_QUERY_TYPE || this.getQueryType();
    this.top = fields.AZURE_COGNITIVE_SEARCH_SEARCH_OPTION_TOP || this.getTop();
    this.select = fields.AZURE_COGNITIVE_SEARCH_SEARCH_OPTION_SELECT || this.getSelect();

    this.client = new SearchClient(
      this.serviceEndpoint,
      this.indexName,
      new AzureKeyCredential(this.apiKey),
      {
        apiVersion: this.apiVersion,
      },
    );
  }

  /**
   * The name of the tool.
   * @type {string}
   */
  name = 'azure-cognitive-search';

  /**
   * A description for the agent to use
   * @type {string}
   */
  description =
    'Use the \'azure-cognitive-search\' tool to retrieve search results relevant to your input';

  getServiceEndpoint() {
    const serviceEndpoint = process.env.AZURE_COGNITIVE_SEARCH_SERVICE_ENDPOINT || '';
    if (!serviceEndpoint) {
      throw new Error('Missing AZURE_COGNITIVE_SEARCH_SERVICE_ENDPOINT environment variable.');
    }
    return serviceEndpoint;
  }

  getIndexName() {
    const indexName = process.env.AZURE_COGNITIVE_SEARCH_INDEX_NAME || '';
    if (!indexName) {
      throw new Error('Missing AZURE_COGNITIVE_SEARCH_INDEX_NAME environment variable.');
    }
    return indexName;
  }

  getApiKey() {
    const apiKey = process.env.AZURE_COGNITIVE_SEARCH_API_KEY || '';
    if (!apiKey) {
      throw new Error('Missing AZURE_COGNITIVE_SEARCH_API_KEY environment variable.');
    }
    return apiKey;
  }

  getApiVersion() {
    return process.env.AZURE_COGNITIVE_SEARCH_API_VERSION || '2020-06-30';
  }

  getQueryType() {
    return process.env.AZURE_COGNITIVE_SEARCH_SEARCH_OPTION_QUERY_TYPE || 'simple';
  }

  getTop() {
    if (process.env.AZURE_COGNITIVE_SEARCH_SEARCH_OPTION_TOP) {
      return Number(process.env.AZURE_COGNITIVE_SEARCH_SEARCH_OPTION_TOP);
    } else {
      return 5;
    }
  }

  getSelect() {
    if (process.env.AZURE_COGNITIVE_SEARCH_SEARCH_OPTION_SELECT) {
      return process.env.AZURE_COGNITIVE_SEARCH_SEARCH_OPTION_SELECT.split(',');
    } else {
      return null;
    }
  }

  async _call(query) {
    try {
      const searchOption = {
        queryType: this.queryType,
        top: this.top,
      };
      if (this.select) {
        searchOption.select = this.select;
      }
      const searchResults = await this.client.search(query, searchOption);
      const resultDocuments = [];
      for await (const result of searchResults.results) {
        resultDocuments.push(result.document);
      }
      return JSON.stringify(resultDocuments);
    } catch (error) {
      console.error(`Azure Cognitive Search request failed: ${error}`);
      return 'There was an error with Azure Cognitive Search.';
    }
  }
}

module.exports = AzureCognitiveSearch;