shukdevdatta123's picture
Update app.py
3344d73 verified
raw
history blame
6.72 kB
#
import streamlit as st
import os
import glob
import base64
import json
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.runnables import chain
from PIL import Image as PILImage
from io import BytesIO
# Streamlit title
st.title("Vehicle Information Extraction from Images")
# Prompt user for OpenAI API key
openai_api_key = st.text_input("Enter your OpenAI API Key:", type="password")
# Set the OpenAI API key if provided
if openai_api_key:
os.environ["OPENAI_API_KEY"] = openai_api_key
# Vehicle class (same as in the original code)
class Vehicle(BaseModel):
Type: str = Field(..., examples=["Car", "Truck", "Motorcycle", 'Bus', 'Van'], description="The type of the vehicle.")
License: str = Field(..., description="The license plate number of the vehicle.")
Make: str = Field(..., examples=["Toyota", "Honda", "Ford", "Suzuki"], description="The Make of the vehicle.")
Model: str = Field(..., examples=["Corolla", "Civic", "F-150"], description="The Model of the vehicle.")
Color: str = Field(..., example=["Red", "Blue", "Black", "White"], description="Return the color of the vehicle.")
Year: str = Field(None, description="The year of the vehicle.")
Condition: str = Field(None, description="The condition of the vehicle.")
Logo: str = Field(None, description="The visible logo of the vehicle, if applicable.")
Damage: str = Field(None, description="Any visible damage or wear and tear on the vehicle.")
Region: str = Field(None, description="Region or country based on the license plate or clues from the image.")
PlateType: str = Field(None, description="Type of license plate, e.g., government, personal.")
# Parser for vehicle details
parser = JsonOutputParser(pydantic_object=Vehicle)
instructions = parser.get_format_instructions()
# Image encoding function (for base64 encoding)
def image_encoding(inputs):
"""Load and convert image to base64 encoding"""
with open(inputs["image_path"], "rb") as image_file:
image_base64 = base64.b64encode(image_file.read()).decode("utf-8")
return {"image": image_base64}
# Image display in grid (for multiple images)
def display_image_grid(image_paths, rows=2, cols=3, figsize=(10, 7)):
fig = plt.figure(figsize=figsize)
max_images = rows * cols
image_paths = image_paths[:max_images]
for idx, path in enumerate(image_paths):
ax = fig.add_subplot(rows, cols, idx + 1)
img = mpimg.imread(path)
ax.imshow(img)
ax.axis('off')
filename = path.split('/')[-1]
ax.set_title(filename)
plt.tight_layout()
st.pyplot(fig)
# Create the prompt for the AI model
@chain
def prompt(inputs):
prompt = [
SystemMessage(content="""You are an AI assistant tasked with extracting detailed information from a vehicle image. Please extract the following details:
- Vehicle type (e.g., Car, Truck, Bus)
- License plate number and type (if identifiable, such as personal, commercial, government)
- Vehicle make, model, and year (e.g., 2020 Toyota Corolla)
- Vehicle color and condition (e.g., Red, well-maintained, damaged)
- Any visible brand logos or distinguishing marks (e.g., Tesla logo)
- Details of any visible damage (e.g., scratches, dents)
- Vehicle’s region or country (based on the license plate or other clues)
If some details are unclear or not visible, return `None` for those fields. Do not guess or provide inaccurate information."""),
HumanMessage(
content=[
{"type": "text", "text": "Analyze the vehicle in the image and extract as many details as possible, including type, license plate, make, model, year, condition, damage, etc."},
{"type": "text", "text": instructions}, # include any other format instructions here
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{inputs['image']}", "detail": "low"}}
]
)
]
return prompt
# Invoke the model for extracting vehicle details
@chain
def MLLM_response(inputs):
model: ChatOpenAI = ChatOpenAI(model="gpt-4o-2024-08-06", temperature=0.0, max_tokens=1024)
output = model.invoke(inputs)
return output.content
# The complete pipeline for extracting vehicle details
pipeline = image_encoding | prompt | MLLM_response | parser
# Streamlit Interface for uploading images and showing results
st.header("Upload Vehicle Images for Information Extraction")
# Option to select either single or batch image upload
upload_option = st.radio("Select Upload Type", ["Single Image Upload", "Batch Images Upload"])
# Single Image Upload
if upload_option == "Single Image Upload":
st.subheader("Upload a Single Vehicle Image")
uploaded_image = st.file_uploader("Choose a JPEG image", type="jpeg")
if uploaded_image is not None:
# Display the uploaded image
image = PILImage.open(uploaded_image)
st.image(image, caption="Uploaded Image", use_column_width=True)
# Convert the uploaded image to base64
image_path = "/tmp/uploaded_image.jpeg"
with open(image_path, "wb") as f:
f.write(uploaded_image.getbuffer())
# Process the image through the pipeline
output = pipeline.invoke({"image_path": image_path})
# Show the results in a user-friendly format
st.subheader("Extracted Vehicle Information")
st.json(output)
# Optionally, display more vehicle images from the folder
img_dir = "/content/images"
image_paths = glob.glob(os.path.join(img_dir, "*.jpeg"))
display_image_grid(image_paths)
# Batch Images Upload
elif upload_option == "Batch Images Upload":
st.sidebar.header("Batch Image Upload")
batch_images = st.sidebar.file_uploader("Upload Images", type="jpeg", accept_multiple_files=True)
if batch_images:
batch_input = [{"image_path": f"/tmp/{file.name}"} for file in batch_images]
for file in batch_images:
with open(f"/tmp/{file.name}", "wb") as f:
f.write(file.getbuffer())
# Process the batch and display the results in a DataFrame
batch_output = pipeline.batch(batch_input)
df = pd.DataFrame(batch_output)
st.dataframe(df)
# Show images in a grid
image_paths = [f"/tmp/{file.name}" for file in batch_images]
display_image_grid(image_paths)