|
from PIL import Image |
|
import base64 |
|
from io import BytesIO |
|
import os |
|
from openai import OpenAI |
|
import json |
|
|
|
class Captioner: |
|
def __init__(self, api_key_path = None, proxy=None, api_base="https://api.lingyiwanwu.com/v1"): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.api_key = os.getenv('YI_VL_KEY') |
|
self.api_base = api_base |
|
|
|
|
|
|
|
self.client = OpenAI( |
|
api_key=self.api_key, |
|
base_url=self.api_base |
|
) |
|
|
|
self.history = {} |
|
self.history_file = None |
|
|
|
self.load_history() |
|
|
|
def load_access_token(self, file_path): |
|
with open(file_path, 'r') as file: |
|
return file.read().strip() |
|
|
|
def image2base64(self, image_path): |
|
|
|
with Image.open(image_path) as img: |
|
|
|
if img.height > 480: |
|
|
|
aspect_ratio = img.width / img.height |
|
new_height = 480 |
|
new_width = int(new_height * aspect_ratio) |
|
img = img.resize((new_width, new_height), Image.ANTIALIAS) |
|
|
|
|
|
buffered = BytesIO() |
|
img.save(buffered, format="JPEG") |
|
buffered.seek(0) |
|
|
|
|
|
img_base64 = "data:image/jpeg;base64," + base64.b64encode(buffered.read()).decode('utf-8') |
|
|
|
return img_base64 |
|
|
|
def load_history(self, jsonl_file_name=None): |
|
if jsonl_file_name is None: |
|
jsonl_file_name = "datas/caption_history.jsonl" |
|
|
|
self.history_file = jsonl_file_name |
|
|
|
if os.path.exists(jsonl_file_name): |
|
with open(jsonl_file_name, 'r', encoding='utf-8') as f: |
|
for line in f: |
|
data = json.loads(line) |
|
self.history[data['file_name']] = data['response'] |
|
|
|
def search_from_history(self, file_name): |
|
return self.history.get(file_name, None) |
|
|
|
def save_history(self, jsonl_file_name=None): |
|
if jsonl_file_name is None: |
|
jsonl_file_name = self.history_file |
|
|
|
if jsonl_file_name: |
|
with open(jsonl_file_name, 'w', encoding='utf-8') as f: |
|
for file_name, response in self.history.items(): |
|
json.dump({'file_name': file_name, 'response': response}, f, ensure_ascii=False) |
|
f.write('\n') |
|
|
|
|
|
|
|
def add_to_history(self, file_name, response): |
|
self.history[file_name] = response |
|
|
|
def caption(self, image_name): |
|
|
|
|
|
cached_response = self.search_from_history(image_name) |
|
if cached_response: |
|
|
|
return cached_response |
|
|
|
prompt = """Analyze the image and output in JSON format, including the following fields: |
|
- "detailed_description": A detailed description of the image content. |
|
- "major_object": Determine the main object/scene in the image based on the description, output with a simple word |
|
- "Chinese_name": 判断图片中主要物体的中文名 |
|
- "real_or_composite": Determine whether this image was taken with a camera or created/modifed by a computer, output with real or composite.""" |
|
|
|
img_base64 = self.image2base64(image_name) |
|
|
|
completion = self.client.chat.completions.create( |
|
model="yi-vision", |
|
messages=[ |
|
{ |
|
"role": "user", |
|
"content": [ |
|
{ |
|
"type": "text", |
|
"text": prompt |
|
}, |
|
{ |
|
"type": "image_url", |
|
"image_url": { |
|
"url": img_base64 |
|
} |
|
} |
|
] |
|
} |
|
], |
|
stream=False |
|
) |
|
|
|
response = completion.choices[0].message.content |
|
|
|
|
|
self.add_to_history(image_name, response) |
|
|
|
self.save_history() |
|
|
|
return response |
|
|
|
if __name__ == "__main__": |
|
import os |
|
os.environ['HTTP_PROXY'] = 'http://localhost:8234' |
|
os.environ['HTTPS_PROXY'] = 'http://localhost:8234' |
|
captioner = Captioner() |
|
test_image = "temp_images/3zjz9b3l.jpg" |
|
print(captioner.caption(test_image)) |