ucsahin commited on
Commit
2584386
1 Parent(s): 8bab591

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -0
app.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer
3
+ from threading import Thread
4
+ import re
5
+ import time
6
+ from PIL import Image
7
+ import torch
8
+ import spaces
9
+
10
+ processor = AutoProcessor.from_pretrained("ucsahin/TraVisionLM-base", trust_remote_code=True)
11
+ model = AutoModelForCausalLM.from_pretrained("ucsahin/TraVisionLM-base", trust_remote_code=True)
12
+ model_od = AutoModelForCausalLM.from_pretrained("ucsahin/TraVisionLM-Object-Detection-v2", trust_remote_code=True)
13
+
14
+ model.to("cuda:0")
15
+ model_od.to("cuda:0")
16
+
17
+ @spaces.GPU
18
+ def bot_streaming(message, history, max_tokens, temperature, top_p, top_k, repetition_penalty):
19
+ # print(message)
20
+ if message.files:
21
+ image = message.files[-1].path
22
+ else:
23
+ # if there's no image uploaded for this turn, look for images in the past turns
24
+ # kept inside tuples, take the last one
25
+ for hist in history:
26
+ if type(hist[0])==tuple:
27
+ image = hist[0][-1].path
28
+
29
+ if image is None:
30
+ gr.Error("Lütfen önce bir resim yükleyin.")
31
+
32
+ prompt = f"{message.text}"
33
+ image = Image.open(image).convert("RGB")
34
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to("cuda:0")
35
+
36
+ streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True})
37
+ generation_kwargs = dict(
38
+ inputs, streamer=streamer, max_new_tokens=max_tokens,
39
+ do_sample=True, temperature=temperature, top_p=top_p,
40
+ top_k=top_k, repetition_penalty=repetition_penalty
41
+ )
42
+ generated_text = ""
43
+
44
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
45
+ thread.start()
46
+
47
+ text_prompt = f"{message.text}\n"
48
+
49
+ buffer = ""
50
+ for new_text in streamer:
51
+ buffer += new_text
52
+ generated_text_without_prompt = buffer[len(text_prompt):]
53
+
54
+ time.sleep(0.04)
55
+ yield generated_text_without_prompt
56
+
57
+
58
+ gr.set_static_paths(paths=["static/images/"])
59
+ logo_path = "static/images/logo-color-v2.png"
60
+
61
+ PLACEHOLDER = f"""
62
+ <div style="display: flex; flex-direction: column; align-items: center; text-align: center; margin: 30px">
63
+ <img src="/file={logo_path}" style="width: 60%; height: auto;">
64
+ <h3>Resim yükleyin ve bir soru sorun</h3>
65
+ </div>
66
+ """
67
+
68
+ # with gr.Blocks() as demo:
69
+ # with gr.Tab("Open-ended Questions (Soru-cevap)"):
70
+ with gr.Accordion("Generation parameters", open=False) as parameter_accordion:
71
+ max_tokens_item = gr.Slider(64, 1024, value=512, step=64, label="Max tokens")
72
+ temperature_item = gr.Slider(0.1, 2, value=0.6, step=0.1, label="Temperature")
73
+ top_p_item = gr.Slider(0, 1.0, value=0.9, step=0.05, label="Top_p")
74
+ top_k_item = gr.Slider(0, 100, value=50, label="Top_k")
75
+ repeat_penalty_item = gr.Slider(0, 2, value=1.2, label="Repeat penalty")
76
+
77
+ demo = gr.ChatInterface(
78
+ title="TraVisionLM - Turkish Visual Language Model",
79
+ description="",
80
+ fn=bot_streaming,
81
+ chatbot=gr.Chatbot(placeholder=PLACEHOLDER, scale=1),
82
+ # examples=[{"text": "", "files":[""]},{"text": "", "files":[""]}],
83
+ additional_inputs=[max_tokens_item, temperature_item, top_p_item, top_k_item, repeat_penalty_item],
84
+ additional_inputs_accordion=parameter_accordion,
85
+ stop_btn="Stop Generation",
86
+ multimodal=True
87
+ )
88
+
89
+ # with gr.Tab("Object Detection (Obje Tespiti)"):
90
+ # gr.Image("tiger.jpg")
91
+ # gr.Button("New Tiger")
92
+
93
+ # demo = gr.ChatInterface(fn=bot_streaming, title="TraVisionLM - Turkish Visual Language Model",
94
+ # # examples=[{"text": "", "files":[""]},{"text": "", "files":[""]}],
95
+ # description="",
96
+ # additional_inputs=[
97
+ # gr.Slider(64, 1024, value=512, step=64, label="Max tokens"),
98
+ # gr.Slider(0.1, 2, value=0.6, step=0.1, label="Temperature"),
99
+ # gr.Slider(0, 1.0, value=0.9, step=0.05, label="Top_p"),
100
+ # gr.Slider(0, 100, value=50, label="Top_k"),
101
+ # gr.Slider(0, 2, value=1.2, label="Repeat penalty"),
102
+ # ],
103
+ # additional_inputs_accordion_name="Text generation parameters",
104
+ # # additional_inputs_accordion=
105
+ # stop_btn="Stop Generation", multimodal=True)
106
+ demo.launch(allowed_paths="")