File size: 2,044 Bytes
0d08077
7d06c4c
0d08077
 
df766f8
0d08077
df766f8
 
 
 
0d08077
 
bc65b96
 
df766f8
 
 
 
 
 
 
0d08077
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7d06c4c
0d08077
df766f8
 
 
 
 
 
 
 
 
 
 
 
0d08077
2e7d5a4
df766f8
 
0d08077
 
 
 
2e7d5a4
0d08077
bc65b96
 
 
df766f8
 
 
 
 
 
 
bc65b96
 
 
 
0d08077
 
2e7d5a4
 
 
 
 
 
8ca734f
2e7d5a4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import gradio as gr
import torch
from PIL import Image

from model import GitBaseCocoModel, BlipBaseModel

MODELS = {
	"Git-Base-COCO": GitBaseCocoModel,
	"Blip Base": BlipBaseModel,
}

def generate_captions(
	image,
	num_captions,
	max_length,
	temperature,
	top_k,
	top_p,
	repetition_penalty,
	diversity_penalty,
	model_name,
	):
	"""
	Generates captions for the given image.
	
	-----
	Parameters:
	image: PIL.Image
		The image to generate captions for.
	max_len: int
		The maximum length of the caption.
	num_captions: int
		The number of captions to generate.

	-----
	Returns:
	list[str]
	"""

	device = "cuda" if torch.cuda.is_available() else "cpu"
	
	model = MODELS[model_name](device)

	captions = model.generate(
		image,
		max_length,
		num_captions,
		temperature,
		top_k,
		top_p,
		repetition_penalty,
		diversity_penalty,
	)

	# Convert list to a single string separated by newlines.
	captions = "\n".join(captions)
	return captions

title = "Git-Base-COCO Image Captioning"
description = "A model for generating captions for images."

interface = gr.Interface(
	fn=generate_captions,
	inputs=[
		gr.inputs.Image(type="pil", label="Image"),
		gr.inputs.Slider(minimum=1, maximum=10, step=1, default=1, label="Number of Captions to Generate"),
		gr.inputs.Slider(minimum=20, maximum=100, step=5, default=50, label="Maximum Caption Length"),
		gr.inputs.Slider(minimum=0.1, maximum=10.0, step=0.1, default=1.0, label="Temperature"),
		gr.inputs.Slider(minimum=1, maximum=100, step=1, default=50, label="Top K"),
		gr.inputs.Slider(minimum=-5.0, maximum=5.0, step=0.1, default=1.0, label="Top P"),
		gr.inputs.Slider(minimum=1.0, maximum=10.0, step=0.1, default=1.0, label="Repetition Penalty"),
		gr.inputs.Slider(minimum=0.0, maximum=10.0, step=0.1, default=0.0, label="Diversity Penalty"),
		gr.Inputs.Dropdown(MODELS.keys(), label="Model"),
	],
	outputs=[
		gr.outputs.Textbox(label="Caption"),
	],
	title=title,
	description=description,
	)


if __name__ == "__main__":
	interface.launch(
		enable_queue=True,
		debug=True
	)