import streamlit as st import pickle import pandas as pd import torch from PIL import Image import numpy as np from main import predict_caption, CLIPModel, get_text_embeddings import openai import base64 from io import BytesIO from reportlab.lib.pagesizes import letter from reportlab.pdfgen import canvas # Set up OpenAI API openai.api_key = "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" # Custom CSS for the page st.markdown( """ """, unsafe_allow_html=True, ) device = torch.device("cpu") testing_df = pd.read_csv("testing_df.csv") model = CLIPModel().to(device) model.load_state_dict(torch.load("weights.pt", map_location=torch.device('cpu'))) text_embeddings = torch.load('saved_text_embeddings.pt', map_location=device) def show_predicted_caption(image): matches = predict_caption( image, model, text_embeddings, testing_df["caption"] )[0] return matches def generate_radiology_report(prompt): response = openai.Completion.create( engine="text-davinci-003", prompt=prompt, max_tokens=800, n=1, stop=None, temperature=0.9, ) return response.choices[0].text.strip() def chatbot_response(prompt): response = openai.Completion.create( engine="text-davinci-003", prompt=prompt, max_tokens=500, n=1, stop=None, temperature=0.8, ) return response.choices[0].text.strip() def create_pdf(caption, buffer): c = canvas.Canvas(buffer, pagesize=letter) c.drawString(50, 750, caption) c.save() buffer.seek(0) return buffer st.title("RadiXGPT: An Evolution of machine doctors towrads Radiology") st.write("Upload Scan to get Radiological Report:") uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"]) if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption="Uploaded Image", use_column_width=True) st.write("") if st.button("Generate Report"): with st.spinner("Generating Report...hold on"): image_np = np.array(image) caption = show_predicted_caption(image_np) st.success(f"Caption: {caption}") prompt = f"Write Complete Radiology Report for this: {caption}" report = generate_radiology_report(prompt) st.markdown("