svjack commited on
Commit
7275cad
Β·
verified Β·
1 Parent(s): 5ad5a1c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +93 -0
app.py CHANGED
@@ -1,3 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import spaces
2
  import gradio as gr
3
  from transformers import LlavaForConditionalGeneration, TextIteratorStreamer, AutoProcessor
 
1
+ '''
2
+ from gradio_client import Client, file
3
+ from datasets import load_dataset
4
+ import os
5
+ from tqdm import tqdm
6
+ from PIL import Image
7
+
8
+ # Initialize Gradio client
9
+ client = Client("http://localhost:7861")
10
+
11
+ # Load the dataset
12
+ dataset = load_dataset("svjack/Dont_be_your_lover_Images")
13
+
14
+ # Create directories for saving images and results
15
+ os.makedirs("Dont_be_your_lover_Images_Captioned", exist_ok=True)
16
+
17
+ # Process each image in the dataset
18
+ for i, item in enumerate(tqdm(dataset["train"], desc="Processing images")):
19
+ try:
20
+ # Get the PIL Image object
21
+ pil_image = item["image"]
22
+
23
+ # Save the image locally with 000i.png format
24
+ img_filename = f"{i:04d}.png"
25
+ img_path = os.path.join("Dont_be_your_lover_Images_Captioned", img_filename)
26
+ pil_image.save(img_path)
27
+
28
+ # Process the image through the API
29
+ result = client.predict(
30
+ input_image=file(img_path),
31
+ prompt="Write a long detailed description for this image.",
32
+ temperature=0.6,
33
+ top_p=0.9,
34
+ max_new_tokens=512,
35
+ log_prompt=True,
36
+ api_name="/chat_joycaption"
37
+ )
38
+
39
+ # Save the result as a text file with the same name
40
+ result_filename = f"{i:04d}.txt"
41
+ result_path = os.path.join("Dont_be_your_lover_Images_Captioned", result_filename)
42
+
43
+ with open(result_path, "w", encoding="utf-8") as f:
44
+ f.write(str(result))
45
+
46
+ except Exception as e:
47
+ print(f"Error processing image {i}: {str(e)}")
48
+ continue
49
+
50
+ print("Processing complete!")
51
+
52
+ # Load the dataset
53
+ dataset = load_dataset("svjack/Origin_Images")
54
+
55
+ # Create directories for saving images and results
56
+ os.makedirs("Origin_Images_Captioned", exist_ok=True)
57
+
58
+ # Process each image in the dataset
59
+ for i, item in enumerate(tqdm(dataset["train"], desc="Processing images")):
60
+ try:
61
+ # Get the PIL Image object
62
+ pil_image = item["image"]
63
+
64
+ # Save the image locally with 000i.png format
65
+ img_filename = f"{i:04d}.png"
66
+ img_path = os.path.join("Origin_Images_Captioned", img_filename)
67
+ pil_image.save(img_path)
68
+
69
+ # Process the image through the API
70
+ result = client.predict(
71
+ input_image=file(img_path),
72
+ prompt="Write a long detailed description for this image.",
73
+ temperature=0.6,
74
+ top_p=0.9,
75
+ max_new_tokens=512,
76
+ log_prompt=True,
77
+ api_name="/chat_joycaption"
78
+ )
79
+
80
+ # Save the result as a text file with the same name
81
+ result_filename = f"{i:04d}.txt"
82
+ result_path = os.path.join("Origin_Images_Captioned", result_filename)
83
+
84
+ with open(result_path, "w", encoding="utf-8") as f:
85
+ f.write(str(result))
86
+
87
+ except Exception as e:
88
+ print(f"Error processing image {i}: {str(e)}")
89
+ continue
90
+
91
+ print("Processing complete!")
92
+ '''
93
+
94
  import spaces
95
  import gradio as gr
96
  from transformers import LlavaForConditionalGeneration, TextIteratorStreamer, AutoProcessor