hk-bt-rnd commited on
Commit
d47b6b4
·
1 Parent(s): 1a6c69c

Init spaces

Browse files
Files changed (3) hide show
  1. app.py +98 -0
  2. model/best.pt +3 -0
  3. requirements.txt +76 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ from matplotlib import cm
5
+ import torch
6
+ from transformers import AutoTokenizer, AutoModel
7
+ from model import ImageModel, TextModel
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms.v2 as transforms
10
+
11
+ # Load model directly
12
+ MODEL_NAME = "distilbert/distilroberta-base"
13
+ class_names = ['Action', 'Adventure', 'Comedy', 'Drama', 'Fantasy', 'Romance', 'Sci-Fi']
14
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
15
+ cp = torch.load(r"model\best.pt", map_location="cpu")
16
+ model_img = ImageModel(len(class_names))
17
+ model_img.load_state_dict(cp['w_i'])
18
+ model_text = TextModel(MODEL_NAME, len(class_names))
19
+ model_text.load_state_dict(cp['w_t'])
20
+
21
+ image_transforms = transforms.Compose([
22
+ transforms.Resize((224, 224)),
23
+ transforms.ToTensor(),
24
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
25
+ ])
26
+ def text_predictor(title, synopsis):
27
+ encoded_synopsis = tokenizer(f"{title} </s> {synopsis}", \
28
+ add_special_tokens = True, \
29
+ max_length = 128, \
30
+ padding = "max_length", \
31
+ truncation = True,
32
+ return_tensors='pt')
33
+
34
+ with torch.no_grad():
35
+ score, isAward, genres = model_text((encoded_synopsis['input_ids'], encoded_synopsis['attention_mask']))
36
+ score, isAward, genres = score.squeeze(0), F.sigmoid(isAward.squeeze(0)) >= 0.5 , F.sigmoid(genres.squeeze(0))
37
+
38
+ preds_name = []
39
+ for prob, cls in zip(genres, class_names):
40
+ if prob >= 0.5:
41
+ preds_name.append(cls)
42
+ # print(preds_name)
43
+ return round(score.item(), 2), isAward.item(), {"genres":preds_name}
44
+
45
+ def img_predictor(img):
46
+ # Preprocess the image
47
+ img = Image.fromarray(img.astype('uint8'), 'RGB') # Convert NumPy array to PIL Image
48
+ img = image_transforms(img).unsqueeze(0) # Apply transforms and add batch dimension
49
+
50
+ # Make predictions
51
+ with torch.no_grad():
52
+ output = model_img(img)
53
+ score, isAward, genres = output[0].squeeze(0), F.sigmoid(output[1].squeeze(0)) >= 0.5, F.sigmoid(output[2].squeeze(0))
54
+
55
+ preds_name = []
56
+ for prob, cls in zip(genres, class_names):
57
+ if prob >= 0.5:
58
+ preds_name.append(cls)
59
+
60
+ return round(score.item(), 2), isAward.item(), {"genres": preds_name}
61
+
62
+
63
+ def combine_predictor(title, synopsis, img):
64
+ encoded_synopsis = tokenizer(f"{title} </s> {synopsis}", \
65
+ add_special_tokens = True, \
66
+ max_length = 128, \
67
+ padding = "max_length", \
68
+ truncation = True,
69
+ return_tensors='pt')
70
+
71
+ img = Image.fromarray(img.astype('uint8'), 'RGB') # Convert NumPy array to PIL Image
72
+ img = image_transforms(img).unsqueeze(0) # Apply transforms and add batch dimension
73
+
74
+ # Make predictions
75
+ with torch.no_grad():
76
+ output_text = model_text((encoded_synopsis['input_ids'], encoded_synopsis['attention_mask']))
77
+ output_img = model_img(img)
78
+
79
+ score = (output_img[0].squeeze(0) + output_text[0].squeeze(0))/2
80
+ isAward = F.sigmoid((output_img[1].squeeze(0) + output_text[1].squeeze(0))/2) >= 0.5
81
+ genres = F.sigmoid((output_img[2].squeeze(0) + output_text[2].squeeze(0))/2)
82
+ print(score, isAward, genres)
83
+ preds_name = []
84
+ for prob, cls in zip(genres, class_names):
85
+ if prob >= 0.5:
86
+ preds_name.append(cls)
87
+
88
+ return round(score.item(), 2), isAward.item(), {"genres": preds_name}
89
+
90
+ # iface_1 = gr.Interface(age_predictor_image, gr.Image(height=256, width=256), "json", examples=[["young.webp"], ["old.jpg"]])
91
+ iface_1 = gr.Interface(text_predictor, [gr.Text(placeholder="Input title here"), gr.Text(placeholder="Input synopsis here")], ["label", "label", "json"])
92
+
93
+ iface_2 = gr.Interface(img_predictor, gr.Image(height=224, width=224), ["label", "label", "json"])
94
+
95
+ iface_3 = gr.Interface(combine_predictor, [gr.Text(placeholder="Input title here"), gr.Text(placeholder="Input synopsis here"), gr.Image(height=224, width=224)], ["label", "label", "json"])
96
+ demo = gr.TabbedInterface([iface_1, iface_2, iface_3], ["From Text", "From Image", "From Text and Image"])
97
+ demo.launch() # Launches the mini app!
98
+
model/best.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4b30a7ffad7969310ec38bcd7f9ef63ce4247e86ab91838a0c61adf0bbba268
3
+ size 696898582
requirements.txt ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==23.2.1
2
+ altair==5.2.0
3
+ annotated-types==0.6.0
4
+ anyio==4.3.0
5
+ attrs==23.2.0
6
+ certifi==2024.2.2
7
+ charset-normalizer==3.3.2
8
+ click==8.1.7
9
+ colorama==0.4.6
10
+ contourpy==1.2.0
11
+ cycler==0.12.1
12
+ exceptiongroup==1.2.0
13
+ fastapi==0.110.0
14
+ ffmpy==0.3.2
15
+ filelock==3.13.1
16
+ fonttools==4.50.0
17
+ fsspec==2024.3.1
18
+ gradio==4.22.0
19
+ gradio_client==0.13.0
20
+ h11==0.14.0
21
+ httpcore==1.0.4
22
+ httpx==0.27.0
23
+ huggingface-hub==0.21.4
24
+ idna==3.6
25
+ importlib_resources==6.3.2
26
+ Jinja2==3.1.3
27
+ jsonschema==4.21.1
28
+ jsonschema-specifications==2023.12.1
29
+ kiwisolver==1.4.5
30
+ markdown-it-py==3.0.0
31
+ MarkupSafe==2.1.5
32
+ matplotlib==3.8.3
33
+ mdurl==0.1.2
34
+ mpmath==1.3.0
35
+ networkx==3.2.1
36
+ numpy==1.26.4
37
+ orjson==3.9.15
38
+ packaging==24.0
39
+ pandas==2.2.1
40
+ pillow==10.2.0
41
+ pydantic==2.6.4
42
+ pydantic_core==2.16.3
43
+ pydub==0.25.1
44
+ Pygments==2.17.2
45
+ pyparsing==3.1.2
46
+ python-dateutil==2.9.0.post0
47
+ python-multipart==0.0.9
48
+ pytz==2024.1
49
+ PyYAML==6.0.1
50
+ referencing==0.34.0
51
+ regex==2023.12.25
52
+ requests==2.31.0
53
+ rich==13.7.1
54
+ rpds-py==0.18.0
55
+ ruff==0.3.3
56
+ safetensors==0.4.2
57
+ semantic-version==2.10.0
58
+ shellingham==1.5.4
59
+ six==1.16.0
60
+ sniffio==1.3.1
61
+ starlette==0.36.3
62
+ sympy==1.12
63
+ tokenizers==0.15.2
64
+ tomlkit==0.12.0
65
+ toolz==0.12.1
66
+ torch==2.2.1
67
+ torchaudio==2.2.1
68
+ torchvision==0.17.1
69
+ tqdm==4.66.2
70
+ transformers==4.38.2
71
+ typer==0.9.0
72
+ typing_extensions==4.10.0
73
+ tzdata==2024.1
74
+ urllib3==2.2.1
75
+ uvicorn==0.29.0
76
+ websockets==11.0.3