chillguyyyyyyyyyyer commited on
Commit
0b75645
Β·
verified Β·
1 Parent(s): 11b2c71

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import openai
3
+ import torch
4
+ from diffusers import StableVideoDiffusionPipeline
5
+ import tempfile
6
+ import cv2
7
+ import os
8
+ from langchain.vectorstores import FAISS
9
+ from langchain.embeddings import OpenAIEmbeddings
10
+ from langchain.chains import RetrievalQA
11
+ from langchain.document_loaders import TextLoader
12
+ from langchain.text_splitter import CharacterTextSplitter
13
+
14
+ # Configure OpenAI API (Replace 'your-api-key' with an actual key)
15
+ openai.api_key = "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"
16
+
17
+ # Load Stable Video Diffusion Model (Optimized for Performance)
18
+ @st.cache_resource
19
+ def load_video_model():
20
+ model_id = "stabilityai/stable-video-diffusion-img2vid"
21
+ pipe = StableVideoDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16)
22
+ pipe.to("cuda" if torch.cuda.is_available() else "cpu")
23
+ return pipe
24
+
25
+ pipe = load_video_model()
26
+
27
+ # Load RAG Components
28
+ @st.cache_resource
29
+ def load_rag():
30
+ loader = TextLoader("knowledge_base.txt") # Ensure you have a knowledge base file
31
+ documents = loader.load()
32
+ text_splitter = CharacterTextSplitter(chunk_size=500, chunk_overlap=50)
33
+ texts = text_splitter.split_documents(documents)
34
+ embeddings = OpenAIEmbeddings()
35
+ vectorstore = FAISS.from_documents(texts, embeddings)
36
+ retriever = vectorstore.as_retriever()
37
+ return RetrievalQA.from_chain_type(llm=openai.ChatCompletion, retriever=retriever)
38
+
39
+ rag_chain = load_rag()
40
+
41
+ # Streamlit UI Configuration
42
+ st.set_page_config(page_title="LiteGPT - Chat & Video AI", layout="wide")
43
+ st.title("πŸ’¬ LiteGPT - Chat & Video AI")
44
+
45
+ # Chatbot Function with RAG
46
+ @st.cache_resource
47
+ def chat_with_gpt(prompt):
48
+ response = rag_chain.run(prompt)
49
+ return response
50
+
51
+ # Sidebar - Video Generation
52
+ st.sidebar.header("πŸŽ₯ AI Video Generator")
53
+ video_prompt = st.sidebar.text_area("Enter a prompt for video generation")
54
+ if st.sidebar.button("Generate Video"):
55
+ if video_prompt:
56
+ with st.spinner("Generating video..."):
57
+ video_frames = pipe(video_prompt, num_inference_steps=50).frames
58
+ video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
59
+ height, width, _ = video_frames[0].shape
60
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
61
+ out = cv2.VideoWriter(video_path, fourcc, 8, (width, height))
62
+ for frame in video_frames:
63
+ out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR))
64
+ out.release()
65
+ st.sidebar.video(video_path)
66
+ else:
67
+ st.sidebar.warning("Please enter a video prompt!")
68
+
69
+ # Chat Interface
70
+ st.subheader("πŸ’‘ Chat with LiteGPT")
71
+ user_input = st.text_input("Type your message:")
72
+ if st.button("Send"):
73
+ if user_input:
74
+ response = chat_with_gpt(user_input)
75
+ st.write("πŸ€– LiteGPT:", response)
76
+ else:
77
+ st.warning("Please enter a message!")