arad1367 commited on
Commit
9e1902c
β€’
1 Parent(s): e7a21bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -12
app.py CHANGED
@@ -1,25 +1,18 @@
1
- # 1. Imports and API setup
2
- import spaces
3
  import gradio as gr
4
  from groq import Groq
5
  import base64
6
  import os
7
- import subprocess
8
 
9
  # Define models used in the process
10
  llava_model = 'llava-v1.5-7b-4096-preview'
11
  llama31_model = 'llama-3.1-70b-versatile'
12
 
13
- # Install flash-attn if not already installed
14
- subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
15
-
16
  # Image encoding function
17
  def encode_image(image_path):
18
  with open(image_path, "rb") as image_file:
19
  return base64.b64encode(image_file.read()).decode('utf-8')
20
 
21
  # Image to text function
22
- @spaces.GPU()
23
  def image_to_text(client, model, base64_image, prompt):
24
  try:
25
  chat_completion = client.chat.completions.create(
@@ -46,7 +39,6 @@ def image_to_text(client, model, base64_image, prompt):
46
  return f"Error generating text from image: {str(e)}"
47
 
48
  # Technical review generation function
49
- @spaces.GPU()
50
  def technical_review_generation(client, image_description):
51
  keywords = ["econometrics", "finance", "marketing", "stock", "prediction", "chart", "graph", "time series"]
52
  if not any(keyword in image_description.lower() for keyword in keywords):
@@ -72,6 +64,7 @@ def technical_review_generation(client, image_description):
72
 
73
  # Main function for Gradio interface
74
  def process_image(api_key, image, prompt="Describe this image in detail."):
 
75
  try:
76
  os.environ["GROQ_API_KEY"] = api_key
77
  client = Groq() # Initialize the Groq client with the provided key
@@ -94,6 +87,7 @@ def process_image(api_key, image, prompt="Describe this image in detail."):
94
  # Return both image description and the econometrics report
95
  return f"--- Image Description ---\n{image_description}", f"--- GroqLLaVA EconoMind Report ---\n{report}"
96
 
 
97
  css = """
98
  #title, #description {
99
  text-align: center;
@@ -121,7 +115,9 @@ css = """
121
  """
122
 
123
  # Gradio Interface
 
124
  def gradio_interface():
 
125
  footer = """
126
  <div id="footer">
127
  <a href="https://www.linkedin.com/in/pejman-ebrahimi-4a60151a7/" target="_blank">LinkedIn</a> |
@@ -140,7 +136,7 @@ def gradio_interface():
140
  with gr.Row():
141
  api_key_input = gr.Textbox(label="GROQ API Key", placeholder="Enter your GROQ API Key", type="password")
142
  with gr.Row():
143
- image_input = gr.Image(type="filepath", label="Upload an Image") # Changed type to 'filepath'
144
  with gr.Row():
145
  report_button = gr.Button("Generate Report")
146
  with gr.Row():
@@ -169,8 +165,7 @@ def gradio_interface():
169
  outputs=[api_key_input, image_input, output_description, output_report]
170
  )
171
 
172
- # Launch the interface
173
- demo.launch()
174
 
175
  # Start the Gradio interface
176
- gradio_interface()
 
 
 
1
  import gradio as gr
2
  from groq import Groq
3
  import base64
4
  import os
 
5
 
6
  # Define models used in the process
7
  llava_model = 'llava-v1.5-7b-4096-preview'
8
  llama31_model = 'llama-3.1-70b-versatile'
9
 
 
 
 
10
  # Image encoding function
11
  def encode_image(image_path):
12
  with open(image_path, "rb") as image_file:
13
  return base64.b64encode(image_file.read()).decode('utf-8')
14
 
15
  # Image to text function
 
16
  def image_to_text(client, model, base64_image, prompt):
17
  try:
18
  chat_completion = client.chat.completions.create(
 
39
  return f"Error generating text from image: {str(e)}"
40
 
41
  # Technical review generation function
 
42
  def technical_review_generation(client, image_description):
43
  keywords = ["econometrics", "finance", "marketing", "stock", "prediction", "chart", "graph", "time series"]
44
  if not any(keyword in image_description.lower() for keyword in keywords):
 
64
 
65
  # Main function for Gradio interface
66
  def process_image(api_key, image, prompt="Describe this image in detail."):
67
+ # Set the API key
68
  try:
69
  os.environ["GROQ_API_KEY"] = api_key
70
  client = Groq() # Initialize the Groq client with the provided key
 
87
  # Return both image description and the econometrics report
88
  return f"--- Image Description ---\n{image_description}", f"--- GroqLLaVA EconoMind Report ---\n{report}"
89
 
90
+ # Define CSS for centering elements and footer styling
91
  css = """
92
  #title, #description {
93
  text-align: center;
 
115
  """
116
 
117
  # Gradio Interface
118
+ @gr.load(src="zerogpu")
119
  def gradio_interface():
120
+ # Define the footer HTML
121
  footer = """
122
  <div id="footer">
123
  <a href="https://www.linkedin.com/in/pejman-ebrahimi-4a60151a7/" target="_blank">LinkedIn</a> |
 
136
  with gr.Row():
137
  api_key_input = gr.Textbox(label="GROQ API Key", placeholder="Enter your GROQ API Key", type="password")
138
  with gr.Row():
139
+ image_input = gr.Image(type="filepath", label="Upload an Image")
140
  with gr.Row():
141
  report_button = gr.Button("Generate Report")
142
  with gr.Row():
 
165
  outputs=[api_key_input, image_input, output_description, output_report]
166
  )
167
 
168
+ return demo
 
169
 
170
  # Start the Gradio interface
171
+ app = gradio_interface()